27 lines
601 B
Python
27 lines
601 B
Python
# prepare.py
|
|
#
|
|
# author: deng
|
|
# date : 20231229
|
|
|
|
from pathlib import PosixPath
|
|
|
|
import torch
|
|
import numpy as np
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class ProcessedDataset(Dataset):
|
|
""""Load processed data"""
|
|
def __init__(self, dataset_dir: PosixPath):
|
|
self.dataset_dir = dataset_dir
|
|
self.file_paths = list(self.dataset_dir.glob('*.npz'))
|
|
|
|
def __len__(self):
|
|
return len(self.file_paths)
|
|
|
|
def __getitem__(self, idx):
|
|
npz = np.load(self.file_paths[idx])
|
|
x = torch.from_numpy(npz['x'])
|
|
y = torch.from_numpy(npz['y'])
|
|
return x, y
|