55 lines
1.4 KiB
Python
55 lines
1.4 KiB
Python
# evaluate.py
|
|
#
|
|
# author: deng
|
|
# date : 20231228
|
|
|
|
from pathlib import Path
|
|
import yaml
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from torchmetrics.classification import Accuracy
|
|
from dvclive import Live
|
|
|
|
from utils.dataset import ProcessedDataset
|
|
|
|
|
|
def evaluate(params_path: str = 'params.yaml') -> None:
|
|
"""Evaluate model and save results to eval dir
|
|
|
|
Args:
|
|
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
|
|
"""
|
|
|
|
with open(params_path, encoding='utf-8') as f:
|
|
params = yaml.safe_load(f)
|
|
data_dir = Path(params['evaluate']['data_dir'])
|
|
|
|
test_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('test'))
|
|
test_dataloader = DataLoader(test_dataset, batch_size=64)
|
|
|
|
device = torch.device('mps')
|
|
net = torch.load('model.pt')
|
|
net.to(device)
|
|
net.eval()
|
|
|
|
metric = Accuracy(task='multiclass', num_classes=10)
|
|
metric.to(device)
|
|
|
|
with Live(dir='dvclive/eval', report='md') as live:
|
|
live.log_params(params['evaluate'])
|
|
|
|
with torch.no_grad():
|
|
for data in test_dataloader:
|
|
inputs, labels = data[0].to(device), data[1].to(device)
|
|
outputs = net(inputs)
|
|
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
|
|
test_acc = metric.compute()
|
|
|
|
print(f'test_acc:{test_acc}')
|
|
live.log_metric('test_acc', float(test_acc.cpu()))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
evaluate()
|