fix accuracy computation

This commit is contained in:
deng 2024-01-15 21:09:58 +08:00
parent 9404557763
commit 391c70753d
17 changed files with 41 additions and 42 deletions

View File

@ -32,8 +32,8 @@ stages:
nfiles: 60001
- path: train.py
hash: md5
md5: aabaf1a407badf48c97b14a69b0072ea
size: 3420
md5: d32feb4bad10a201fdde2b6424238d16
size: 3401
params:
params.yaml:
train:
@ -44,7 +44,7 @@ stages:
outs:
- path: model.pt
hash: md5
md5: 8ead2a7cd52d70b359d3cdc3df5e43e3
md5: a0823af18fbf58ff562d804a02dca7d2
size: 102592994
evaluate:
cmd: python evaluate.py
@ -56,11 +56,11 @@ stages:
nfiles: 60001
- path: evaluate.py
hash: md5
md5: 8a9a2e95a6b64e632a4f2feac62d294b
size: 1473
md5: 8eb9acfdc80a5ca6b7bd782643a0997b
size: 1483
- path: model.pt
hash: md5
md5: 8ead2a7cd52d70b359d3cdc3df5e43e3
md5: a0823af18fbf58ff562d804a02dca7d2
size: 102592994
params:
params.yaml:

View File

@ -1,3 +1,3 @@
{
"test_acc": 0.7336809039115906
"test_acc": 0.20589999854564667
}

View File

@ -1,2 +1,2 @@
step test_acc
0 0.7336809039115906
0 0.20589999854564667

1 step test_acc
2 0 0.7336809039115906 0.20589999854564667

View File

@ -10,6 +10,6 @@ metrics.json
| test_acc |
|------------|
| 0.733681 |
| 0.2059 |
![static/test_acc](static/test_acc.png)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

After

Width:  |  Height:  |  Size: 12 KiB

View File

@ -1,7 +1,7 @@
{
"train_loss": 2.2422571182250977,
"train_acc": 0.7347080707550049,
"valid_loss": 2.3184266090393066,
"valid_acc": 0.7381500005722046,
"train_loss": 2.2407546043395996,
"train_acc": 0.22537143528461456,
"valid_loss": 2.324389696121216,
"valid_acc": 0.20746666193008423,
"step": 4
}

View File

@ -1,6 +1,6 @@
step train_acc
0 0.6712241768836975
1 0.6976224184036255
2 0.7157850861549377
3 0.7277812957763672
4 0.7347080707550049
0 0.0997999981045723
1 0.13202856481075287
2 0.16345714032649994
3 0.19614285230636597
4 0.22537143528461456

1 step train_acc
2 0 0.6712241768836975 0.0997999981045723
3 1 0.6976224184036255 0.13202856481075287
4 2 0.7157850861549377 0.16345714032649994
5 3 0.7277812957763672 0.19614285230636597
6 4 0.7347080707550049 0.22537143528461456

View File

@ -1,6 +1,6 @@
step train_loss
0 3.0726168155670166
1 2.7409346103668213
2 2.5224294662475586
3 2.364570140838623
4 2.2422571182250977
0 3.096755027770996
1 2.749246120452881
2 2.522728681564331
3 2.3631479740142822
4 2.2407546043395996

1 step train_loss
2 0 3.0726168155670166 3.096755027770996
3 1 2.7409346103668213 2.749246120452881
4 2 2.5224294662475586 2.522728681564331
5 3 2.364570140838623 2.3631479740142822
6 4 2.2422571182250977 2.2407546043395996

View File

@ -1,6 +1,6 @@
step valid_acc
0 0.6918894052505493
1 0.7131190896034241
2 0.7261338233947754
3 0.7339118123054504
4 0.7381500005722046
0 0.10953333228826523
1 0.13466666638851166
2 0.15913332998752594
3 0.18406666815280914
4 0.20746666193008423

1 step valid_acc
2 0 0.6918894052505493 0.10953333228826523
3 1 0.7131190896034241 0.13466666638851166
4 2 0.7261338233947754 0.15913332998752594
5 3 0.7339118123054504 0.18406666815280914
6 4 0.7381500005722046 0.20746666193008423

View File

@ -1,6 +1,6 @@
step valid_loss
0 2.890321969985962
1 2.669679880142212
2 2.5183584690093994
3 2.4061686992645264
4 2.3184266090393066
0 2.919710636138916
1 2.68220853805542
2 2.524806022644043
3 2.4115779399871826
4 2.324389696121216

1 step valid_loss
2 0 2.890321969985962 2.919710636138916
3 1 2.669679880142212 2.68220853805542
4 2 2.5183584690093994 2.524806022644043
5 3 2.4061686992645264 2.4115779399871826
6 4 2.3184266090393066 2.324389696121216

View File

@ -10,7 +10,7 @@ metrics.json
| train_loss | train_acc | valid_loss | valid_acc | step |
|--------------|-------------|--------------|-------------|--------|
| 2.24226 | 0.734708 | 2.31843 | 0.73815 | 4 |
| 2.24075 | 0.225371 | 2.32439 | 0.207467 | 4 |
![static/valid_loss](static/valid_loss.png)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

After

Width:  |  Height:  |  Size: 22 KiB

View File

@ -8,7 +8,7 @@ import yaml
import torch
from torch.utils.data import DataLoader
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.classification import Accuracy
from dvclive import Live
from utils.dataset import ProcessedDataset
@ -33,7 +33,7 @@ def evaluate(params_path: str = 'params.yaml') -> None:
net.to(device)
net.eval()
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted')
metric = Accuracy(task='multiclass', num_classes=10)
metric.to(device)
with Live(dir='dvclive/eval', report='md') as live:
@ -43,7 +43,7 @@ def evaluate(params_path: str = 'params.yaml') -> None:
for data in test_dataloader:
inputs, labels = data[0].to(device), data[1].to(device)
outputs = net(inputs)
_ = metric(outputs, labels)
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
test_acc = metric.compute()
print(f'test_acc:{test_acc}')

View File

@ -10,7 +10,7 @@ import torch
from rich.progress import track
from torch.utils.data import DataLoader
from torchvision.models import resnet50
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.classification import Accuracy
from dvclive import Live
from utils.dataset import ProcessedDataset
@ -45,7 +45,7 @@ def train(params_path: str = 'params.yaml') -> None:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted')
metric = Accuracy(task='multiclass', num_classes=10)
metric.to(device)
with Live(dir='dvclive/train', report='md') as live:
@ -63,7 +63,7 @@ def train(params_path: str = 'params.yaml') -> None:
train_loss += loss
loss.backward()
optimizer.step()
_ = metric(outputs, labels)
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
train_loss /= len(train_dataloader)
train_acc = metric.compute()
metric.reset()
@ -76,7 +76,7 @@ def train(params_path: str = 'params.yaml') -> None:
outputs = net(inputs)
loss = criterion(outputs, labels)
valid_loss += loss
_ = metric(outputs, labels)
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
valid_loss /= len(valid_dataloader)
valid_acc = metric.compute()
metric.reset()
@ -93,7 +93,6 @@ def train(params_path: str = 'params.yaml') -> None:
live.next_step()
torch.save(net, 'model.pt')
live.log_artifact('model.pt', type='model', name='resnet50')
if __name__ == '__main__':