fix file_lazy_load

This commit is contained in:
2026-06-18 13:51:36 +08:00
parent c5d39c8e16
commit 9c46e1c345
5 changed files with 30 additions and 23 deletions

View File

@ -8,8 +8,8 @@ prepare:
random_seed: 1 random_seed: 1
train: train:
device_type: mps device_type: mps
train_npz: ./data/processed/train.npz train_data_dir: ./data/processed/train
valid_npz: ./data/processed/valid.npz valid_data_dir: ./data/processed/valid
batch_size: 256 batch_size: 256
num_of_class: 20 num_of_class: 20
optimizer_name: sgd # sgd, adam optimizer_name: sgd # sgd, adam
@ -20,7 +20,7 @@ train:
random_seed: 1 random_seed: 1
exp_msg: init train exp_msg: init train
eval: eval:
test_npz: ./data/processed/test.npz test_data_dir: ./data/processed/test
model_path: ./assets/model.pth model_path: ./assets/model.pth
random_seed: 1 random_seed: 1
deploy: deploy:

View File

@ -19,7 +19,7 @@ class Eval:
self._device = torch.device('mps' if torch.mps.is_available() else 'cpu') self._device = torch.device('mps' if torch.mps.is_available() else 'cpu')
def _get_dataloader(self): def _get_dataloader(self):
test_dataset = QuickDrawDataset(data_npz_path=self.config['test_npz'], return_cate_name=False) test_dataset = QuickDrawDataset(data_dir=self.config['test_data_dir'], return_cate_name=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
return test_dataloader return test_dataloader

View File

@ -5,6 +5,7 @@
import random import random
from pathlib import Path from pathlib import Path
from shutil import rmtree
import numpy as np import numpy as np
@ -66,16 +67,21 @@ class Prepare:
sets[name] = {key: value[idx] for key, value in data.items()} sets[name] = {key: value[idx] for key, value in data.items()}
return sets return sets
def _save_npz(self, sets: dict[str, dict[str, np.ndarray]]) -> None: def _save_data(self, sets: dict[str, dict[str, np.ndarray]]) -> None:
save_dir = Path(self.config['data_dir']) / 'processed' save_dir = Path(self.config['data_dir']) / 'processed'
save_dir.mkdir(exist_ok=True) if save_dir.exists():
rmtree(save_dir)
save_dir.mkdir()
for usage, data in sets.items(): for usage, data in sets.items():
np.savez(f'{save_dir}/{usage}.npz', **data) usage_dir = save_dir / usage
usage_dir.mkdir()
for key, value in data.items():
np.save(f'{usage_dir}/{key}.npy', value)
def run(self): def run(self):
data = self._load_dataset() data = self._load_dataset()
sets = self._split_data(data) sets = self._split_data(data)
self._save_npz(sets) self._save_data(sets)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -33,14 +33,14 @@ class Train:
def _get_dataloader(self): def _get_dataloader(self):
train_dataset = QuickDrawDataset( train_dataset = QuickDrawDataset(
data_npz_path=self.config['train_npz'], data_dir=self.config['train_data_dir'],
enable_data_aug=True, enable_data_aug=True,
file_lazy_load=self.config['file_lazy_load'], file_lazy_load=self.config['file_lazy_load'],
return_cate_name=False, return_cate_name=False,
# vis_dir='./tmp' # vis_dir='./tmp'
) )
valid_dataset = QuickDrawDataset( valid_dataset = QuickDrawDataset(
data_npz_path=self.config['valid_npz'], data_dir=self.config['valid_data_dir'],
enable_data_aug=False, enable_data_aug=False,
file_lazy_load=self.config['file_lazy_load'], file_lazy_load=self.config['file_lazy_load'],
return_cate_name=False, return_cate_name=False,
@ -55,7 +55,12 @@ class Train:
persistent_workers=True, persistent_workers=True,
) )
valid_dataloader = DataLoader( valid_dataloader = DataLoader(
valid_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=1, pin_memory=False, persistent_workers=True valid_dataset,
batch_size=self.config['batch_size'],
shuffle=False,
num_workers=1,
pin_memory=False, # not support for mps
persistent_workers=True,
) )
return train_dataloader, valid_dataloader return train_dataloader, valid_dataloader

View File

@ -13,7 +13,7 @@ from torchvision.transforms import v2
class QuickDrawDataset(torch.utils.data.Dataset): class QuickDrawDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
data_npz_path: str, data_dir: str,
image_shape: tuple[int, int, int] = (1, 28, 28), image_shape: tuple[int, int, int] = (1, 28, 28),
enable_data_aug: bool = False, enable_data_aug: bool = False,
file_lazy_load: bool = False, file_lazy_load: bool = False,
@ -21,13 +21,13 @@ class QuickDrawDataset(torch.utils.data.Dataset):
vis_dir: str = None, vis_dir: str = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self._images: torch.Tensor | np.ndarray = [] self._images: torch.Tensor | np.ndarray = None
self._cate_names: list[str] = [] self._cate_names: list[str] = []
self._cate_ids: torch.Tensor = [] self._cate_ids: torch.Tensor = []
self._transform: callable = None self._transform: callable = None
self._enable_data_aug = enable_data_aug self._enable_data_aug = enable_data_aug
self._data_npz_path = data_npz_path self._data_dir = Path(data_dir)
self._image_shape = image_shape self._image_shape = image_shape
self._file_lazy_load = file_lazy_load self._file_lazy_load = file_lazy_load
self._return_cate_name = return_cate_name self._return_cate_name = return_cate_name
@ -55,19 +55,15 @@ class QuickDrawDataset(torch.utils.data.Dataset):
) )
def _collect_data(self) -> None: def _collect_data(self) -> None:
self._cate_names = [cate_name.decode() for cate_name in np.load(self._data_dir / 'cate_names.npy')]
self._cate_ids = torch.from_numpy(np.load(self._data_dir / 'cate_ids.npy')).long()
if self._file_lazy_load: if self._file_lazy_load:
self._npz_file = np.load(self._data_npz_path, mmap_mode='r') self._images = np.load(self._data_dir / 'images.npy', mmap_mode='r')
self._cate_names = [cate_name.decode() for cate_name in self._npz_file['cate_names']]
self._cate_ids = torch.from_numpy(self._npz_file['cate_ids'].copy()).long()
self._images = self._npz_file['images']
else: else:
with np.load(self._data_npz_path, mmap_mode=None) as npz_file: self._images = torch.from_numpy(np.load(self._data_dir / 'images.npy'))
self._cate_names = [cate_name.decode() for cate_name in npz_file['cate_names']]
self._cate_ids = torch.from_numpy(npz_file['cate_ids']).long()
self._images = torch.from_numpy(npz_file['images'])
def __len__(self) -> int: def __len__(self) -> int:
return len(self._images) return len(self._cate_ids)
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
if self._file_lazy_load: if self._file_lazy_load: