use resnet50 to train a cifar10 classifier

This commit is contained in:
2023-12-30 00:03:36 +08:00
parent 7a891969e0
commit a27d0a24d9
27 changed files with 393 additions and 19 deletions

0
utils/__init__.py Normal file
View File

26
utils/dataset.py Normal file
View File

@ -0,0 +1,26 @@
# prepare.py
#
# author: deng
# date : 20231229
from pathlib import PosixPath
import torch
import numpy as np
from torch.utils.data import Dataset
class ProcessedDataset(Dataset):
""""Load processed data"""
def __init__(self, dataset_dir: PosixPath):
self.dataset_dir = dataset_dir
self.file_paths = list(self.dataset_dir.glob('*.npz'))
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
npz = np.load(self.file_paths[idx])
x = torch.from_numpy(npz['x'])
y = torch.from_numpy(npz['y'])
return x, y