74 lines
3.1 KiB
Python
74 lines
3.1 KiB
Python
# eval.py
|
|
#
|
|
# author: deng
|
|
# date : 20260618
|
|
|
|
import torch
|
|
from dvclive import Live
|
|
from torch.utils.data import DataLoader
|
|
from torchmetrics import MetricCollection
|
|
from torchmetrics.classification import Accuracy, ConfusionMatrix, F1Score, Precision, Recall
|
|
|
|
from quickdraw_bot.utils.dataset import QuickDrawDataset
|
|
from quickdraw_bot.utils.utils import load_config
|
|
|
|
|
|
class Eval:
|
|
def __init__(self, config_path: str = './assets/config.yaml'):
|
|
self.config = load_config(config_path)['eval']
|
|
self._device = torch.device('mps' if torch.mps.is_available() else 'cpu')
|
|
|
|
def _get_dataloader(self):
|
|
test_dataset = QuickDrawDataset(data_dir=self.config['test_data_dir'], return_cate_name=False)
|
|
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
|
return test_dataloader
|
|
|
|
def _get_model(self) -> tuple[torch.nn.Module, int]:
|
|
model = torch.load(self.config['model_path'], map_location=self._device, weights_only=False)
|
|
model.eval()
|
|
num_classes = [m for m in model.modules() if isinstance(m, torch.nn.Linear)][-1].out_features
|
|
return model, num_classes
|
|
|
|
def _get_metrics(self, num_classes: int) -> tuple[MetricCollection, ConfusionMatrix]:
|
|
metric_collection = MetricCollection(
|
|
[
|
|
Accuracy(task='multiclass', num_classes=num_classes, top_k=1),
|
|
Precision(task='multiclass', num_classes=num_classes, average='macro'),
|
|
Recall(task='multiclass', num_classes=num_classes, average='macro'),
|
|
F1Score(task='multiclass', num_classes=num_classes, average='macro'),
|
|
]
|
|
).to(self._device)
|
|
confusion_matrix = ConfusionMatrix(
|
|
task='multiclass',
|
|
threshold=0.5,
|
|
num_classes=num_classes,
|
|
).to(self._device)
|
|
return metric_collection, confusion_matrix
|
|
|
|
def run(self):
|
|
test_dataloader = self._get_dataloader()
|
|
model, num_classes = self._get_model()
|
|
metrics, confusion_matrix = self._get_metrics(num_classes=num_classes)
|
|
|
|
with Live(dir='./doc/exp/eval', dvcyaml='./assets/dvc.yaml') as live:
|
|
metrics.reset()
|
|
with torch.no_grad():
|
|
for inputs, targets in test_dataloader:
|
|
inputs = inputs.to(self._device)
|
|
targets = targets.to(self._device)
|
|
outputs = model(inputs)
|
|
metrics.update(outputs, targets)
|
|
confusion_matrix.update(outputs, targets)
|
|
test_metrics = metrics.compute()
|
|
confusion_matrix_fig, _ = confusion_matrix.plot()
|
|
|
|
live.log_metric('test/accuracy', test_metrics['MulticlassAccuracy'].item())
|
|
live.log_metric('test/precision', test_metrics['MulticlassPrecision'].item())
|
|
live.log_metric('test/recall', test_metrics['MulticlassRecall'].item())
|
|
live.log_metric('test/f1', test_metrics['MulticlassF1Score'].item())
|
|
live.log_image('test_confusion_matrix.png', confusion_matrix_fig)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
Eval().run()
|