# prepare.py # # author: deng # date : 20231229 from pathlib import PosixPath import torch import numpy as np from torch.utils.data import Dataset class ProcessedDataset(Dataset): """"Load processed data""" def __init__(self, dataset_dir: PosixPath): self.dataset_dir = dataset_dir self.file_paths = list(self.dataset_dir.glob('*.npz')) def __len__(self): return len(self.file_paths) def __getitem__(self, idx): npz = np.load(self.file_paths[idx]) x = torch.from_numpy(npz['x']) y = torch.from_numpy(npz['y']) return x, y