fix file_lazy_load
This commit is contained in:
@ -8,8 +8,8 @@ prepare:
|
|||||||
random_seed: 1
|
random_seed: 1
|
||||||
train:
|
train:
|
||||||
device_type: mps
|
device_type: mps
|
||||||
train_npz: ./data/processed/train.npz
|
train_data_dir: ./data/processed/train
|
||||||
valid_npz: ./data/processed/valid.npz
|
valid_data_dir: ./data/processed/valid
|
||||||
batch_size: 256
|
batch_size: 256
|
||||||
num_of_class: 20
|
num_of_class: 20
|
||||||
optimizer_name: sgd # sgd, adam
|
optimizer_name: sgd # sgd, adam
|
||||||
@ -20,7 +20,7 @@ train:
|
|||||||
random_seed: 1
|
random_seed: 1
|
||||||
exp_msg: init train
|
exp_msg: init train
|
||||||
eval:
|
eval:
|
||||||
test_npz: ./data/processed/test.npz
|
test_data_dir: ./data/processed/test
|
||||||
model_path: ./assets/model.pth
|
model_path: ./assets/model.pth
|
||||||
random_seed: 1
|
random_seed: 1
|
||||||
deploy:
|
deploy:
|
||||||
|
|||||||
@ -19,7 +19,7 @@ class Eval:
|
|||||||
self._device = torch.device('mps' if torch.mps.is_available() else 'cpu')
|
self._device = torch.device('mps' if torch.mps.is_available() else 'cpu')
|
||||||
|
|
||||||
def _get_dataloader(self):
|
def _get_dataloader(self):
|
||||||
test_dataset = QuickDrawDataset(data_npz_path=self.config['test_npz'], return_cate_name=False)
|
test_dataset = QuickDrawDataset(data_dir=self.config['test_data_dir'], return_cate_name=False)
|
||||||
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
||||||
return test_dataloader
|
return test_dataloader
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from shutil import rmtree
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -66,16 +67,21 @@ class Prepare:
|
|||||||
sets[name] = {key: value[idx] for key, value in data.items()}
|
sets[name] = {key: value[idx] for key, value in data.items()}
|
||||||
return sets
|
return sets
|
||||||
|
|
||||||
def _save_npz(self, sets: dict[str, dict[str, np.ndarray]]) -> None:
|
def _save_data(self, sets: dict[str, dict[str, np.ndarray]]) -> None:
|
||||||
save_dir = Path(self.config['data_dir']) / 'processed'
|
save_dir = Path(self.config['data_dir']) / 'processed'
|
||||||
save_dir.mkdir(exist_ok=True)
|
if save_dir.exists():
|
||||||
|
rmtree(save_dir)
|
||||||
|
save_dir.mkdir()
|
||||||
for usage, data in sets.items():
|
for usage, data in sets.items():
|
||||||
np.savez(f'{save_dir}/{usage}.npz', **data)
|
usage_dir = save_dir / usage
|
||||||
|
usage_dir.mkdir()
|
||||||
|
for key, value in data.items():
|
||||||
|
np.save(f'{usage_dir}/{key}.npy', value)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
data = self._load_dataset()
|
data = self._load_dataset()
|
||||||
sets = self._split_data(data)
|
sets = self._split_data(data)
|
||||||
self._save_npz(sets)
|
self._save_data(sets)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -33,14 +33,14 @@ class Train:
|
|||||||
|
|
||||||
def _get_dataloader(self):
|
def _get_dataloader(self):
|
||||||
train_dataset = QuickDrawDataset(
|
train_dataset = QuickDrawDataset(
|
||||||
data_npz_path=self.config['train_npz'],
|
data_dir=self.config['train_data_dir'],
|
||||||
enable_data_aug=True,
|
enable_data_aug=True,
|
||||||
file_lazy_load=self.config['file_lazy_load'],
|
file_lazy_load=self.config['file_lazy_load'],
|
||||||
return_cate_name=False,
|
return_cate_name=False,
|
||||||
# vis_dir='./tmp'
|
# vis_dir='./tmp'
|
||||||
)
|
)
|
||||||
valid_dataset = QuickDrawDataset(
|
valid_dataset = QuickDrawDataset(
|
||||||
data_npz_path=self.config['valid_npz'],
|
data_dir=self.config['valid_data_dir'],
|
||||||
enable_data_aug=False,
|
enable_data_aug=False,
|
||||||
file_lazy_load=self.config['file_lazy_load'],
|
file_lazy_load=self.config['file_lazy_load'],
|
||||||
return_cate_name=False,
|
return_cate_name=False,
|
||||||
@ -55,7 +55,12 @@ class Train:
|
|||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
)
|
)
|
||||||
valid_dataloader = DataLoader(
|
valid_dataloader = DataLoader(
|
||||||
valid_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=1, pin_memory=False, persistent_workers=True
|
valid_dataset,
|
||||||
|
batch_size=self.config['batch_size'],
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=1,
|
||||||
|
pin_memory=False, # not support for mps
|
||||||
|
persistent_workers=True,
|
||||||
)
|
)
|
||||||
return train_dataloader, valid_dataloader
|
return train_dataloader, valid_dataloader
|
||||||
|
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from torchvision.transforms import v2
|
|||||||
class QuickDrawDataset(torch.utils.data.Dataset):
|
class QuickDrawDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_npz_path: str,
|
data_dir: str,
|
||||||
image_shape: tuple[int, int, int] = (1, 28, 28),
|
image_shape: tuple[int, int, int] = (1, 28, 28),
|
||||||
enable_data_aug: bool = False,
|
enable_data_aug: bool = False,
|
||||||
file_lazy_load: bool = False,
|
file_lazy_load: bool = False,
|
||||||
@ -21,13 +21,13 @@ class QuickDrawDataset(torch.utils.data.Dataset):
|
|||||||
vis_dir: str = None,
|
vis_dir: str = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._images: torch.Tensor | np.ndarray = []
|
self._images: torch.Tensor | np.ndarray = None
|
||||||
self._cate_names: list[str] = []
|
self._cate_names: list[str] = []
|
||||||
self._cate_ids: torch.Tensor = []
|
self._cate_ids: torch.Tensor = []
|
||||||
self._transform: callable = None
|
self._transform: callable = None
|
||||||
|
|
||||||
self._enable_data_aug = enable_data_aug
|
self._enable_data_aug = enable_data_aug
|
||||||
self._data_npz_path = data_npz_path
|
self._data_dir = Path(data_dir)
|
||||||
self._image_shape = image_shape
|
self._image_shape = image_shape
|
||||||
self._file_lazy_load = file_lazy_load
|
self._file_lazy_load = file_lazy_load
|
||||||
self._return_cate_name = return_cate_name
|
self._return_cate_name = return_cate_name
|
||||||
@ -55,19 +55,15 @@ class QuickDrawDataset(torch.utils.data.Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _collect_data(self) -> None:
|
def _collect_data(self) -> None:
|
||||||
|
self._cate_names = [cate_name.decode() for cate_name in np.load(self._data_dir / 'cate_names.npy')]
|
||||||
|
self._cate_ids = torch.from_numpy(np.load(self._data_dir / 'cate_ids.npy')).long()
|
||||||
if self._file_lazy_load:
|
if self._file_lazy_load:
|
||||||
self._npz_file = np.load(self._data_npz_path, mmap_mode='r')
|
self._images = np.load(self._data_dir / 'images.npy', mmap_mode='r')
|
||||||
self._cate_names = [cate_name.decode() for cate_name in self._npz_file['cate_names']]
|
|
||||||
self._cate_ids = torch.from_numpy(self._npz_file['cate_ids'].copy()).long()
|
|
||||||
self._images = self._npz_file['images']
|
|
||||||
else:
|
else:
|
||||||
with np.load(self._data_npz_path, mmap_mode=None) as npz_file:
|
self._images = torch.from_numpy(np.load(self._data_dir / 'images.npy'))
|
||||||
self._cate_names = [cate_name.decode() for cate_name in npz_file['cate_names']]
|
|
||||||
self._cate_ids = torch.from_numpy(npz_file['cate_ids']).long()
|
|
||||||
self._images = torch.from_numpy(npz_file['images'])
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self._images)
|
return len(self._cate_ids)
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
|
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
if self._file_lazy_load:
|
if self._file_lazy_load:
|
||||||
|
|||||||
Reference in New Issue
Block a user