apply precommit

This commit is contained in:
2026-06-18 09:27:42 +08:00
parent 2f2db72db1
commit 945b04bb56
6 changed files with 85 additions and 68 deletions

33
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,33 @@
default_language_version:
python: python3.13
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-added-large-files
args: ["--maxkb=512"]
- id: check-yaml
- id: check-toml
- id: check-docstring-first
- repo: local
hooks:
- id: ruff-format
name: ruff format
entry: uv run ruff format .
language: system
types: [python]
- id: ruff-check
name: ruff check
entry: uv run ruff check --fix .
language: system
types: [python]
- id: pytest
name: pytest
entry: uv run pytest tests/ -v
language: system
pass_filenames: false
stages: [pre-push]

View File

@ -21,11 +21,7 @@ class Prepare:
np.random.seed(self.config['random_seed']) np.random.seed(self.config['random_seed'])
def _load_dataset(self) -> dict[str, np.ndarray]: def _load_dataset(self) -> dict[str, np.ndarray]:
data: dict[str, list] = { data: dict[str, list] = {'images': [], 'cate_names': [], 'cate_ids': []}
'images': [],
'cate_names': [],
'cate_ids': []
}
cls_id_map: dict[str, int] = {} cls_id_map: dict[str, int] = {}
raw_data_dir = Path(self.config['data_dir']) / 'raw' raw_data_dir = Path(self.config['data_dir']) / 'raw'
for npy_file in sorted(raw_data_dir.glob('*.npy')): for npy_file in sorted(raw_data_dir.glob('*.npy')):

View File

@ -52,23 +52,15 @@ class Train:
shuffle=True, shuffle=True,
num_workers=4, num_workers=4,
pin_memory=False, # not support for mps pin_memory=False, # not support for mps
persistent_workers=True persistent_workers=True,
) )
valid_dataloader = DataLoader( valid_dataloader = DataLoader(
valid_dataset, valid_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=1, pin_memory=False, persistent_workers=True
batch_size=self.config['batch_size'],
shuffle=False,
num_workers=1,
pin_memory=False,
persistent_workers=True
) )
return train_dataloader, valid_dataloader return train_dataloader, valid_dataloader
def _get_model(self) -> torch.nn.Module: def _get_model(self) -> torch.nn.Module:
model = BabyCNN( model = BabyCNN(num_classes=self.config['num_of_class'], dropout_p=0.3).to(self._device)
num_classes=self.config['num_of_class'],
dropout_p=0.3
).to(self._device)
model.train() model.train()
return model return model
@ -89,18 +81,18 @@ class Train:
return scheduler return scheduler
def _get_loss(self) -> torch.nn.modules.loss._Loss: def _get_loss(self) -> torch.nn.modules.loss._Loss:
loss = torch.nn.CrossEntropyLoss( loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1).to(self._device)
label_smoothing=0.1
).to(self._device)
return loss return loss
def _get_metrics(self) -> tuple[MetricCollection, ConfusionMatrix]: def _get_metrics(self) -> tuple[MetricCollection, ConfusionMatrix]:
metric_collection = MetricCollection([ metric_collection = MetricCollection(
Accuracy(task='multiclass', num_classes=self.config['num_of_class'], top_k=1), [
Precision(task='multiclass', num_classes=self.config['num_of_class'], average='macro'), Accuracy(task='multiclass', num_classes=self.config['num_of_class'], top_k=1),
Recall(task='multiclass', num_classes=self.config['num_of_class'], average='macro'), Precision(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
F1Score(task='multiclass', num_classes=self.config['num_of_class'], average='macro'), Recall(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
]).to(self._device) F1Score(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
]
).to(self._device)
confusion_matrix = ConfusionMatrix( confusion_matrix = ConfusionMatrix(
task='multiclass', task='multiclass',
threshold=0.5, threshold=0.5,
@ -116,16 +108,11 @@ class Train:
loss = self._get_loss() loss = self._get_loss()
metrics, _ = self._get_metrics() metrics, _ = self._get_metrics()
with Live( with Live(dir='./doc/exp/train', report='html', dvcyaml='./assets/dvc.yaml', exp_message=self.config['exp_msg']) as live:
dir='./doc/exp/train',
report='html',
dvcyaml='./assets/dvc.yaml',
exp_message=self.config['exp_msg']) as live:
for epoch in tqdm(range(self.config['num_of_epochs']), desc='Training Epoch'): for epoch in tqdm(range(self.config['num_of_epochs']), desc='Training Epoch'):
metrics.reset() metrics.reset()
model.train() model.train()
total_train_loss = 0. total_train_loss = 0.0
for inputs, targets in train_dataloader: for inputs, targets in train_dataloader:
inputs = inputs.to(self._device) inputs = inputs.to(self._device)
targets = targets.to(self._device) targets = targets.to(self._device)
@ -141,7 +128,7 @@ class Train:
metrics.reset() metrics.reset()
model.eval() model.eval()
total_valid_loss = 0. total_valid_loss = 0.0
with torch.no_grad(): with torch.no_grad():
for inputs, targets in valid_dataloader: for inputs, targets in valid_dataloader:
inputs = inputs.to(self._device) inputs = inputs.to(self._device)

View File

@ -11,14 +11,15 @@ from torchvision.transforms import v2
class QuickDrawDataset(torch.utils.data.Dataset): class QuickDrawDataset(torch.utils.data.Dataset):
def __init__(self, def __init__(
data_npz_path: str, self,
image_shape: tuple[int, int, int] = (1, 28, 28), data_npz_path: str,
enable_data_aug: bool = False, image_shape: tuple[int, int, int] = (1, 28, 28),
file_lazy_load: bool = False, enable_data_aug: bool = False,
return_cate_name: bool = False, file_lazy_load: bool = False,
vis_dir: str = None, return_cate_name: bool = False,
) -> None: vis_dir: str = None,
) -> None:
super().__init__() super().__init__()
self._images: torch.Tensor | np.ndarray = [] self._images: torch.Tensor | np.ndarray = []
self._cate_names: list[str] = [] self._cate_names: list[str] = []
@ -43,13 +44,15 @@ class QuickDrawDataset(torch.utils.data.Dataset):
v2.RandomApply([v2.RandomAffine(degrees=(-30, 30), translate=(0.2, 0.2), scale=(0.8, 1.2), shear=(-10, 10))], p=0.5), v2.RandomApply([v2.RandomAffine(degrees=(-30, 30), translate=(0.2, 0.2), scale=(0.8, 1.2), shear=(-10, 10))], p=0.5),
v2.RandomPerspective(distortion_scale=0.15, p=0.2), v2.RandomPerspective(distortion_scale=0.15, p=0.2),
v2.RandomApply([v2.ElasticTransform(alpha=15.0, sigma=3.0)], p=0.2), v2.RandomApply([v2.ElasticTransform(alpha=15.0, sigma=3.0)], p=0.2),
v2.RandomErasing(p=0.2, scale=(0.02, 0.2)) v2.RandomErasing(p=0.2, scale=(0.02, 0.2)),
] ]
self._transform = v2.Compose([ self._transform = v2.Compose(
*aug_pipeline, [
v2.Resize(self._image_shape[1:]), *aug_pipeline,
v2.ToDtype(torch.float32, scale=True), v2.Resize(self._image_shape[1:]),
]) v2.ToDtype(torch.float32, scale=True),
]
)
def _collect_data(self) -> None: def _collect_data(self) -> None:
if self._file_lazy_load: if self._file_lazy_load:

View File

@ -9,9 +9,7 @@ import torch.nn.functional as F
class BabyCNN(nn.Module): class BabyCNN(nn.Module):
def __init__(self, def __init__(self, num_classes: int = 10, dropout_p: float = 0.5) -> None:
num_classes: int = 10,
dropout_p: float = 0.5) -> None:
super().__init__() super().__init__()
# Conv Block 1: 28x28 -> 14x14 # Conv Block 1: 28x28 -> 14x14