fix accuracy computation

This commit is contained in:
deng 2024-01-15 21:09:58 +08:00
parent 9404557763
commit 391c70753d
17 changed files with 41 additions and 42 deletions

View File

@ -32,8 +32,8 @@ stages:
nfiles: 60001 nfiles: 60001
- path: train.py - path: train.py
hash: md5 hash: md5
md5: aabaf1a407badf48c97b14a69b0072ea md5: d32feb4bad10a201fdde2b6424238d16
size: 3420 size: 3401
params: params:
params.yaml: params.yaml:
train: train:
@ -44,7 +44,7 @@ stages:
outs: outs:
- path: model.pt - path: model.pt
hash: md5 hash: md5
md5: 8ead2a7cd52d70b359d3cdc3df5e43e3 md5: a0823af18fbf58ff562d804a02dca7d2
size: 102592994 size: 102592994
evaluate: evaluate:
cmd: python evaluate.py cmd: python evaluate.py
@ -56,11 +56,11 @@ stages:
nfiles: 60001 nfiles: 60001
- path: evaluate.py - path: evaluate.py
hash: md5 hash: md5
md5: 8a9a2e95a6b64e632a4f2feac62d294b md5: 8eb9acfdc80a5ca6b7bd782643a0997b
size: 1473 size: 1483
- path: model.pt - path: model.pt
hash: md5 hash: md5
md5: 8ead2a7cd52d70b359d3cdc3df5e43e3 md5: a0823af18fbf58ff562d804a02dca7d2
size: 102592994 size: 102592994
params: params:
params.yaml: params.yaml:

View File

@ -1,3 +1,3 @@
{ {
"test_acc": 0.7336809039115906 "test_acc": 0.20589999854564667
} }

View File

@ -1,2 +1,2 @@
step test_acc step test_acc
0 0.7336809039115906 0 0.20589999854564667

1 step test_acc
2 0 0.7336809039115906 0.20589999854564667

View File

@ -10,6 +10,6 @@ metrics.json
| test_acc | | test_acc |
|------------| |------------|
| 0.733681 | | 0.2059 |
![static/test_acc](static/test_acc.png) ![static/test_acc](static/test_acc.png)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

After

Width:  |  Height:  |  Size: 12 KiB

View File

@ -1,7 +1,7 @@
{ {
"train_loss": 2.2422571182250977, "train_loss": 2.2407546043395996,
"train_acc": 0.7347080707550049, "train_acc": 0.22537143528461456,
"valid_loss": 2.3184266090393066, "valid_loss": 2.324389696121216,
"valid_acc": 0.7381500005722046, "valid_acc": 0.20746666193008423,
"step": 4 "step": 4
} }

View File

@ -1,6 +1,6 @@
step train_acc step train_acc
0 0.6712241768836975 0 0.0997999981045723
1 0.6976224184036255 1 0.13202856481075287
2 0.7157850861549377 2 0.16345714032649994
3 0.7277812957763672 3 0.19614285230636597
4 0.7347080707550049 4 0.22537143528461456

1 step train_acc
2 0 0.6712241768836975 0.0997999981045723
3 1 0.6976224184036255 0.13202856481075287
4 2 0.7157850861549377 0.16345714032649994
5 3 0.7277812957763672 0.19614285230636597
6 4 0.7347080707550049 0.22537143528461456

View File

@ -1,6 +1,6 @@
step train_loss step train_loss
0 3.0726168155670166 0 3.096755027770996
1 2.7409346103668213 1 2.749246120452881
2 2.5224294662475586 2 2.522728681564331
3 2.364570140838623 3 2.3631479740142822
4 2.2422571182250977 4 2.2407546043395996

1 step train_loss
2 0 3.0726168155670166 3.096755027770996
3 1 2.7409346103668213 2.749246120452881
4 2 2.5224294662475586 2.522728681564331
5 3 2.364570140838623 2.3631479740142822
6 4 2.2422571182250977 2.2407546043395996

View File

@ -1,6 +1,6 @@
step valid_acc step valid_acc
0 0.6918894052505493 0 0.10953333228826523
1 0.7131190896034241 1 0.13466666638851166
2 0.7261338233947754 2 0.15913332998752594
3 0.7339118123054504 3 0.18406666815280914
4 0.7381500005722046 4 0.20746666193008423

1 step valid_acc
2 0 0.6918894052505493 0.10953333228826523
3 1 0.7131190896034241 0.13466666638851166
4 2 0.7261338233947754 0.15913332998752594
5 3 0.7339118123054504 0.18406666815280914
6 4 0.7381500005722046 0.20746666193008423

View File

@ -1,6 +1,6 @@
step valid_loss step valid_loss
0 2.890321969985962 0 2.919710636138916
1 2.669679880142212 1 2.68220853805542
2 2.5183584690093994 2 2.524806022644043
3 2.4061686992645264 3 2.4115779399871826
4 2.3184266090393066 4 2.324389696121216

1 step valid_loss
2 0 2.890321969985962 2.919710636138916
3 1 2.669679880142212 2.68220853805542
4 2 2.5183584690093994 2.524806022644043
5 3 2.4061686992645264 2.4115779399871826
6 4 2.3184266090393066 2.324389696121216

View File

@ -10,7 +10,7 @@ metrics.json
| train_loss | train_acc | valid_loss | valid_acc | step | | train_loss | train_acc | valid_loss | valid_acc | step |
|--------------|-------------|--------------|-------------|--------| |--------------|-------------|--------------|-------------|--------|
| 2.24226 | 0.734708 | 2.31843 | 0.73815 | 4 | | 2.24075 | 0.225371 | 2.32439 | 0.207467 | 4 |
![static/valid_loss](static/valid_loss.png) ![static/valid_loss](static/valid_loss.png)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

After

Width:  |  Height:  |  Size: 22 KiB

View File

@ -8,7 +8,7 @@ import yaml
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchmetrics.classification import MulticlassAccuracy from torchmetrics.classification import Accuracy
from dvclive import Live from dvclive import Live
from utils.dataset import ProcessedDataset from utils.dataset import ProcessedDataset
@ -33,7 +33,7 @@ def evaluate(params_path: str = 'params.yaml') -> None:
net.to(device) net.to(device)
net.eval() net.eval()
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted') metric = Accuracy(task='multiclass', num_classes=10)
metric.to(device) metric.to(device)
with Live(dir='dvclive/eval', report='md') as live: with Live(dir='dvclive/eval', report='md') as live:
@ -43,7 +43,7 @@ def evaluate(params_path: str = 'params.yaml') -> None:
for data in test_dataloader: for data in test_dataloader:
inputs, labels = data[0].to(device), data[1].to(device) inputs, labels = data[0].to(device), data[1].to(device)
outputs = net(inputs) outputs = net(inputs)
_ = metric(outputs, labels) _ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
test_acc = metric.compute() test_acc = metric.compute()
print(f'test_acc:{test_acc}') print(f'test_acc:{test_acc}')

View File

@ -10,7 +10,7 @@ import torch
from rich.progress import track from rich.progress import track
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.models import resnet50 from torchvision.models import resnet50
from torchmetrics.classification import MulticlassAccuracy from torchmetrics.classification import Accuracy
from dvclive import Live from dvclive import Live
from utils.dataset import ProcessedDataset from utils.dataset import ProcessedDataset
@ -45,7 +45,7 @@ def train(params_path: str = 'params.yaml') -> None:
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate) optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted') metric = Accuracy(task='multiclass', num_classes=10)
metric.to(device) metric.to(device)
with Live(dir='dvclive/train', report='md') as live: with Live(dir='dvclive/train', report='md') as live:
@ -63,7 +63,7 @@ def train(params_path: str = 'params.yaml') -> None:
train_loss += loss train_loss += loss
loss.backward() loss.backward()
optimizer.step() optimizer.step()
_ = metric(outputs, labels) _ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
train_loss /= len(train_dataloader) train_loss /= len(train_dataloader)
train_acc = metric.compute() train_acc = metric.compute()
metric.reset() metric.reset()
@ -76,7 +76,7 @@ def train(params_path: str = 'params.yaml') -> None:
outputs = net(inputs) outputs = net(inputs)
loss = criterion(outputs, labels) loss = criterion(outputs, labels)
valid_loss += loss valid_loss += loss
_ = metric(outputs, labels) _ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
valid_loss /= len(valid_dataloader) valid_loss /= len(valid_dataloader)
valid_acc = metric.compute() valid_acc = metric.compute()
metric.reset() metric.reset()
@ -93,7 +93,6 @@ def train(params_path: str = 'params.yaml') -> None:
live.next_step() live.next_step()
torch.save(net, 'model.pt') torch.save(net, 'model.pt')
live.log_artifact('model.pt', type='model', name='resnet50')
if __name__ == '__main__': if __name__ == '__main__':