# evaluate.py # # author: deng # date : 20231228 from pathlib import Path import yaml import torch from torch.utils.data import DataLoader from torchmetrics.classification import Accuracy from dvclive import Live from utils.dataset import ProcessedDataset def evaluate(params_path: str = 'params.yaml') -> None: """Evaluate model and save results to eval dir Args: params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'. """ with open(params_path, encoding='utf-8') as f: params = yaml.safe_load(f) data_dir = Path(params['evaluate']['data_dir']) test_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('test')) test_dataloader = DataLoader(test_dataset, batch_size=64) device = torch.device('mps') net = torch.load('model.pt') net.to(device) net.eval() metric = Accuracy(task='multiclass', num_classes=10) metric.to(device) with Live(dir='dvclive/eval', report='md') as live: live.log_params(params['evaluate']) with torch.no_grad(): for data in test_dataloader: inputs, labels = data[0].to(device), data[1].to(device) outputs = net(inputs) _ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1]) test_acc = metric.compute() print(f'test_acc:{test_acc}') live.log_metric('test_acc', float(test_acc.cpu())) if __name__ == '__main__': evaluate()