test_dvc/train.py

100 lines
3.3 KiB
Python

# train.py
#
# author: deng
# date : 20231228
from pathlib import Path
import yaml
import torch
from rich.progress import track
from torch.utils.data import DataLoader
from torchvision.models import resnet50
from torchmetrics.classification import Accuracy
from dvclive import Live
from utils.dataset import ProcessedDataset
def train(params_path: str = 'params.yaml') -> None:
"""Train a simple model using Pytorch
Args:
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
"""
with open(params_path, encoding='utf-8') as f:
params = yaml.safe_load(f)
data_dir = Path(params['train']['data_dir'])
epochs = params['train']['epochs']
batch_size = params['train']['batch_size']
learning_rate = params['train']['learning_rate']
train_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('train'))
valid_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('valid'))
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=2)
device = torch.device('mps')
net = torch.nn.Sequential(
resnet50(weights='IMAGENET1K_V1'),
torch.nn.Linear(1000, 10)
)
net.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
metric = Accuracy(task='multiclass', num_classes=10)
metric.to(device)
with Live(dir='dvclive/train', report='md') as live:
live.log_params(params['train'])
for epoch in track(range(epochs), 'Epochs'):
net.train()
train_loss = 0.
for data in train_dataloader:
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
train_loss += loss
loss.backward()
optimizer.step()
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
train_loss /= len(train_dataloader)
train_acc = metric.compute()
metric.reset()
net.eval()
with torch.no_grad():
valid_loss = 0.
for data in valid_dataloader:
inputs, labels = data[0].to(device), data[1].to(device)
outputs = net(inputs)
loss = criterion(outputs, labels)
valid_loss += loss
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
valid_loss /= len(valid_dataloader)
valid_acc = metric.compute()
metric.reset()
print(
f'Epoch {epoch} - train_loss:{train_loss} train_acc:{train_acc} '
f'valid_loss:{valid_loss} valid_acc:{valid_acc}'
)
live.log_param('epoch', epoch)
live.log_metric('train_loss', float(train_loss.cpu()))
live.log_metric('train_acc', float(train_acc.cpu()))
live.log_metric('valid_loss', float(valid_loss.cpu()))
live.log_metric('valid_acc', float(valid_acc.cpu()))
live.next_step()
torch.save(net, 'model.pt')
if __name__ == '__main__':
train()