test_dvc/utils/dataset.py

27 lines
601 B
Python

# prepare.py
#
# author: deng
# date : 20231229
from pathlib import PosixPath
import torch
import numpy as np
from torch.utils.data import Dataset
class ProcessedDataset(Dataset):
""""Load processed data"""
def __init__(self, dataset_dir: PosixPath):
self.dataset_dir = dataset_dir
self.file_paths = list(self.dataset_dir.glob('*.npz'))
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
npz = np.load(self.file_paths[idx])
x = torch.from_numpy(npz['x'])
y = torch.from_numpy(npz['y'])
return x, y