use resnet50 to train a cifar10 classifier
This commit is contained in:
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
26
utils/dataset.py
Normal file
26
utils/dataset.py
Normal file
@ -0,0 +1,26 @@
|
||||
# prepare.py
|
||||
#
|
||||
# author: deng
|
||||
# date : 20231229
|
||||
|
||||
from pathlib import PosixPath
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class ProcessedDataset(Dataset):
|
||||
""""Load processed data"""
|
||||
def __init__(self, dataset_dir: PosixPath):
|
||||
self.dataset_dir = dataset_dir
|
||||
self.file_paths = list(self.dataset_dir.glob('*.npz'))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
npz = np.load(self.file_paths[idx])
|
||||
x = torch.from_numpy(npz['x'])
|
||||
y = torch.from_numpy(npz['y'])
|
||||
return x, y
|
Reference in New Issue
Block a user