fix accuracy computation
12
dvc.lock
|
@ -32,8 +32,8 @@ stages:
|
|||
nfiles: 60001
|
||||
- path: train.py
|
||||
hash: md5
|
||||
md5: aabaf1a407badf48c97b14a69b0072ea
|
||||
size: 3420
|
||||
md5: d32feb4bad10a201fdde2b6424238d16
|
||||
size: 3401
|
||||
params:
|
||||
params.yaml:
|
||||
train:
|
||||
|
@ -44,7 +44,7 @@ stages:
|
|||
outs:
|
||||
- path: model.pt
|
||||
hash: md5
|
||||
md5: 8ead2a7cd52d70b359d3cdc3df5e43e3
|
||||
md5: a0823af18fbf58ff562d804a02dca7d2
|
||||
size: 102592994
|
||||
evaluate:
|
||||
cmd: python evaluate.py
|
||||
|
@ -56,11 +56,11 @@ stages:
|
|||
nfiles: 60001
|
||||
- path: evaluate.py
|
||||
hash: md5
|
||||
md5: 8a9a2e95a6b64e632a4f2feac62d294b
|
||||
size: 1473
|
||||
md5: 8eb9acfdc80a5ca6b7bd782643a0997b
|
||||
size: 1483
|
||||
- path: model.pt
|
||||
hash: md5
|
||||
md5: 8ead2a7cd52d70b359d3cdc3df5e43e3
|
||||
md5: a0823af18fbf58ff562d804a02dca7d2
|
||||
size: 102592994
|
||||
params:
|
||||
params.yaml:
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
{
|
||||
"test_acc": 0.7336809039115906
|
||||
"test_acc": 0.20589999854564667
|
||||
}
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
step test_acc
|
||||
0 0.7336809039115906
|
||||
0 0.20589999854564667
|
||||
|
|
|
|
@ -10,6 +10,6 @@ metrics.json
|
|||
|
||||
| test_acc |
|
||||
|------------|
|
||||
| 0.733681 |
|
||||
| 0.2059 |
|
||||
|
||||

|
||||
|
|
Before Width: | Height: | Size: 14 KiB After Width: | Height: | Size: 12 KiB |
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"train_loss": 2.2422571182250977,
|
||||
"train_acc": 0.7347080707550049,
|
||||
"valid_loss": 2.3184266090393066,
|
||||
"valid_acc": 0.7381500005722046,
|
||||
"train_loss": 2.2407546043395996,
|
||||
"train_acc": 0.22537143528461456,
|
||||
"valid_loss": 2.324389696121216,
|
||||
"valid_acc": 0.20746666193008423,
|
||||
"step": 4
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
step train_acc
|
||||
0 0.6712241768836975
|
||||
1 0.6976224184036255
|
||||
2 0.7157850861549377
|
||||
3 0.7277812957763672
|
||||
4 0.7347080707550049
|
||||
0 0.0997999981045723
|
||||
1 0.13202856481075287
|
||||
2 0.16345714032649994
|
||||
3 0.19614285230636597
|
||||
4 0.22537143528461456
|
||||
|
|
|
|
@ -1,6 +1,6 @@
|
|||
step train_loss
|
||||
0 3.0726168155670166
|
||||
1 2.7409346103668213
|
||||
2 2.5224294662475586
|
||||
3 2.364570140838623
|
||||
4 2.2422571182250977
|
||||
0 3.096755027770996
|
||||
1 2.749246120452881
|
||||
2 2.522728681564331
|
||||
3 2.3631479740142822
|
||||
4 2.2407546043395996
|
||||
|
|
|
|
@ -1,6 +1,6 @@
|
|||
step valid_acc
|
||||
0 0.6918894052505493
|
||||
1 0.7131190896034241
|
||||
2 0.7261338233947754
|
||||
3 0.7339118123054504
|
||||
4 0.7381500005722046
|
||||
0 0.10953333228826523
|
||||
1 0.13466666638851166
|
||||
2 0.15913332998752594
|
||||
3 0.18406666815280914
|
||||
4 0.20746666193008423
|
||||
|
|
|
|
@ -1,6 +1,6 @@
|
|||
step valid_loss
|
||||
0 2.890321969985962
|
||||
1 2.669679880142212
|
||||
2 2.5183584690093994
|
||||
3 2.4061686992645264
|
||||
4 2.3184266090393066
|
||||
0 2.919710636138916
|
||||
1 2.68220853805542
|
||||
2 2.524806022644043
|
||||
3 2.4115779399871826
|
||||
4 2.324389696121216
|
||||
|
|
|
|
@ -10,7 +10,7 @@ metrics.json
|
|||
|
||||
| train_loss | train_acc | valid_loss | valid_acc | step |
|
||||
|--------------|-------------|--------------|-------------|--------|
|
||||
| 2.24226 | 0.734708 | 2.31843 | 0.73815 | 4 |
|
||||
| 2.24075 | 0.225371 | 2.32439 | 0.207467 | 4 |
|
||||
|
||||

|
||||
|
||||
|
|
Before Width: | Height: | Size: 21 KiB After Width: | Height: | Size: 22 KiB |
Before Width: | Height: | Size: 20 KiB After Width: | Height: | Size: 21 KiB |
Before Width: | Height: | Size: 21 KiB After Width: | Height: | Size: 20 KiB |
Before Width: | Height: | Size: 22 KiB After Width: | Height: | Size: 22 KiB |
|
@ -8,7 +8,7 @@ import yaml
|
|||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from torchmetrics.classification import Accuracy
|
||||
from dvclive import Live
|
||||
|
||||
from utils.dataset import ProcessedDataset
|
||||
|
@ -33,7 +33,7 @@ def evaluate(params_path: str = 'params.yaml') -> None:
|
|||
net.to(device)
|
||||
net.eval()
|
||||
|
||||
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted')
|
||||
metric = Accuracy(task='multiclass', num_classes=10)
|
||||
metric.to(device)
|
||||
|
||||
with Live(dir='dvclive/eval', report='md') as live:
|
||||
|
@ -43,7 +43,7 @@ def evaluate(params_path: str = 'params.yaml') -> None:
|
|||
for data in test_dataloader:
|
||||
inputs, labels = data[0].to(device), data[1].to(device)
|
||||
outputs = net(inputs)
|
||||
_ = metric(outputs, labels)
|
||||
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
|
||||
test_acc = metric.compute()
|
||||
|
||||
print(f'test_acc:{test_acc}')
|
||||
|
|
9
train.py
|
@ -10,7 +10,7 @@ import torch
|
|||
from rich.progress import track
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.models import resnet50
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from torchmetrics.classification import Accuracy
|
||||
from dvclive import Live
|
||||
|
||||
from utils.dataset import ProcessedDataset
|
||||
|
@ -45,7 +45,7 @@ def train(params_path: str = 'params.yaml') -> None:
|
|||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
|
||||
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted')
|
||||
metric = Accuracy(task='multiclass', num_classes=10)
|
||||
metric.to(device)
|
||||
|
||||
with Live(dir='dvclive/train', report='md') as live:
|
||||
|
@ -63,7 +63,7 @@ def train(params_path: str = 'params.yaml') -> None:
|
|||
train_loss += loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
_ = metric(outputs, labels)
|
||||
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
|
||||
train_loss /= len(train_dataloader)
|
||||
train_acc = metric.compute()
|
||||
metric.reset()
|
||||
|
@ -76,7 +76,7 @@ def train(params_path: str = 'params.yaml') -> None:
|
|||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
valid_loss += loss
|
||||
_ = metric(outputs, labels)
|
||||
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
|
||||
valid_loss /= len(valid_dataloader)
|
||||
valid_acc = metric.compute()
|
||||
metric.reset()
|
||||
|
@ -93,7 +93,6 @@ def train(params_path: str = 'params.yaml') -> None:
|
|||
live.next_step()
|
||||
|
||||
torch.save(net, 'model.pt')
|
||||
live.log_artifact('model.pt', type='model', name='resnet50')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|