174 lines
7.2 KiB
Python
174 lines
7.2 KiB
Python
# train.py
|
|
#
|
|
# author: deng
|
|
# date : 20260617
|
|
|
|
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
from dvclive import Live
|
|
from torch.utils.data import DataLoader
|
|
from torchmetrics import MetricCollection
|
|
from torchmetrics.classification import Accuracy, ConfusionMatrix, F1Score, Precision, Recall
|
|
from tqdm import tqdm
|
|
|
|
from quickdraw_bot.utils.dataset import QuickDrawDataset
|
|
from quickdraw_bot.utils.model import BabyCNN
|
|
from quickdraw_bot.utils.utils import load_config
|
|
|
|
|
|
class Train:
|
|
def __init__(self, config_path: str = './assets/config.yaml'):
|
|
self.config = load_config(config_path)['train']
|
|
self._device = torch.device(self.config['device_type'])
|
|
|
|
self._ensure_deterministic()
|
|
|
|
def _ensure_deterministic(self) -> None:
|
|
torch.use_deterministic_algorithms(mode=True, warn_only=True)
|
|
random.seed(self.config['random_seed'])
|
|
np.random.seed(self.config['random_seed'])
|
|
torch.manual_seed(self.config['random_seed'])
|
|
|
|
def _get_dataloader(self):
|
|
train_dataset = QuickDrawDataset(
|
|
data_npz_path=self.config['train_npz'],
|
|
enable_data_aug=True,
|
|
file_lazy_load=self.config['file_lazy_load'],
|
|
return_cate_name=False,
|
|
# vis_dir='./tmp'
|
|
)
|
|
valid_dataset = QuickDrawDataset(
|
|
data_npz_path=self.config['valid_npz'],
|
|
enable_data_aug=False,
|
|
file_lazy_load=self.config['file_lazy_load'],
|
|
return_cate_name=False,
|
|
)
|
|
|
|
train_dataloader = DataLoader(
|
|
train_dataset,
|
|
batch_size=self.config['batch_size'],
|
|
shuffle=True,
|
|
num_workers=4,
|
|
pin_memory=False, # not support for mps
|
|
persistent_workers=True
|
|
)
|
|
valid_dataloader = DataLoader(
|
|
valid_dataset,
|
|
batch_size=self.config['batch_size'],
|
|
shuffle=False,
|
|
num_workers=1,
|
|
pin_memory=False,
|
|
persistent_workers=True
|
|
)
|
|
return train_dataloader, valid_dataloader
|
|
|
|
def _get_model(self) -> torch.nn.Module:
|
|
model = BabyCNN(
|
|
num_classes=self.config['num_of_class'],
|
|
dropout_p=0.3
|
|
).to(self._device)
|
|
model.train()
|
|
return model
|
|
|
|
def _get_optimizer(self, model: torch.nn.Module) -> torch.optim.Optimizer:
|
|
if self.config['optimizer_name'] == 'adam':
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=self.config['learning_rate'])
|
|
elif self.config['optimizer_name'] == 'sgd':
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=self.config['learning_rate'])
|
|
else:
|
|
raise ValueError(f'Unknown optimizer name: {self.config["optimizer_name"]}')
|
|
return optimizer
|
|
|
|
def _get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=0.0001)
|
|
if self.config['warmup_epochs'] > 0:
|
|
warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=self.config['warmup_epochs'])
|
|
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup, scheduler], milestones=[self.config['warmup_epochs']])
|
|
return scheduler
|
|
|
|
def _get_loss(self) -> torch.nn.modules.loss._Loss:
|
|
loss = torch.nn.CrossEntropyLoss(
|
|
label_smoothing=0.1
|
|
).to(self._device)
|
|
return loss
|
|
|
|
def _get_metrics(self) -> tuple[MetricCollection, ConfusionMatrix]:
|
|
metric_collection = MetricCollection([
|
|
Accuracy(task='multiclass', num_classes=self.config['num_of_class'], top_k=1),
|
|
Precision(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
|
Recall(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
|
F1Score(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
|
]).to(self._device)
|
|
confusion_matrix = ConfusionMatrix(
|
|
task='multiclass',
|
|
threshold=0.5,
|
|
num_classes=self.config['num_of_class'],
|
|
).to(self._device)
|
|
return metric_collection, confusion_matrix
|
|
|
|
def run(self):
|
|
train_dataloader, valid_dataloader = self._get_dataloader()
|
|
model = self._get_model()
|
|
optimizer = self._get_optimizer(model)
|
|
scheduler = self._get_scheduler(optimizer)
|
|
loss = self._get_loss()
|
|
metrics, _ = self._get_metrics()
|
|
|
|
with Live(
|
|
dir='./doc/exp/train',
|
|
report='html',
|
|
dvcyaml='./assets/dvc.yaml',
|
|
exp_message=self.config['exp_msg']) as live:
|
|
|
|
for epoch in tqdm(range(self.config['num_of_epochs']), desc='Training Epoch'):
|
|
metrics.reset()
|
|
model.train()
|
|
total_train_loss = 0.
|
|
for inputs, targets in train_dataloader:
|
|
inputs = inputs.to(self._device)
|
|
targets = targets.to(self._device)
|
|
optimizer.zero_grad()
|
|
outputs = model(inputs)
|
|
train_loss = loss(outputs, targets)
|
|
total_train_loss += train_loss.item()
|
|
train_loss.backward()
|
|
optimizer.step()
|
|
metrics.update(outputs, targets)
|
|
train_metrics = metrics.compute()
|
|
avg_train_loss = total_train_loss / len(train_dataloader)
|
|
|
|
metrics.reset()
|
|
model.eval()
|
|
total_valid_loss = 0.
|
|
with torch.no_grad():
|
|
for inputs, targets in valid_dataloader:
|
|
inputs = inputs.to(self._device)
|
|
targets = targets.to(self._device)
|
|
outputs = model(inputs)
|
|
valid_loss = loss(outputs, targets)
|
|
total_valid_loss += valid_loss.item()
|
|
metrics.update(outputs, targets)
|
|
valid_metrics = metrics.compute()
|
|
avg_valid_loss = total_valid_loss / len(valid_dataloader)
|
|
|
|
live.log_metric('train/loss', avg_train_loss)
|
|
live.log_metric('train/accuracy', train_metrics['MulticlassAccuracy'].item())
|
|
live.log_metric('train/precision', train_metrics['MulticlassPrecision'].item())
|
|
live.log_metric('train/recall', train_metrics['MulticlassRecall'].item())
|
|
live.log_metric('train/f1', train_metrics['MulticlassF1Score'].item())
|
|
live.log_metric('valid/loss', avg_valid_loss)
|
|
live.log_metric('valid/accuracy', valid_metrics['MulticlassAccuracy'].item())
|
|
live.log_metric('valid/precision', valid_metrics['MulticlassPrecision'].item())
|
|
live.log_metric('valid/recall', valid_metrics['MulticlassRecall'].item())
|
|
live.log_metric('valid/f1', valid_metrics['MulticlassF1Score'].item())
|
|
|
|
scheduler.step()
|
|
live.next_step()
|
|
|
|
torch.save(model, './assets/model.pth')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
Train().run() |