fix file_lazy_load

This commit is contained in:
2026-06-18 13:51:36 +08:00
parent c5d39c8e16
commit 9c46e1c345
5 changed files with 30 additions and 23 deletions

View File

@ -8,8 +8,8 @@ prepare:
random_seed: 1
train:
device_type: mps
train_npz: ./data/processed/train.npz
valid_npz: ./data/processed/valid.npz
train_data_dir: ./data/processed/train
valid_data_dir: ./data/processed/valid
batch_size: 256
num_of_class: 20
optimizer_name: sgd # sgd, adam
@ -20,7 +20,7 @@ train:
random_seed: 1
exp_msg: init train
eval:
test_npz: ./data/processed/test.npz
test_data_dir: ./data/processed/test
model_path: ./assets/model.pth
random_seed: 1
deploy:

View File

@ -19,7 +19,7 @@ class Eval:
self._device = torch.device('mps' if torch.mps.is_available() else 'cpu')
def _get_dataloader(self):
test_dataset = QuickDrawDataset(data_npz_path=self.config['test_npz'], return_cate_name=False)
test_dataset = QuickDrawDataset(data_dir=self.config['test_data_dir'], return_cate_name=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
return test_dataloader

View File

@ -5,6 +5,7 @@
import random
from pathlib import Path
from shutil import rmtree
import numpy as np
@ -66,16 +67,21 @@ class Prepare:
sets[name] = {key: value[idx] for key, value in data.items()}
return sets
def _save_npz(self, sets: dict[str, dict[str, np.ndarray]]) -> None:
def _save_data(self, sets: dict[str, dict[str, np.ndarray]]) -> None:
save_dir = Path(self.config['data_dir']) / 'processed'
save_dir.mkdir(exist_ok=True)
if save_dir.exists():
rmtree(save_dir)
save_dir.mkdir()
for usage, data in sets.items():
np.savez(f'{save_dir}/{usage}.npz', **data)
usage_dir = save_dir / usage
usage_dir.mkdir()
for key, value in data.items():
np.save(f'{usage_dir}/{key}.npy', value)
def run(self):
data = self._load_dataset()
sets = self._split_data(data)
self._save_npz(sets)
self._save_data(sets)
if __name__ == '__main__':

View File

@ -33,14 +33,14 @@ class Train:
def _get_dataloader(self):
train_dataset = QuickDrawDataset(
data_npz_path=self.config['train_npz'],
data_dir=self.config['train_data_dir'],
enable_data_aug=True,
file_lazy_load=self.config['file_lazy_load'],
return_cate_name=False,
# vis_dir='./tmp'
)
valid_dataset = QuickDrawDataset(
data_npz_path=self.config['valid_npz'],
data_dir=self.config['valid_data_dir'],
enable_data_aug=False,
file_lazy_load=self.config['file_lazy_load'],
return_cate_name=False,
@ -55,7 +55,12 @@ class Train:
persistent_workers=True,
)
valid_dataloader = DataLoader(
valid_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=1, pin_memory=False, persistent_workers=True
valid_dataset,
batch_size=self.config['batch_size'],
shuffle=False,
num_workers=1,
pin_memory=False, # not support for mps
persistent_workers=True,
)
return train_dataloader, valid_dataloader

View File

@ -13,7 +13,7 @@ from torchvision.transforms import v2
class QuickDrawDataset(torch.utils.data.Dataset):
def __init__(
self,
data_npz_path: str,
data_dir: str,
image_shape: tuple[int, int, int] = (1, 28, 28),
enable_data_aug: bool = False,
file_lazy_load: bool = False,
@ -21,13 +21,13 @@ class QuickDrawDataset(torch.utils.data.Dataset):
vis_dir: str = None,
) -> None:
super().__init__()
self._images: torch.Tensor | np.ndarray = []
self._images: torch.Tensor | np.ndarray = None
self._cate_names: list[str] = []
self._cate_ids: torch.Tensor = []
self._transform: callable = None
self._enable_data_aug = enable_data_aug
self._data_npz_path = data_npz_path
self._data_dir = Path(data_dir)
self._image_shape = image_shape
self._file_lazy_load = file_lazy_load
self._return_cate_name = return_cate_name
@ -55,19 +55,15 @@ class QuickDrawDataset(torch.utils.data.Dataset):
)
def _collect_data(self) -> None:
self._cate_names = [cate_name.decode() for cate_name in np.load(self._data_dir / 'cate_names.npy')]
self._cate_ids = torch.from_numpy(np.load(self._data_dir / 'cate_ids.npy')).long()
if self._file_lazy_load:
self._npz_file = np.load(self._data_npz_path, mmap_mode='r')
self._cate_names = [cate_name.decode() for cate_name in self._npz_file['cate_names']]
self._cate_ids = torch.from_numpy(self._npz_file['cate_ids'].copy()).long()
self._images = self._npz_file['images']
self._images = np.load(self._data_dir / 'images.npy', mmap_mode='r')
else:
with np.load(self._data_npz_path, mmap_mode=None) as npz_file:
self._cate_names = [cate_name.decode() for cate_name in npz_file['cate_names']]
self._cate_ids = torch.from_numpy(npz_file['cate_ids']).long()
self._images = torch.from_numpy(npz_file['images'])
self._images = torch.from_numpy(np.load(self._data_dir / 'images.npy'))
def __len__(self) -> int:
return len(self._images)
return len(self._cate_ids)
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
if self._file_lazy_load: