# train.py # # author: deng # date : 20231228 from pathlib import Path import yaml import torch from rich.progress import track from torch.utils.data import DataLoader from torchvision.models import resnet50 from torchmetrics.classification import Accuracy from dvclive import Live from utils.dataset import ProcessedDataset def train(params_path: str = 'params.yaml') -> None: """Train a simple model using Pytorch Args: params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'. """ with open(params_path, encoding='utf-8') as f: params = yaml.safe_load(f) data_dir = Path(params['train']['data_dir']) epochs = params['train']['epochs'] batch_size = params['train']['batch_size'] learning_rate = params['train']['learning_rate'] train_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('train')) valid_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('valid')) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2) valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=2) device = torch.device('mps') net = torch.nn.Sequential( resnet50(weights='IMAGENET1K_V1'), torch.nn.Linear(1000, 10) ) net.to(device) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate) metric = Accuracy(task='multiclass', num_classes=10) metric.to(device) with Live(dir='dvclive/train', report='md') as live: live.log_params(params['train']) for epoch in track(range(epochs), 'Epochs'): net.train() train_loss = 0. for data in train_dataloader: inputs, labels = data[0].to(device), data[1].to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) train_loss += loss loss.backward() optimizer.step() _ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1]) train_loss /= len(train_dataloader) train_acc = metric.compute() metric.reset() net.eval() with torch.no_grad(): valid_loss = 0. for data in valid_dataloader: inputs, labels = data[0].to(device), data[1].to(device) outputs = net(inputs) loss = criterion(outputs, labels) valid_loss += loss _ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1]) valid_loss /= len(valid_dataloader) valid_acc = metric.compute() metric.reset() print( f'Epoch {epoch} - train_loss:{train_loss} train_acc:{train_acc} ' f'valid_loss:{valid_loss} valid_acc:{valid_acc}' ) live.log_param('epoch', epoch) live.log_metric('train_loss', float(train_loss.cpu())) live.log_metric('train_acc', float(train_acc.cpu())) live.log_metric('valid_loss', float(valid_loss.cpu())) live.log_metric('valid_acc', float(valid_acc.cpu())) live.next_step() torch.save(net, 'model.pt') if __name__ == '__main__': train()