diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..25978b0 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,33 @@ +default_language_version: + python: python3.13 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-added-large-files + args: ["--maxkb=512"] + - id: check-yaml + - id: check-toml + - id: check-docstring-first + + - repo: local + hooks: + - id: ruff-format + name: ruff format + entry: uv run ruff format . + language: system + types: [python] + + - id: ruff-check + name: ruff check + entry: uv run ruff check --fix . + language: system + types: [python] + + - id: pytest + name: pytest + entry: uv run pytest tests/ -v + language: system + pass_filenames: false + stages: [pre-push] \ No newline at end of file diff --git a/quickdraw_bot/prepare.py b/quickdraw_bot/prepare.py index 1c17f82..5ae671d 100644 --- a/quickdraw_bot/prepare.py +++ b/quickdraw_bot/prepare.py @@ -15,17 +15,13 @@ class Prepare: def __init__(self, config_path: str = './assets/config.yaml'): self.config = load_config(config_path)['prepare'] self._set_random_seed() - + def _set_random_seed(self): random.seed(self.config['random_seed']) np.random.seed(self.config['random_seed']) - + def _load_dataset(self) -> dict[str, np.ndarray]: - data: dict[str, list] = { - 'images': [], - 'cate_names': [], - 'cate_ids': [] - } + data: dict[str, list] = {'images': [], 'cate_names': [], 'cate_ids': []} cls_id_map: dict[str, int] = {} raw_data_dir = Path(self.config['data_dir']) / 'raw' for npy_file in sorted(raw_data_dir.glob('*.npy')): @@ -49,7 +45,7 @@ class Prepare: data['cate_names'] = np.array(data['cate_names']).astype('S30') data['cate_ids'] = np.array(data['cate_ids']).astype(np.uint16) return data - + def _split_data(self, data: dict[str, np.ndarray]) -> dict[str, dict[str, np.ndarray]]: weights = self.config['data_split'] if abs(sum(weights.values()) - 1.0) > 1e-6: @@ -69,7 +65,7 @@ class Prepare: start = end sets[name] = {key: value[idx] for key, value in data.items()} return sets - + def _save_npz(self, sets: dict[str, dict[str, np.ndarray]]) -> None: save_dir = Path(self.config['data_dir']) / 'processed' save_dir.mkdir(exist_ok=True) diff --git a/quickdraw_bot/train.py b/quickdraw_bot/train.py index a0b8b5f..83f7e61 100644 --- a/quickdraw_bot/train.py +++ b/quickdraw_bot/train.py @@ -24,13 +24,13 @@ class Train: self._device = torch.device(self.config['device_type']) self._ensure_deterministic() - + def _ensure_deterministic(self) -> None: torch.use_deterministic_algorithms(mode=True, warn_only=True) random.seed(self.config['random_seed']) np.random.seed(self.config['random_seed']) torch.manual_seed(self.config['random_seed']) - + def _get_dataloader(self): train_dataset = QuickDrawDataset( data_npz_path=self.config['train_npz'], @@ -52,26 +52,18 @@ class Train: shuffle=True, num_workers=4, pin_memory=False, # not support for mps - persistent_workers=True + persistent_workers=True, ) valid_dataloader = DataLoader( - valid_dataset, - batch_size=self.config['batch_size'], - shuffle=False, - num_workers=1, - pin_memory=False, - persistent_workers=True + valid_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=1, pin_memory=False, persistent_workers=True ) return train_dataloader, valid_dataloader - + def _get_model(self) -> torch.nn.Module: - model = BabyCNN( - num_classes=self.config['num_of_class'], - dropout_p=0.3 - ).to(self._device) + model = BabyCNN(num_classes=self.config['num_of_class'], dropout_p=0.3).to(self._device) model.train() return model - + def _get_optimizer(self, model: torch.nn.Module) -> torch.optim.Optimizer: if self.config['optimizer_name'] == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=self.config['learning_rate']) @@ -80,27 +72,27 @@ class Train: else: raise ValueError(f'Unknown optimizer name: {self.config["optimizer_name"]}') return optimizer - + def _get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler: scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=0.0001) if self.config['warmup_epochs'] > 0: warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=self.config['warmup_epochs']) scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup, scheduler], milestones=[self.config['warmup_epochs']]) return scheduler - + def _get_loss(self) -> torch.nn.modules.loss._Loss: - loss = torch.nn.CrossEntropyLoss( - label_smoothing=0.1 - ).to(self._device) + loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1).to(self._device) return loss - + def _get_metrics(self) -> tuple[MetricCollection, ConfusionMatrix]: - metric_collection = MetricCollection([ - Accuracy(task='multiclass', num_classes=self.config['num_of_class'], top_k=1), - Precision(task='multiclass', num_classes=self.config['num_of_class'], average='macro'), - Recall(task='multiclass', num_classes=self.config['num_of_class'], average='macro'), - F1Score(task='multiclass', num_classes=self.config['num_of_class'], average='macro'), - ]).to(self._device) + metric_collection = MetricCollection( + [ + Accuracy(task='multiclass', num_classes=self.config['num_of_class'], top_k=1), + Precision(task='multiclass', num_classes=self.config['num_of_class'], average='macro'), + Recall(task='multiclass', num_classes=self.config['num_of_class'], average='macro'), + F1Score(task='multiclass', num_classes=self.config['num_of_class'], average='macro'), + ] + ).to(self._device) confusion_matrix = ConfusionMatrix( task='multiclass', threshold=0.5, @@ -116,16 +108,11 @@ class Train: loss = self._get_loss() metrics, _ = self._get_metrics() - with Live( - dir='./doc/exp/train', - report='html', - dvcyaml='./assets/dvc.yaml', - exp_message=self.config['exp_msg']) as live: - + with Live(dir='./doc/exp/train', report='html', dvcyaml='./assets/dvc.yaml', exp_message=self.config['exp_msg']) as live: for epoch in tqdm(range(self.config['num_of_epochs']), desc='Training Epoch'): metrics.reset() model.train() - total_train_loss = 0. + total_train_loss = 0.0 for inputs, targets in train_dataloader: inputs = inputs.to(self._device) targets = targets.to(self._device) @@ -141,7 +128,7 @@ class Train: metrics.reset() model.eval() - total_valid_loss = 0. + total_valid_loss = 0.0 with torch.no_grad(): for inputs, targets in valid_dataloader: inputs = inputs.to(self._device) @@ -166,9 +153,9 @@ class Train: scheduler.step() live.next_step() - + torch.save(model, './assets/model.pth') if __name__ == '__main__': - Train().run() \ No newline at end of file + Train().run() diff --git a/quickdraw_bot/utils/dataset.py b/quickdraw_bot/utils/dataset.py index 469ccf8..115046c 100644 --- a/quickdraw_bot/utils/dataset.py +++ b/quickdraw_bot/utils/dataset.py @@ -11,14 +11,15 @@ from torchvision.transforms import v2 class QuickDrawDataset(torch.utils.data.Dataset): - def __init__(self, - data_npz_path: str, - image_shape: tuple[int, int, int] = (1, 28, 28), - enable_data_aug: bool = False, - file_lazy_load: bool = False, - return_cate_name: bool = False, - vis_dir: str = None, - ) -> None: + def __init__( + self, + data_npz_path: str, + image_shape: tuple[int, int, int] = (1, 28, 28), + enable_data_aug: bool = False, + file_lazy_load: bool = False, + return_cate_name: bool = False, + vis_dir: str = None, + ) -> None: super().__init__() self._images: torch.Tensor | np.ndarray = [] self._cate_names: list[str] = [] @@ -43,13 +44,15 @@ class QuickDrawDataset(torch.utils.data.Dataset): v2.RandomApply([v2.RandomAffine(degrees=(-30, 30), translate=(0.2, 0.2), scale=(0.8, 1.2), shear=(-10, 10))], p=0.5), v2.RandomPerspective(distortion_scale=0.15, p=0.2), v2.RandomApply([v2.ElasticTransform(alpha=15.0, sigma=3.0)], p=0.2), - v2.RandomErasing(p=0.2, scale=(0.02, 0.2)) + v2.RandomErasing(p=0.2, scale=(0.02, 0.2)), ] - self._transform = v2.Compose([ - *aug_pipeline, - v2.Resize(self._image_shape[1:]), - v2.ToDtype(torch.float32, scale=True), - ]) + self._transform = v2.Compose( + [ + *aug_pipeline, + v2.Resize(self._image_shape[1:]), + v2.ToDtype(torch.float32, scale=True), + ] + ) def _collect_data(self) -> None: if self._file_lazy_load: @@ -62,7 +65,7 @@ class QuickDrawDataset(torch.utils.data.Dataset): self._cate_names = [cate_name.decode() for cate_name in npz_file['cate_names']] self._cate_ids = torch.from_numpy(npz_file['cate_ids']).long() self._images = torch.from_numpy(npz_file['images']) - + def __len__(self) -> int: return len(self._images) @@ -82,7 +85,7 @@ class QuickDrawDataset(torch.utils.data.Dataset): if self._return_cate_name: return x, y, self._cate_names[index] return x, y - + def set_data_aug(self, enable_data_aug: bool) -> None: self._enable_data_aug = enable_data_aug self._set_data_transform() diff --git a/quickdraw_bot/utils/model.py b/quickdraw_bot/utils/model.py index 849a606..2d35874 100644 --- a/quickdraw_bot/utils/model.py +++ b/quickdraw_bot/utils/model.py @@ -9,9 +9,7 @@ import torch.nn.functional as F class BabyCNN(nn.Module): - def __init__(self, - num_classes: int = 10, - dropout_p: float = 0.5) -> None: + def __init__(self, num_classes: int = 10, dropout_p: float = 0.5) -> None: super().__init__() # Conv Block 1: 28x28 -> 14x14 @@ -30,7 +28,7 @@ class BabyCNN(nn.Module): self.fc2 = nn.Linear(in_features=128, out_features=num_classes) self._init_weights() - + def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -50,4 +48,4 @@ class BabyCNN(nn.Module): x = x.view(x.size(0), -1) x = self.dropout(F.relu(self.fc1(x))) x = self.fc2(x) - return x \ No newline at end of file + return x diff --git a/quickdraw_bot/utils/utils.py b/quickdraw_bot/utils/utils.py index 467ed63..a3bec12 100644 --- a/quickdraw_bot/utils/utils.py +++ b/quickdraw_bot/utils/utils.py @@ -5,4 +5,4 @@ import yaml def load_config(config_path: str) -> dict: with open(config_path, 'r') as f: - return yaml.safe_load(f) \ No newline at end of file + return yaml.safe_load(f)