test_dvc/evaluate.py

55 lines
1.4 KiB
Python

# evaluate.py
#
# author: deng
# date : 20231228
from pathlib import Path
import yaml
import torch
from torch.utils.data import DataLoader
from torchmetrics.classification import Accuracy
from dvclive import Live
from utils.dataset import ProcessedDataset
def evaluate(params_path: str = 'params.yaml') -> None:
"""Evaluate model and save results to eval dir
Args:
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
"""
with open(params_path, encoding='utf-8') as f:
params = yaml.safe_load(f)
data_dir = Path(params['evaluate']['data_dir'])
test_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('test'))
test_dataloader = DataLoader(test_dataset, batch_size=64)
device = torch.device('mps')
net = torch.load('model.pt')
net.to(device)
net.eval()
metric = Accuracy(task='multiclass', num_classes=10)
metric.to(device)
with Live(dir='dvclive/eval', report='md') as live:
live.log_params(params['evaluate'])
with torch.no_grad():
for data in test_dataloader:
inputs, labels = data[0].to(device), data[1].to(device)
outputs = net(inputs)
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
test_acc = metric.compute()
print(f'test_acc:{test_acc}')
live.log_metric('test_acc', float(test_acc.cpu()))
if __name__ == '__main__':
evaluate()