100 lines
3.3 KiB
Python
100 lines
3.3 KiB
Python
# train.py
|
|
#
|
|
# author: deng
|
|
# date : 20231228
|
|
|
|
from pathlib import Path
|
|
import yaml
|
|
|
|
import torch
|
|
from rich.progress import track
|
|
from torch.utils.data import DataLoader
|
|
from torchvision.models import resnet50
|
|
from torchmetrics.classification import Accuracy
|
|
from dvclive import Live
|
|
|
|
from utils.dataset import ProcessedDataset
|
|
|
|
|
|
def train(params_path: str = 'params.yaml') -> None:
|
|
"""Train a simple model using Pytorch
|
|
|
|
Args:
|
|
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
|
|
"""
|
|
|
|
with open(params_path, encoding='utf-8') as f:
|
|
params = yaml.safe_load(f)
|
|
data_dir = Path(params['train']['data_dir'])
|
|
epochs = params['train']['epochs']
|
|
batch_size = params['train']['batch_size']
|
|
learning_rate = params['train']['learning_rate']
|
|
|
|
train_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('train'))
|
|
valid_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('valid'))
|
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2)
|
|
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=2)
|
|
|
|
device = torch.device('mps')
|
|
net = torch.nn.Sequential(
|
|
resnet50(weights='IMAGENET1K_V1'),
|
|
torch.nn.Linear(1000, 10)
|
|
)
|
|
net.to(device)
|
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
|
|
metric = Accuracy(task='multiclass', num_classes=10)
|
|
metric.to(device)
|
|
|
|
with Live(dir='dvclive/train', report='md') as live:
|
|
live.log_params(params['train'])
|
|
|
|
for epoch in track(range(epochs), 'Epochs'):
|
|
|
|
net.train()
|
|
train_loss = 0.
|
|
for data in train_dataloader:
|
|
inputs, labels = data[0].to(device), data[1].to(device)
|
|
optimizer.zero_grad()
|
|
outputs = net(inputs)
|
|
loss = criterion(outputs, labels)
|
|
train_loss += loss
|
|
loss.backward()
|
|
optimizer.step()
|
|
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
|
|
train_loss /= len(train_dataloader)
|
|
train_acc = metric.compute()
|
|
metric.reset()
|
|
|
|
net.eval()
|
|
with torch.no_grad():
|
|
valid_loss = 0.
|
|
for data in valid_dataloader:
|
|
inputs, labels = data[0].to(device), data[1].to(device)
|
|
outputs = net(inputs)
|
|
loss = criterion(outputs, labels)
|
|
valid_loss += loss
|
|
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
|
|
valid_loss /= len(valid_dataloader)
|
|
valid_acc = metric.compute()
|
|
metric.reset()
|
|
|
|
print(
|
|
f'Epoch {epoch} - train_loss:{train_loss} train_acc:{train_acc} '
|
|
f'valid_loss:{valid_loss} valid_acc:{valid_acc}'
|
|
)
|
|
live.log_param('epoch', epoch)
|
|
live.log_metric('train_loss', float(train_loss.cpu()))
|
|
live.log_metric('train_acc', float(train_acc.cpu()))
|
|
live.log_metric('valid_loss', float(valid_loss.cpu()))
|
|
live.log_metric('valid_acc', float(valid_acc.cpu()))
|
|
live.next_step()
|
|
|
|
torch.save(net, 'model.pt')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
train()
|