apply precommit
This commit is contained in:
33
.pre-commit-config.yaml
Normal file
33
.pre-commit-config.yaml
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
default_language_version:
|
||||||
|
python: python3.13
|
||||||
|
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: ["--maxkb=512"]
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-toml
|
||||||
|
- id: check-docstring-first
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: ruff-format
|
||||||
|
name: ruff format
|
||||||
|
entry: uv run ruff format .
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
|
||||||
|
- id: ruff-check
|
||||||
|
name: ruff check
|
||||||
|
entry: uv run ruff check --fix .
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
|
||||||
|
- id: pytest
|
||||||
|
name: pytest
|
||||||
|
entry: uv run pytest tests/ -v
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
stages: [pre-push]
|
||||||
@ -15,17 +15,13 @@ class Prepare:
|
|||||||
def __init__(self, config_path: str = './assets/config.yaml'):
|
def __init__(self, config_path: str = './assets/config.yaml'):
|
||||||
self.config = load_config(config_path)['prepare']
|
self.config = load_config(config_path)['prepare']
|
||||||
self._set_random_seed()
|
self._set_random_seed()
|
||||||
|
|
||||||
def _set_random_seed(self):
|
def _set_random_seed(self):
|
||||||
random.seed(self.config['random_seed'])
|
random.seed(self.config['random_seed'])
|
||||||
np.random.seed(self.config['random_seed'])
|
np.random.seed(self.config['random_seed'])
|
||||||
|
|
||||||
def _load_dataset(self) -> dict[str, np.ndarray]:
|
def _load_dataset(self) -> dict[str, np.ndarray]:
|
||||||
data: dict[str, list] = {
|
data: dict[str, list] = {'images': [], 'cate_names': [], 'cate_ids': []}
|
||||||
'images': [],
|
|
||||||
'cate_names': [],
|
|
||||||
'cate_ids': []
|
|
||||||
}
|
|
||||||
cls_id_map: dict[str, int] = {}
|
cls_id_map: dict[str, int] = {}
|
||||||
raw_data_dir = Path(self.config['data_dir']) / 'raw'
|
raw_data_dir = Path(self.config['data_dir']) / 'raw'
|
||||||
for npy_file in sorted(raw_data_dir.glob('*.npy')):
|
for npy_file in sorted(raw_data_dir.glob('*.npy')):
|
||||||
@ -49,7 +45,7 @@ class Prepare:
|
|||||||
data['cate_names'] = np.array(data['cate_names']).astype('S30')
|
data['cate_names'] = np.array(data['cate_names']).astype('S30')
|
||||||
data['cate_ids'] = np.array(data['cate_ids']).astype(np.uint16)
|
data['cate_ids'] = np.array(data['cate_ids']).astype(np.uint16)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _split_data(self, data: dict[str, np.ndarray]) -> dict[str, dict[str, np.ndarray]]:
|
def _split_data(self, data: dict[str, np.ndarray]) -> dict[str, dict[str, np.ndarray]]:
|
||||||
weights = self.config['data_split']
|
weights = self.config['data_split']
|
||||||
if abs(sum(weights.values()) - 1.0) > 1e-6:
|
if abs(sum(weights.values()) - 1.0) > 1e-6:
|
||||||
@ -69,7 +65,7 @@ class Prepare:
|
|||||||
start = end
|
start = end
|
||||||
sets[name] = {key: value[idx] for key, value in data.items()}
|
sets[name] = {key: value[idx] for key, value in data.items()}
|
||||||
return sets
|
return sets
|
||||||
|
|
||||||
def _save_npz(self, sets: dict[str, dict[str, np.ndarray]]) -> None:
|
def _save_npz(self, sets: dict[str, dict[str, np.ndarray]]) -> None:
|
||||||
save_dir = Path(self.config['data_dir']) / 'processed'
|
save_dir = Path(self.config['data_dir']) / 'processed'
|
||||||
save_dir.mkdir(exist_ok=True)
|
save_dir.mkdir(exist_ok=True)
|
||||||
|
|||||||
@ -24,13 +24,13 @@ class Train:
|
|||||||
self._device = torch.device(self.config['device_type'])
|
self._device = torch.device(self.config['device_type'])
|
||||||
|
|
||||||
self._ensure_deterministic()
|
self._ensure_deterministic()
|
||||||
|
|
||||||
def _ensure_deterministic(self) -> None:
|
def _ensure_deterministic(self) -> None:
|
||||||
torch.use_deterministic_algorithms(mode=True, warn_only=True)
|
torch.use_deterministic_algorithms(mode=True, warn_only=True)
|
||||||
random.seed(self.config['random_seed'])
|
random.seed(self.config['random_seed'])
|
||||||
np.random.seed(self.config['random_seed'])
|
np.random.seed(self.config['random_seed'])
|
||||||
torch.manual_seed(self.config['random_seed'])
|
torch.manual_seed(self.config['random_seed'])
|
||||||
|
|
||||||
def _get_dataloader(self):
|
def _get_dataloader(self):
|
||||||
train_dataset = QuickDrawDataset(
|
train_dataset = QuickDrawDataset(
|
||||||
data_npz_path=self.config['train_npz'],
|
data_npz_path=self.config['train_npz'],
|
||||||
@ -52,26 +52,18 @@ class Train:
|
|||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
pin_memory=False, # not support for mps
|
pin_memory=False, # not support for mps
|
||||||
persistent_workers=True
|
persistent_workers=True,
|
||||||
)
|
)
|
||||||
valid_dataloader = DataLoader(
|
valid_dataloader = DataLoader(
|
||||||
valid_dataset,
|
valid_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=1, pin_memory=False, persistent_workers=True
|
||||||
batch_size=self.config['batch_size'],
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=1,
|
|
||||||
pin_memory=False,
|
|
||||||
persistent_workers=True
|
|
||||||
)
|
)
|
||||||
return train_dataloader, valid_dataloader
|
return train_dataloader, valid_dataloader
|
||||||
|
|
||||||
def _get_model(self) -> torch.nn.Module:
|
def _get_model(self) -> torch.nn.Module:
|
||||||
model = BabyCNN(
|
model = BabyCNN(num_classes=self.config['num_of_class'], dropout_p=0.3).to(self._device)
|
||||||
num_classes=self.config['num_of_class'],
|
|
||||||
dropout_p=0.3
|
|
||||||
).to(self._device)
|
|
||||||
model.train()
|
model.train()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _get_optimizer(self, model: torch.nn.Module) -> torch.optim.Optimizer:
|
def _get_optimizer(self, model: torch.nn.Module) -> torch.optim.Optimizer:
|
||||||
if self.config['optimizer_name'] == 'adam':
|
if self.config['optimizer_name'] == 'adam':
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=self.config['learning_rate'])
|
optimizer = torch.optim.Adam(model.parameters(), lr=self.config['learning_rate'])
|
||||||
@ -80,27 +72,27 @@ class Train:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f'Unknown optimizer name: {self.config["optimizer_name"]}')
|
raise ValueError(f'Unknown optimizer name: {self.config["optimizer_name"]}')
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
def _get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
def _get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=0.0001)
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=0.0001)
|
||||||
if self.config['warmup_epochs'] > 0:
|
if self.config['warmup_epochs'] > 0:
|
||||||
warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=self.config['warmup_epochs'])
|
warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=self.config['warmup_epochs'])
|
||||||
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup, scheduler], milestones=[self.config['warmup_epochs']])
|
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup, scheduler], milestones=[self.config['warmup_epochs']])
|
||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
def _get_loss(self) -> torch.nn.modules.loss._Loss:
|
def _get_loss(self) -> torch.nn.modules.loss._Loss:
|
||||||
loss = torch.nn.CrossEntropyLoss(
|
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1).to(self._device)
|
||||||
label_smoothing=0.1
|
|
||||||
).to(self._device)
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _get_metrics(self) -> tuple[MetricCollection, ConfusionMatrix]:
|
def _get_metrics(self) -> tuple[MetricCollection, ConfusionMatrix]:
|
||||||
metric_collection = MetricCollection([
|
metric_collection = MetricCollection(
|
||||||
Accuracy(task='multiclass', num_classes=self.config['num_of_class'], top_k=1),
|
[
|
||||||
Precision(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
Accuracy(task='multiclass', num_classes=self.config['num_of_class'], top_k=1),
|
||||||
Recall(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
Precision(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
||||||
F1Score(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
Recall(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
||||||
]).to(self._device)
|
F1Score(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
||||||
|
]
|
||||||
|
).to(self._device)
|
||||||
confusion_matrix = ConfusionMatrix(
|
confusion_matrix = ConfusionMatrix(
|
||||||
task='multiclass',
|
task='multiclass',
|
||||||
threshold=0.5,
|
threshold=0.5,
|
||||||
@ -116,16 +108,11 @@ class Train:
|
|||||||
loss = self._get_loss()
|
loss = self._get_loss()
|
||||||
metrics, _ = self._get_metrics()
|
metrics, _ = self._get_metrics()
|
||||||
|
|
||||||
with Live(
|
with Live(dir='./doc/exp/train', report='html', dvcyaml='./assets/dvc.yaml', exp_message=self.config['exp_msg']) as live:
|
||||||
dir='./doc/exp/train',
|
|
||||||
report='html',
|
|
||||||
dvcyaml='./assets/dvc.yaml',
|
|
||||||
exp_message=self.config['exp_msg']) as live:
|
|
||||||
|
|
||||||
for epoch in tqdm(range(self.config['num_of_epochs']), desc='Training Epoch'):
|
for epoch in tqdm(range(self.config['num_of_epochs']), desc='Training Epoch'):
|
||||||
metrics.reset()
|
metrics.reset()
|
||||||
model.train()
|
model.train()
|
||||||
total_train_loss = 0.
|
total_train_loss = 0.0
|
||||||
for inputs, targets in train_dataloader:
|
for inputs, targets in train_dataloader:
|
||||||
inputs = inputs.to(self._device)
|
inputs = inputs.to(self._device)
|
||||||
targets = targets.to(self._device)
|
targets = targets.to(self._device)
|
||||||
@ -141,7 +128,7 @@ class Train:
|
|||||||
|
|
||||||
metrics.reset()
|
metrics.reset()
|
||||||
model.eval()
|
model.eval()
|
||||||
total_valid_loss = 0.
|
total_valid_loss = 0.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for inputs, targets in valid_dataloader:
|
for inputs, targets in valid_dataloader:
|
||||||
inputs = inputs.to(self._device)
|
inputs = inputs.to(self._device)
|
||||||
@ -166,9 +153,9 @@ class Train:
|
|||||||
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
live.next_step()
|
live.next_step()
|
||||||
|
|
||||||
torch.save(model, './assets/model.pth')
|
torch.save(model, './assets/model.pth')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
Train().run()
|
Train().run()
|
||||||
|
|||||||
@ -11,14 +11,15 @@ from torchvision.transforms import v2
|
|||||||
|
|
||||||
|
|
||||||
class QuickDrawDataset(torch.utils.data.Dataset):
|
class QuickDrawDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
data_npz_path: str,
|
self,
|
||||||
image_shape: tuple[int, int, int] = (1, 28, 28),
|
data_npz_path: str,
|
||||||
enable_data_aug: bool = False,
|
image_shape: tuple[int, int, int] = (1, 28, 28),
|
||||||
file_lazy_load: bool = False,
|
enable_data_aug: bool = False,
|
||||||
return_cate_name: bool = False,
|
file_lazy_load: bool = False,
|
||||||
vis_dir: str = None,
|
return_cate_name: bool = False,
|
||||||
) -> None:
|
vis_dir: str = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._images: torch.Tensor | np.ndarray = []
|
self._images: torch.Tensor | np.ndarray = []
|
||||||
self._cate_names: list[str] = []
|
self._cate_names: list[str] = []
|
||||||
@ -43,13 +44,15 @@ class QuickDrawDataset(torch.utils.data.Dataset):
|
|||||||
v2.RandomApply([v2.RandomAffine(degrees=(-30, 30), translate=(0.2, 0.2), scale=(0.8, 1.2), shear=(-10, 10))], p=0.5),
|
v2.RandomApply([v2.RandomAffine(degrees=(-30, 30), translate=(0.2, 0.2), scale=(0.8, 1.2), shear=(-10, 10))], p=0.5),
|
||||||
v2.RandomPerspective(distortion_scale=0.15, p=0.2),
|
v2.RandomPerspective(distortion_scale=0.15, p=0.2),
|
||||||
v2.RandomApply([v2.ElasticTransform(alpha=15.0, sigma=3.0)], p=0.2),
|
v2.RandomApply([v2.ElasticTransform(alpha=15.0, sigma=3.0)], p=0.2),
|
||||||
v2.RandomErasing(p=0.2, scale=(0.02, 0.2))
|
v2.RandomErasing(p=0.2, scale=(0.02, 0.2)),
|
||||||
]
|
]
|
||||||
self._transform = v2.Compose([
|
self._transform = v2.Compose(
|
||||||
*aug_pipeline,
|
[
|
||||||
v2.Resize(self._image_shape[1:]),
|
*aug_pipeline,
|
||||||
v2.ToDtype(torch.float32, scale=True),
|
v2.Resize(self._image_shape[1:]),
|
||||||
])
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def _collect_data(self) -> None:
|
def _collect_data(self) -> None:
|
||||||
if self._file_lazy_load:
|
if self._file_lazy_load:
|
||||||
@ -62,7 +65,7 @@ class QuickDrawDataset(torch.utils.data.Dataset):
|
|||||||
self._cate_names = [cate_name.decode() for cate_name in npz_file['cate_names']]
|
self._cate_names = [cate_name.decode() for cate_name in npz_file['cate_names']]
|
||||||
self._cate_ids = torch.from_numpy(npz_file['cate_ids']).long()
|
self._cate_ids = torch.from_numpy(npz_file['cate_ids']).long()
|
||||||
self._images = torch.from_numpy(npz_file['images'])
|
self._images = torch.from_numpy(npz_file['images'])
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self._images)
|
return len(self._images)
|
||||||
|
|
||||||
@ -82,7 +85,7 @@ class QuickDrawDataset(torch.utils.data.Dataset):
|
|||||||
if self._return_cate_name:
|
if self._return_cate_name:
|
||||||
return x, y, self._cate_names[index]
|
return x, y, self._cate_names[index]
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
def set_data_aug(self, enable_data_aug: bool) -> None:
|
def set_data_aug(self, enable_data_aug: bool) -> None:
|
||||||
self._enable_data_aug = enable_data_aug
|
self._enable_data_aug = enable_data_aug
|
||||||
self._set_data_transform()
|
self._set_data_transform()
|
||||||
|
|||||||
@ -9,9 +9,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
|
|
||||||
class BabyCNN(nn.Module):
|
class BabyCNN(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self, num_classes: int = 10, dropout_p: float = 0.5) -> None:
|
||||||
num_classes: int = 10,
|
|
||||||
dropout_p: float = 0.5) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Conv Block 1: 28x28 -> 14x14
|
# Conv Block 1: 28x28 -> 14x14
|
||||||
@ -30,7 +28,7 @@ class BabyCNN(nn.Module):
|
|||||||
self.fc2 = nn.Linear(in_features=128, out_features=num_classes)
|
self.fc2 = nn.Linear(in_features=128, out_features=num_classes)
|
||||||
|
|
||||||
self._init_weights()
|
self._init_weights()
|
||||||
|
|
||||||
def _init_weights(self):
|
def _init_weights(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
@ -50,4 +48,4 @@ class BabyCNN(nn.Module):
|
|||||||
x = x.view(x.size(0), -1)
|
x = x.view(x.size(0), -1)
|
||||||
x = self.dropout(F.relu(self.fc1(x)))
|
x = self.dropout(F.relu(self.fc1(x)))
|
||||||
x = self.fc2(x)
|
x = self.fc2(x)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -5,4 +5,4 @@ import yaml
|
|||||||
|
|
||||||
def load_config(config_path: str) -> dict:
|
def load_config(config_path: str) -> dict:
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
return yaml.safe_load(f)
|
return yaml.safe_load(f)
|
||||||
|
|||||||
Reference in New Issue
Block a user