diff --git a/dvc.lock b/dvc.lock index 6729b23..5ad4755 100644 --- a/dvc.lock +++ b/dvc.lock @@ -32,8 +32,8 @@ stages: nfiles: 60001 - path: train.py hash: md5 - md5: aabaf1a407badf48c97b14a69b0072ea - size: 3420 + md5: d32feb4bad10a201fdde2b6424238d16 + size: 3401 params: params.yaml: train: @@ -44,7 +44,7 @@ stages: outs: - path: model.pt hash: md5 - md5: 8ead2a7cd52d70b359d3cdc3df5e43e3 + md5: a0823af18fbf58ff562d804a02dca7d2 size: 102592994 evaluate: cmd: python evaluate.py @@ -56,11 +56,11 @@ stages: nfiles: 60001 - path: evaluate.py hash: md5 - md5: 8a9a2e95a6b64e632a4f2feac62d294b - size: 1473 + md5: 8eb9acfdc80a5ca6b7bd782643a0997b + size: 1483 - path: model.pt hash: md5 - md5: 8ead2a7cd52d70b359d3cdc3df5e43e3 + md5: a0823af18fbf58ff562d804a02dca7d2 size: 102592994 params: params.yaml: diff --git a/dvclive/eval/metrics.json b/dvclive/eval/metrics.json index 5b1327d..58132b6 100644 --- a/dvclive/eval/metrics.json +++ b/dvclive/eval/metrics.json @@ -1,3 +1,3 @@ { - "test_acc": 0.7336809039115906 + "test_acc": 0.20589999854564667 } diff --git a/dvclive/eval/plots/metrics/test_acc.tsv b/dvclive/eval/plots/metrics/test_acc.tsv index 6d4a371..ca9e105 100644 --- a/dvclive/eval/plots/metrics/test_acc.tsv +++ b/dvclive/eval/plots/metrics/test_acc.tsv @@ -1,2 +1,2 @@ step test_acc -0 0.7336809039115906 +0 0.20589999854564667 diff --git a/dvclive/eval/report.md b/dvclive/eval/report.md index 678f542..778881c 100644 --- a/dvclive/eval/report.md +++ b/dvclive/eval/report.md @@ -10,6 +10,6 @@ metrics.json | test_acc | |------------| -| 0.733681 | +| 0.2059 | ![static/test_acc](static/test_acc.png) diff --git a/dvclive/eval/static/test_acc.png b/dvclive/eval/static/test_acc.png index 27643cd..2b4f11e 100644 Binary files a/dvclive/eval/static/test_acc.png and b/dvclive/eval/static/test_acc.png differ diff --git a/dvclive/train/metrics.json b/dvclive/train/metrics.json index 9a0ddb1..fd3d620 100644 --- a/dvclive/train/metrics.json +++ b/dvclive/train/metrics.json @@ -1,7 +1,7 @@ { - "train_loss": 2.2422571182250977, - "train_acc": 0.7347080707550049, - "valid_loss": 2.3184266090393066, - "valid_acc": 0.7381500005722046, + "train_loss": 2.2407546043395996, + "train_acc": 0.22537143528461456, + "valid_loss": 2.324389696121216, + "valid_acc": 0.20746666193008423, "step": 4 } diff --git a/dvclive/train/plots/metrics/train_acc.tsv b/dvclive/train/plots/metrics/train_acc.tsv index 3535742..50940e3 100644 --- a/dvclive/train/plots/metrics/train_acc.tsv +++ b/dvclive/train/plots/metrics/train_acc.tsv @@ -1,6 +1,6 @@ step train_acc -0 0.6712241768836975 -1 0.6976224184036255 -2 0.7157850861549377 -3 0.7277812957763672 -4 0.7347080707550049 +0 0.0997999981045723 +1 0.13202856481075287 +2 0.16345714032649994 +3 0.19614285230636597 +4 0.22537143528461456 diff --git a/dvclive/train/plots/metrics/train_loss.tsv b/dvclive/train/plots/metrics/train_loss.tsv index 26010c9..1045549 100644 --- a/dvclive/train/plots/metrics/train_loss.tsv +++ b/dvclive/train/plots/metrics/train_loss.tsv @@ -1,6 +1,6 @@ step train_loss -0 3.0726168155670166 -1 2.7409346103668213 -2 2.5224294662475586 -3 2.364570140838623 -4 2.2422571182250977 +0 3.096755027770996 +1 2.749246120452881 +2 2.522728681564331 +3 2.3631479740142822 +4 2.2407546043395996 diff --git a/dvclive/train/plots/metrics/valid_acc.tsv b/dvclive/train/plots/metrics/valid_acc.tsv index 9818ddd..53a9dab 100644 --- a/dvclive/train/plots/metrics/valid_acc.tsv +++ b/dvclive/train/plots/metrics/valid_acc.tsv @@ -1,6 +1,6 @@ step valid_acc -0 0.6918894052505493 -1 0.7131190896034241 -2 0.7261338233947754 -3 0.7339118123054504 -4 0.7381500005722046 +0 0.10953333228826523 +1 0.13466666638851166 +2 0.15913332998752594 +3 0.18406666815280914 +4 0.20746666193008423 diff --git a/dvclive/train/plots/metrics/valid_loss.tsv b/dvclive/train/plots/metrics/valid_loss.tsv index a556569..9747623 100644 --- a/dvclive/train/plots/metrics/valid_loss.tsv +++ b/dvclive/train/plots/metrics/valid_loss.tsv @@ -1,6 +1,6 @@ step valid_loss -0 2.890321969985962 -1 2.669679880142212 -2 2.5183584690093994 -3 2.4061686992645264 -4 2.3184266090393066 +0 2.919710636138916 +1 2.68220853805542 +2 2.524806022644043 +3 2.4115779399871826 +4 2.324389696121216 diff --git a/dvclive/train/report.md b/dvclive/train/report.md index e584bb1..fefc80c 100644 --- a/dvclive/train/report.md +++ b/dvclive/train/report.md @@ -10,7 +10,7 @@ metrics.json | train_loss | train_acc | valid_loss | valid_acc | step | |--------------|-------------|--------------|-------------|--------| -| 2.24226 | 0.734708 | 2.31843 | 0.73815 | 4 | +| 2.24075 | 0.225371 | 2.32439 | 0.207467 | 4 | ![static/valid_loss](static/valid_loss.png) diff --git a/dvclive/train/static/train_acc.png b/dvclive/train/static/train_acc.png index cda0996..f942939 100644 Binary files a/dvclive/train/static/train_acc.png and b/dvclive/train/static/train_acc.png differ diff --git a/dvclive/train/static/train_loss.png b/dvclive/train/static/train_loss.png index d2a49db..9a99624 100644 Binary files a/dvclive/train/static/train_loss.png and b/dvclive/train/static/train_loss.png differ diff --git a/dvclive/train/static/valid_acc.png b/dvclive/train/static/valid_acc.png index 6cc9d85..8fa2696 100644 Binary files a/dvclive/train/static/valid_acc.png and b/dvclive/train/static/valid_acc.png differ diff --git a/dvclive/train/static/valid_loss.png b/dvclive/train/static/valid_loss.png index 7cd6fd3..132fd29 100644 Binary files a/dvclive/train/static/valid_loss.png and b/dvclive/train/static/valid_loss.png differ diff --git a/evaluate.py b/evaluate.py index 772cfca..ceca635 100644 --- a/evaluate.py +++ b/evaluate.py @@ -8,7 +8,7 @@ import yaml import torch from torch.utils.data import DataLoader -from torchmetrics.classification import MulticlassAccuracy +from torchmetrics.classification import Accuracy from dvclive import Live from utils.dataset import ProcessedDataset @@ -33,7 +33,7 @@ def evaluate(params_path: str = 'params.yaml') -> None: net.to(device) net.eval() - metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted') + metric = Accuracy(task='multiclass', num_classes=10) metric.to(device) with Live(dir='dvclive/eval', report='md') as live: @@ -43,7 +43,7 @@ def evaluate(params_path: str = 'params.yaml') -> None: for data in test_dataloader: inputs, labels = data[0].to(device), data[1].to(device) outputs = net(inputs) - _ = metric(outputs, labels) + _ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1]) test_acc = metric.compute() print(f'test_acc:{test_acc}') diff --git a/train.py b/train.py index 8bb54f8..0f3e68d 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,7 @@ import torch from rich.progress import track from torch.utils.data import DataLoader from torchvision.models import resnet50 -from torchmetrics.classification import MulticlassAccuracy +from torchmetrics.classification import Accuracy from dvclive import Live from utils.dataset import ProcessedDataset @@ -45,7 +45,7 @@ def train(params_path: str = 'params.yaml') -> None: criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate) - metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted') + metric = Accuracy(task='multiclass', num_classes=10) metric.to(device) with Live(dir='dvclive/train', report='md') as live: @@ -63,7 +63,7 @@ def train(params_path: str = 'params.yaml') -> None: train_loss += loss loss.backward() optimizer.step() - _ = metric(outputs, labels) + _ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1]) train_loss /= len(train_dataloader) train_acc = metric.compute() metric.reset() @@ -76,7 +76,7 @@ def train(params_path: str = 'params.yaml') -> None: outputs = net(inputs) loss = criterion(outputs, labels) valid_loss += loss - _ = metric(outputs, labels) + _ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1]) valid_loss /= len(valid_dataloader) valid_acc = metric.compute() metric.reset() @@ -93,7 +93,6 @@ def train(params_path: str = 'params.yaml') -> None: live.next_step() torch.save(net, 'model.pt') - live.log_artifact('model.pt', type='model', name='resnet50') if __name__ == '__main__':