From 9c46e1c345241d8bbea3cd4b5ac2fec42140f71c Mon Sep 17 00:00:00 2001 From: deng Date: Thu, 18 Jun 2026 13:51:36 +0800 Subject: [PATCH] fix file_lazy_load --- quickdraw_bot/assets/config.yaml | 6 +++--- quickdraw_bot/eval.py | 2 +- quickdraw_bot/prepare.py | 14 ++++++++++---- quickdraw_bot/train.py | 11 ++++++++--- quickdraw_bot/utils/dataset.py | 20 ++++++++------------ 5 files changed, 30 insertions(+), 23 deletions(-) diff --git a/quickdraw_bot/assets/config.yaml b/quickdraw_bot/assets/config.yaml index 44d8909..e3f677d 100644 --- a/quickdraw_bot/assets/config.yaml +++ b/quickdraw_bot/assets/config.yaml @@ -8,8 +8,8 @@ prepare: random_seed: 1 train: device_type: mps - train_npz: ./data/processed/train.npz - valid_npz: ./data/processed/valid.npz + train_data_dir: ./data/processed/train + valid_data_dir: ./data/processed/valid batch_size: 256 num_of_class: 20 optimizer_name: sgd # sgd, adam @@ -20,7 +20,7 @@ train: random_seed: 1 exp_msg: init train eval: - test_npz: ./data/processed/test.npz + test_data_dir: ./data/processed/test model_path: ./assets/model.pth random_seed: 1 deploy: diff --git a/quickdraw_bot/eval.py b/quickdraw_bot/eval.py index dff88ca..5ee5536 100644 --- a/quickdraw_bot/eval.py +++ b/quickdraw_bot/eval.py @@ -19,7 +19,7 @@ class Eval: self._device = torch.device('mps' if torch.mps.is_available() else 'cpu') def _get_dataloader(self): - test_dataset = QuickDrawDataset(data_npz_path=self.config['test_npz'], return_cate_name=False) + test_dataset = QuickDrawDataset(data_dir=self.config['test_data_dir'], return_cate_name=False) test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) return test_dataloader diff --git a/quickdraw_bot/prepare.py b/quickdraw_bot/prepare.py index 5ae671d..a0ca22b 100644 --- a/quickdraw_bot/prepare.py +++ b/quickdraw_bot/prepare.py @@ -5,6 +5,7 @@ import random from pathlib import Path +from shutil import rmtree import numpy as np @@ -66,16 +67,21 @@ class Prepare: sets[name] = {key: value[idx] for key, value in data.items()} return sets - def _save_npz(self, sets: dict[str, dict[str, np.ndarray]]) -> None: + def _save_data(self, sets: dict[str, dict[str, np.ndarray]]) -> None: save_dir = Path(self.config['data_dir']) / 'processed' - save_dir.mkdir(exist_ok=True) + if save_dir.exists(): + rmtree(save_dir) + save_dir.mkdir() for usage, data in sets.items(): - np.savez(f'{save_dir}/{usage}.npz', **data) + usage_dir = save_dir / usage + usage_dir.mkdir() + for key, value in data.items(): + np.save(f'{usage_dir}/{key}.npy', value) def run(self): data = self._load_dataset() sets = self._split_data(data) - self._save_npz(sets) + self._save_data(sets) if __name__ == '__main__': diff --git a/quickdraw_bot/train.py b/quickdraw_bot/train.py index 695c5f0..41810da 100644 --- a/quickdraw_bot/train.py +++ b/quickdraw_bot/train.py @@ -33,14 +33,14 @@ class Train: def _get_dataloader(self): train_dataset = QuickDrawDataset( - data_npz_path=self.config['train_npz'], + data_dir=self.config['train_data_dir'], enable_data_aug=True, file_lazy_load=self.config['file_lazy_load'], return_cate_name=False, # vis_dir='./tmp' ) valid_dataset = QuickDrawDataset( - data_npz_path=self.config['valid_npz'], + data_dir=self.config['valid_data_dir'], enable_data_aug=False, file_lazy_load=self.config['file_lazy_load'], return_cate_name=False, @@ -55,7 +55,12 @@ class Train: persistent_workers=True, ) valid_dataloader = DataLoader( - valid_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=1, pin_memory=False, persistent_workers=True + valid_dataset, + batch_size=self.config['batch_size'], + shuffle=False, + num_workers=1, + pin_memory=False, # not support for mps + persistent_workers=True, ) return train_dataloader, valid_dataloader diff --git a/quickdraw_bot/utils/dataset.py b/quickdraw_bot/utils/dataset.py index 115046c..c369163 100644 --- a/quickdraw_bot/utils/dataset.py +++ b/quickdraw_bot/utils/dataset.py @@ -13,7 +13,7 @@ from torchvision.transforms import v2 class QuickDrawDataset(torch.utils.data.Dataset): def __init__( self, - data_npz_path: str, + data_dir: str, image_shape: tuple[int, int, int] = (1, 28, 28), enable_data_aug: bool = False, file_lazy_load: bool = False, @@ -21,13 +21,13 @@ class QuickDrawDataset(torch.utils.data.Dataset): vis_dir: str = None, ) -> None: super().__init__() - self._images: torch.Tensor | np.ndarray = [] + self._images: torch.Tensor | np.ndarray = None self._cate_names: list[str] = [] self._cate_ids: torch.Tensor = [] self._transform: callable = None self._enable_data_aug = enable_data_aug - self._data_npz_path = data_npz_path + self._data_dir = Path(data_dir) self._image_shape = image_shape self._file_lazy_load = file_lazy_load self._return_cate_name = return_cate_name @@ -55,19 +55,15 @@ class QuickDrawDataset(torch.utils.data.Dataset): ) def _collect_data(self) -> None: + self._cate_names = [cate_name.decode() for cate_name in np.load(self._data_dir / 'cate_names.npy')] + self._cate_ids = torch.from_numpy(np.load(self._data_dir / 'cate_ids.npy')).long() if self._file_lazy_load: - self._npz_file = np.load(self._data_npz_path, mmap_mode='r') - self._cate_names = [cate_name.decode() for cate_name in self._npz_file['cate_names']] - self._cate_ids = torch.from_numpy(self._npz_file['cate_ids'].copy()).long() - self._images = self._npz_file['images'] + self._images = np.load(self._data_dir / 'images.npy', mmap_mode='r') else: - with np.load(self._data_npz_path, mmap_mode=None) as npz_file: - self._cate_names = [cate_name.decode() for cate_name in npz_file['cate_names']] - self._cate_ids = torch.from_numpy(npz_file['cate_ids']).long() - self._images = torch.from_numpy(npz_file['images']) + self._images = torch.from_numpy(np.load(self._data_dir / 'images.npy')) def __len__(self) -> int: - return len(self._images) + return len(self._cate_ids) def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: if self._file_lazy_load: