# evaluate.py # # author: deng # date : 20231228 from pathlib import Path import yaml import torch from torch.utils.data import DataLoader from torchmetrics.classification import MulticlassAccuracy from dvclive import Live from utils.dataset import ProcessedDataset def evaluate(params_path: str = 'params.yaml') -> None: """Evaluate model and save results to eval dir Args: params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'. """ with open(params_path, 'r') as f: params = yaml.safe_load(f) data_dir = Path(params['evaluate']['data_dir']) test_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('test')) test_dataloader = DataLoader(test_dataset, batch_size=64) device = torch.device('mps') net = torch.load('model.pt') net.to(device) net.eval() metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted') metric.to(device) with Live(dir='dvclive/eval', report='md') as live: live.log_params(params['evaluate']) for data in test_dataloader: inputs, labels = data[0].to(device), data[1].to(device) outputs = net(inputs) _ = metric(outputs, labels) test_acc = metric.compute() print(f'test_acc:{test_acc}') live.log_metric('test_acc', float(test_acc.cpu())) if __name__ == '__main__': evaluate()