Files
quickdraw_bot/quickdraw_bot/train.py
2026-06-18 09:16:42 +08:00

174 lines
7.2 KiB
Python

# train.py
#
# author: deng
# date : 20260617
import random
import numpy as np
import torch
from dvclive import Live
from torch.utils.data import DataLoader
from torchmetrics import MetricCollection
from torchmetrics.classification import Accuracy, ConfusionMatrix, F1Score, Precision, Recall
from tqdm import tqdm
from quickdraw_bot.utils.dataset import QuickDrawDataset
from quickdraw_bot.utils.model import BabyCNN
from quickdraw_bot.utils.utils import load_config
class Train:
def __init__(self, config_path: str = './assets/config.yaml'):
self.config = load_config(config_path)['train']
self._device = torch.device(self.config['device_type'])
self._ensure_deterministic()
def _ensure_deterministic(self) -> None:
torch.use_deterministic_algorithms(mode=True, warn_only=True)
random.seed(self.config['random_seed'])
np.random.seed(self.config['random_seed'])
torch.manual_seed(self.config['random_seed'])
def _get_dataloader(self):
train_dataset = QuickDrawDataset(
data_npz_path=self.config['train_npz'],
enable_data_aug=True,
file_lazy_load=self.config['file_lazy_load'],
return_cate_name=False,
# vis_dir='./tmp'
)
valid_dataset = QuickDrawDataset(
data_npz_path=self.config['valid_npz'],
enable_data_aug=False,
file_lazy_load=self.config['file_lazy_load'],
return_cate_name=False,
)
train_dataloader = DataLoader(
train_dataset,
batch_size=self.config['batch_size'],
shuffle=True,
num_workers=4,
pin_memory=False, # not support for mps
persistent_workers=True
)
valid_dataloader = DataLoader(
valid_dataset,
batch_size=self.config['batch_size'],
shuffle=False,
num_workers=1,
pin_memory=False,
persistent_workers=True
)
return train_dataloader, valid_dataloader
def _get_model(self) -> torch.nn.Module:
model = BabyCNN(
num_classes=self.config['num_of_class'],
dropout_p=0.3
).to(self._device)
model.train()
return model
def _get_optimizer(self, model: torch.nn.Module) -> torch.optim.Optimizer:
if self.config['optimizer_name'] == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=self.config['learning_rate'])
elif self.config['optimizer_name'] == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), lr=self.config['learning_rate'])
else:
raise ValueError(f'Unknown optimizer name: {self.config["optimizer_name"]}')
return optimizer
def _get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=0.0001)
if self.config['warmup_epochs'] > 0:
warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=self.config['warmup_epochs'])
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup, scheduler], milestones=[self.config['warmup_epochs']])
return scheduler
def _get_loss(self) -> torch.nn.modules.loss._Loss:
loss = torch.nn.CrossEntropyLoss(
label_smoothing=0.1
).to(self._device)
return loss
def _get_metrics(self) -> tuple[MetricCollection, ConfusionMatrix]:
metric_collection = MetricCollection([
Accuracy(task='multiclass', num_classes=self.config['num_of_class'], top_k=1),
Precision(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
Recall(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
F1Score(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
]).to(self._device)
confusion_matrix = ConfusionMatrix(
task='multiclass',
threshold=0.5,
num_classes=self.config['num_of_class'],
).to(self._device)
return metric_collection, confusion_matrix
def run(self):
train_dataloader, valid_dataloader = self._get_dataloader()
model = self._get_model()
optimizer = self._get_optimizer(model)
scheduler = self._get_scheduler(optimizer)
loss = self._get_loss()
metrics, _ = self._get_metrics()
with Live(
dir='./doc/exp/train',
report='html',
dvcyaml='./assets/dvc.yaml',
exp_message=self.config['exp_msg']) as live:
for epoch in tqdm(range(self.config['num_of_epochs']), desc='Training Epoch'):
metrics.reset()
model.train()
total_train_loss = 0.
for inputs, targets in train_dataloader:
inputs = inputs.to(self._device)
targets = targets.to(self._device)
optimizer.zero_grad()
outputs = model(inputs)
train_loss = loss(outputs, targets)
total_train_loss += train_loss.item()
train_loss.backward()
optimizer.step()
metrics.update(outputs, targets)
train_metrics = metrics.compute()
avg_train_loss = total_train_loss / len(train_dataloader)
metrics.reset()
model.eval()
total_valid_loss = 0.
with torch.no_grad():
for inputs, targets in valid_dataloader:
inputs = inputs.to(self._device)
targets = targets.to(self._device)
outputs = model(inputs)
valid_loss = loss(outputs, targets)
total_valid_loss += valid_loss.item()
metrics.update(outputs, targets)
valid_metrics = metrics.compute()
avg_valid_loss = total_valid_loss / len(valid_dataloader)
live.log_metric('train/loss', avg_train_loss)
live.log_metric('train/accuracy', train_metrics['MulticlassAccuracy'].item())
live.log_metric('train/precision', train_metrics['MulticlassPrecision'].item())
live.log_metric('train/recall', train_metrics['MulticlassRecall'].item())
live.log_metric('train/f1', train_metrics['MulticlassF1Score'].item())
live.log_metric('valid/loss', avg_valid_loss)
live.log_metric('valid/accuracy', valid_metrics['MulticlassAccuracy'].item())
live.log_metric('valid/precision', valid_metrics['MulticlassPrecision'].item())
live.log_metric('valid/recall', valid_metrics['MulticlassRecall'].item())
live.log_metric('valid/f1', valid_metrics['MulticlassF1Score'].item())
scheduler.step()
live.next_step()
torch.save(model, './assets/model.pth')
if __name__ == '__main__':
Train().run()