apply precommit
This commit is contained in:
33
.pre-commit-config.yaml
Normal file
33
.pre-commit-config.yaml
Normal file
@ -0,0 +1,33 @@
|
||||
default_language_version:
|
||||
python: python3.13
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: ["--maxkb=512"]
|
||||
- id: check-yaml
|
||||
- id: check-toml
|
||||
- id: check-docstring-first
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: ruff-format
|
||||
name: ruff format
|
||||
entry: uv run ruff format .
|
||||
language: system
|
||||
types: [python]
|
||||
|
||||
- id: ruff-check
|
||||
name: ruff check
|
||||
entry: uv run ruff check --fix .
|
||||
language: system
|
||||
types: [python]
|
||||
|
||||
- id: pytest
|
||||
name: pytest
|
||||
entry: uv run pytest tests/ -v
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [pre-push]
|
||||
@ -21,11 +21,7 @@ class Prepare:
|
||||
np.random.seed(self.config['random_seed'])
|
||||
|
||||
def _load_dataset(self) -> dict[str, np.ndarray]:
|
||||
data: dict[str, list] = {
|
||||
'images': [],
|
||||
'cate_names': [],
|
||||
'cate_ids': []
|
||||
}
|
||||
data: dict[str, list] = {'images': [], 'cate_names': [], 'cate_ids': []}
|
||||
cls_id_map: dict[str, int] = {}
|
||||
raw_data_dir = Path(self.config['data_dir']) / 'raw'
|
||||
for npy_file in sorted(raw_data_dir.glob('*.npy')):
|
||||
|
||||
@ -52,23 +52,15 @@ class Train:
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
pin_memory=False, # not support for mps
|
||||
persistent_workers=True
|
||||
persistent_workers=True,
|
||||
)
|
||||
valid_dataloader = DataLoader(
|
||||
valid_dataset,
|
||||
batch_size=self.config['batch_size'],
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
pin_memory=False,
|
||||
persistent_workers=True
|
||||
valid_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=1, pin_memory=False, persistent_workers=True
|
||||
)
|
||||
return train_dataloader, valid_dataloader
|
||||
|
||||
def _get_model(self) -> torch.nn.Module:
|
||||
model = BabyCNN(
|
||||
num_classes=self.config['num_of_class'],
|
||||
dropout_p=0.3
|
||||
).to(self._device)
|
||||
model = BabyCNN(num_classes=self.config['num_of_class'], dropout_p=0.3).to(self._device)
|
||||
model.train()
|
||||
return model
|
||||
|
||||
@ -89,18 +81,18 @@ class Train:
|
||||
return scheduler
|
||||
|
||||
def _get_loss(self) -> torch.nn.modules.loss._Loss:
|
||||
loss = torch.nn.CrossEntropyLoss(
|
||||
label_smoothing=0.1
|
||||
).to(self._device)
|
||||
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1).to(self._device)
|
||||
return loss
|
||||
|
||||
def _get_metrics(self) -> tuple[MetricCollection, ConfusionMatrix]:
|
||||
metric_collection = MetricCollection([
|
||||
metric_collection = MetricCollection(
|
||||
[
|
||||
Accuracy(task='multiclass', num_classes=self.config['num_of_class'], top_k=1),
|
||||
Precision(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
||||
Recall(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
||||
F1Score(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
||||
]).to(self._device)
|
||||
]
|
||||
).to(self._device)
|
||||
confusion_matrix = ConfusionMatrix(
|
||||
task='multiclass',
|
||||
threshold=0.5,
|
||||
@ -116,16 +108,11 @@ class Train:
|
||||
loss = self._get_loss()
|
||||
metrics, _ = self._get_metrics()
|
||||
|
||||
with Live(
|
||||
dir='./doc/exp/train',
|
||||
report='html',
|
||||
dvcyaml='./assets/dvc.yaml',
|
||||
exp_message=self.config['exp_msg']) as live:
|
||||
|
||||
with Live(dir='./doc/exp/train', report='html', dvcyaml='./assets/dvc.yaml', exp_message=self.config['exp_msg']) as live:
|
||||
for epoch in tqdm(range(self.config['num_of_epochs']), desc='Training Epoch'):
|
||||
metrics.reset()
|
||||
model.train()
|
||||
total_train_loss = 0.
|
||||
total_train_loss = 0.0
|
||||
for inputs, targets in train_dataloader:
|
||||
inputs = inputs.to(self._device)
|
||||
targets = targets.to(self._device)
|
||||
@ -141,7 +128,7 @@ class Train:
|
||||
|
||||
metrics.reset()
|
||||
model.eval()
|
||||
total_valid_loss = 0.
|
||||
total_valid_loss = 0.0
|
||||
with torch.no_grad():
|
||||
for inputs, targets in valid_dataloader:
|
||||
inputs = inputs.to(self._device)
|
||||
|
||||
@ -11,7 +11,8 @@ from torchvision.transforms import v2
|
||||
|
||||
|
||||
class QuickDrawDataset(torch.utils.data.Dataset):
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
data_npz_path: str,
|
||||
image_shape: tuple[int, int, int] = (1, 28, 28),
|
||||
enable_data_aug: bool = False,
|
||||
@ -43,13 +44,15 @@ class QuickDrawDataset(torch.utils.data.Dataset):
|
||||
v2.RandomApply([v2.RandomAffine(degrees=(-30, 30), translate=(0.2, 0.2), scale=(0.8, 1.2), shear=(-10, 10))], p=0.5),
|
||||
v2.RandomPerspective(distortion_scale=0.15, p=0.2),
|
||||
v2.RandomApply([v2.ElasticTransform(alpha=15.0, sigma=3.0)], p=0.2),
|
||||
v2.RandomErasing(p=0.2, scale=(0.02, 0.2))
|
||||
v2.RandomErasing(p=0.2, scale=(0.02, 0.2)),
|
||||
]
|
||||
self._transform = v2.Compose([
|
||||
self._transform = v2.Compose(
|
||||
[
|
||||
*aug_pipeline,
|
||||
v2.Resize(self._image_shape[1:]),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
def _collect_data(self) -> None:
|
||||
if self._file_lazy_load:
|
||||
|
||||
@ -9,9 +9,7 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class BabyCNN(nn.Module):
|
||||
def __init__(self,
|
||||
num_classes: int = 10,
|
||||
dropout_p: float = 0.5) -> None:
|
||||
def __init__(self, num_classes: int = 10, dropout_p: float = 0.5) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Conv Block 1: 28x28 -> 14x14
|
||||
|
||||
Reference in New Issue
Block a user