Files
quickdraw_bot/quickdraw_bot/eval.py
2026-06-18 13:51:36 +08:00

74 lines
3.1 KiB
Python

# eval.py
#
# author: deng
# date : 20260618
import torch
from dvclive import Live
from torch.utils.data import DataLoader
from torchmetrics import MetricCollection
from torchmetrics.classification import Accuracy, ConfusionMatrix, F1Score, Precision, Recall
from quickdraw_bot.utils.dataset import QuickDrawDataset
from quickdraw_bot.utils.utils import load_config
class Eval:
def __init__(self, config_path: str = './assets/config.yaml'):
self.config = load_config(config_path)['eval']
self._device = torch.device('mps' if torch.mps.is_available() else 'cpu')
def _get_dataloader(self):
test_dataset = QuickDrawDataset(data_dir=self.config['test_data_dir'], return_cate_name=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
return test_dataloader
def _get_model(self) -> tuple[torch.nn.Module, int]:
model = torch.load(self.config['model_path'], map_location=self._device, weights_only=False)
model.eval()
num_classes = [m for m in model.modules() if isinstance(m, torch.nn.Linear)][-1].out_features
return model, num_classes
def _get_metrics(self, num_classes: int) -> tuple[MetricCollection, ConfusionMatrix]:
metric_collection = MetricCollection(
[
Accuracy(task='multiclass', num_classes=num_classes, top_k=1),
Precision(task='multiclass', num_classes=num_classes, average='macro'),
Recall(task='multiclass', num_classes=num_classes, average='macro'),
F1Score(task='multiclass', num_classes=num_classes, average='macro'),
]
).to(self._device)
confusion_matrix = ConfusionMatrix(
task='multiclass',
threshold=0.5,
num_classes=num_classes,
).to(self._device)
return metric_collection, confusion_matrix
def run(self):
test_dataloader = self._get_dataloader()
model, num_classes = self._get_model()
metrics, confusion_matrix = self._get_metrics(num_classes=num_classes)
with Live(dir='./doc/exp/eval', dvcyaml='./assets/dvc.yaml') as live:
metrics.reset()
with torch.no_grad():
for inputs, targets in test_dataloader:
inputs = inputs.to(self._device)
targets = targets.to(self._device)
outputs = model(inputs)
metrics.update(outputs, targets)
confusion_matrix.update(outputs, targets)
test_metrics = metrics.compute()
confusion_matrix_fig, _ = confusion_matrix.plot()
live.log_metric('test/accuracy', test_metrics['MulticlassAccuracy'].item())
live.log_metric('test/precision', test_metrics['MulticlassPrecision'].item())
live.log_metric('test/recall', test_metrics['MulticlassRecall'].item())
live.log_metric('test/f1', test_metrics['MulticlassF1Score'].item())
live.log_image('test_confusion_matrix.png', confusion_matrix_fig)
if __name__ == '__main__':
Eval().run()