fix file_lazy_load
This commit is contained in:
@ -8,8 +8,8 @@ prepare:
|
||||
random_seed: 1
|
||||
train:
|
||||
device_type: mps
|
||||
train_npz: ./data/processed/train.npz
|
||||
valid_npz: ./data/processed/valid.npz
|
||||
train_data_dir: ./data/processed/train
|
||||
valid_data_dir: ./data/processed/valid
|
||||
batch_size: 256
|
||||
num_of_class: 20
|
||||
optimizer_name: sgd # sgd, adam
|
||||
@ -20,7 +20,7 @@ train:
|
||||
random_seed: 1
|
||||
exp_msg: init train
|
||||
eval:
|
||||
test_npz: ./data/processed/test.npz
|
||||
test_data_dir: ./data/processed/test
|
||||
model_path: ./assets/model.pth
|
||||
random_seed: 1
|
||||
deploy:
|
||||
|
||||
@ -19,7 +19,7 @@ class Eval:
|
||||
self._device = torch.device('mps' if torch.mps.is_available() else 'cpu')
|
||||
|
||||
def _get_dataloader(self):
|
||||
test_dataset = QuickDrawDataset(data_npz_path=self.config['test_npz'], return_cate_name=False)
|
||||
test_dataset = QuickDrawDataset(data_dir=self.config['test_data_dir'], return_cate_name=False)
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
||||
return test_dataloader
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -66,16 +67,21 @@ class Prepare:
|
||||
sets[name] = {key: value[idx] for key, value in data.items()}
|
||||
return sets
|
||||
|
||||
def _save_npz(self, sets: dict[str, dict[str, np.ndarray]]) -> None:
|
||||
def _save_data(self, sets: dict[str, dict[str, np.ndarray]]) -> None:
|
||||
save_dir = Path(self.config['data_dir']) / 'processed'
|
||||
save_dir.mkdir(exist_ok=True)
|
||||
if save_dir.exists():
|
||||
rmtree(save_dir)
|
||||
save_dir.mkdir()
|
||||
for usage, data in sets.items():
|
||||
np.savez(f'{save_dir}/{usage}.npz', **data)
|
||||
usage_dir = save_dir / usage
|
||||
usage_dir.mkdir()
|
||||
for key, value in data.items():
|
||||
np.save(f'{usage_dir}/{key}.npy', value)
|
||||
|
||||
def run(self):
|
||||
data = self._load_dataset()
|
||||
sets = self._split_data(data)
|
||||
self._save_npz(sets)
|
||||
self._save_data(sets)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -33,14 +33,14 @@ class Train:
|
||||
|
||||
def _get_dataloader(self):
|
||||
train_dataset = QuickDrawDataset(
|
||||
data_npz_path=self.config['train_npz'],
|
||||
data_dir=self.config['train_data_dir'],
|
||||
enable_data_aug=True,
|
||||
file_lazy_load=self.config['file_lazy_load'],
|
||||
return_cate_name=False,
|
||||
# vis_dir='./tmp'
|
||||
)
|
||||
valid_dataset = QuickDrawDataset(
|
||||
data_npz_path=self.config['valid_npz'],
|
||||
data_dir=self.config['valid_data_dir'],
|
||||
enable_data_aug=False,
|
||||
file_lazy_load=self.config['file_lazy_load'],
|
||||
return_cate_name=False,
|
||||
@ -55,7 +55,12 @@ class Train:
|
||||
persistent_workers=True,
|
||||
)
|
||||
valid_dataloader = DataLoader(
|
||||
valid_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=1, pin_memory=False, persistent_workers=True
|
||||
valid_dataset,
|
||||
batch_size=self.config['batch_size'],
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
pin_memory=False, # not support for mps
|
||||
persistent_workers=True,
|
||||
)
|
||||
return train_dataloader, valid_dataloader
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from torchvision.transforms import v2
|
||||
class QuickDrawDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_npz_path: str,
|
||||
data_dir: str,
|
||||
image_shape: tuple[int, int, int] = (1, 28, 28),
|
||||
enable_data_aug: bool = False,
|
||||
file_lazy_load: bool = False,
|
||||
@ -21,13 +21,13 @@ class QuickDrawDataset(torch.utils.data.Dataset):
|
||||
vis_dir: str = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._images: torch.Tensor | np.ndarray = []
|
||||
self._images: torch.Tensor | np.ndarray = None
|
||||
self._cate_names: list[str] = []
|
||||
self._cate_ids: torch.Tensor = []
|
||||
self._transform: callable = None
|
||||
|
||||
self._enable_data_aug = enable_data_aug
|
||||
self._data_npz_path = data_npz_path
|
||||
self._data_dir = Path(data_dir)
|
||||
self._image_shape = image_shape
|
||||
self._file_lazy_load = file_lazy_load
|
||||
self._return_cate_name = return_cate_name
|
||||
@ -55,19 +55,15 @@ class QuickDrawDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def _collect_data(self) -> None:
|
||||
self._cate_names = [cate_name.decode() for cate_name in np.load(self._data_dir / 'cate_names.npy')]
|
||||
self._cate_ids = torch.from_numpy(np.load(self._data_dir / 'cate_ids.npy')).long()
|
||||
if self._file_lazy_load:
|
||||
self._npz_file = np.load(self._data_npz_path, mmap_mode='r')
|
||||
self._cate_names = [cate_name.decode() for cate_name in self._npz_file['cate_names']]
|
||||
self._cate_ids = torch.from_numpy(self._npz_file['cate_ids'].copy()).long()
|
||||
self._images = self._npz_file['images']
|
||||
self._images = np.load(self._data_dir / 'images.npy', mmap_mode='r')
|
||||
else:
|
||||
with np.load(self._data_npz_path, mmap_mode=None) as npz_file:
|
||||
self._cate_names = [cate_name.decode() for cate_name in npz_file['cate_names']]
|
||||
self._cate_ids = torch.from_numpy(npz_file['cate_ids']).long()
|
||||
self._images = torch.from_numpy(npz_file['images'])
|
||||
self._images = torch.from_numpy(np.load(self._data_dir / 'images.npy'))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._images)
|
||||
return len(self._cate_ids)
|
||||
|
||||
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self._file_lazy_load:
|
||||
|
||||
Reference in New Issue
Block a user