fix accuracy computation
12
dvc.lock
|
@ -32,8 +32,8 @@ stages:
|
||||||
nfiles: 60001
|
nfiles: 60001
|
||||||
- path: train.py
|
- path: train.py
|
||||||
hash: md5
|
hash: md5
|
||||||
md5: aabaf1a407badf48c97b14a69b0072ea
|
md5: d32feb4bad10a201fdde2b6424238d16
|
||||||
size: 3420
|
size: 3401
|
||||||
params:
|
params:
|
||||||
params.yaml:
|
params.yaml:
|
||||||
train:
|
train:
|
||||||
|
@ -44,7 +44,7 @@ stages:
|
||||||
outs:
|
outs:
|
||||||
- path: model.pt
|
- path: model.pt
|
||||||
hash: md5
|
hash: md5
|
||||||
md5: 8ead2a7cd52d70b359d3cdc3df5e43e3
|
md5: a0823af18fbf58ff562d804a02dca7d2
|
||||||
size: 102592994
|
size: 102592994
|
||||||
evaluate:
|
evaluate:
|
||||||
cmd: python evaluate.py
|
cmd: python evaluate.py
|
||||||
|
@ -56,11 +56,11 @@ stages:
|
||||||
nfiles: 60001
|
nfiles: 60001
|
||||||
- path: evaluate.py
|
- path: evaluate.py
|
||||||
hash: md5
|
hash: md5
|
||||||
md5: 8a9a2e95a6b64e632a4f2feac62d294b
|
md5: 8eb9acfdc80a5ca6b7bd782643a0997b
|
||||||
size: 1473
|
size: 1483
|
||||||
- path: model.pt
|
- path: model.pt
|
||||||
hash: md5
|
hash: md5
|
||||||
md5: 8ead2a7cd52d70b359d3cdc3df5e43e3
|
md5: a0823af18fbf58ff562d804a02dca7d2
|
||||||
size: 102592994
|
size: 102592994
|
||||||
params:
|
params:
|
||||||
params.yaml:
|
params.yaml:
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
{
|
{
|
||||||
"test_acc": 0.7336809039115906
|
"test_acc": 0.20589999854564667
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
step test_acc
|
step test_acc
|
||||||
0 0.7336809039115906
|
0 0.20589999854564667
|
||||||
|
|
|
|
@ -10,6 +10,6 @@ metrics.json
|
||||||
|
|
||||||
| test_acc |
|
| test_acc |
|
||||||
|------------|
|
|------------|
|
||||||
| 0.733681 |
|
| 0.2059 |
|
||||||
|
|
||||||

|

|
||||||
|
|
Before Width: | Height: | Size: 14 KiB After Width: | Height: | Size: 12 KiB |
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"train_loss": 2.2422571182250977,
|
"train_loss": 2.2407546043395996,
|
||||||
"train_acc": 0.7347080707550049,
|
"train_acc": 0.22537143528461456,
|
||||||
"valid_loss": 2.3184266090393066,
|
"valid_loss": 2.324389696121216,
|
||||||
"valid_acc": 0.7381500005722046,
|
"valid_acc": 0.20746666193008423,
|
||||||
"step": 4
|
"step": 4
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
step train_acc
|
step train_acc
|
||||||
0 0.6712241768836975
|
0 0.0997999981045723
|
||||||
1 0.6976224184036255
|
1 0.13202856481075287
|
||||||
2 0.7157850861549377
|
2 0.16345714032649994
|
||||||
3 0.7277812957763672
|
3 0.19614285230636597
|
||||||
4 0.7347080707550049
|
4 0.22537143528461456
|
||||||
|
|
|
|
@ -1,6 +1,6 @@
|
||||||
step train_loss
|
step train_loss
|
||||||
0 3.0726168155670166
|
0 3.096755027770996
|
||||||
1 2.7409346103668213
|
1 2.749246120452881
|
||||||
2 2.5224294662475586
|
2 2.522728681564331
|
||||||
3 2.364570140838623
|
3 2.3631479740142822
|
||||||
4 2.2422571182250977
|
4 2.2407546043395996
|
||||||
|
|
|
|
@ -1,6 +1,6 @@
|
||||||
step valid_acc
|
step valid_acc
|
||||||
0 0.6918894052505493
|
0 0.10953333228826523
|
||||||
1 0.7131190896034241
|
1 0.13466666638851166
|
||||||
2 0.7261338233947754
|
2 0.15913332998752594
|
||||||
3 0.7339118123054504
|
3 0.18406666815280914
|
||||||
4 0.7381500005722046
|
4 0.20746666193008423
|
||||||
|
|
|
|
@ -1,6 +1,6 @@
|
||||||
step valid_loss
|
step valid_loss
|
||||||
0 2.890321969985962
|
0 2.919710636138916
|
||||||
1 2.669679880142212
|
1 2.68220853805542
|
||||||
2 2.5183584690093994
|
2 2.524806022644043
|
||||||
3 2.4061686992645264
|
3 2.4115779399871826
|
||||||
4 2.3184266090393066
|
4 2.324389696121216
|
||||||
|
|
|
|
@ -10,7 +10,7 @@ metrics.json
|
||||||
|
|
||||||
| train_loss | train_acc | valid_loss | valid_acc | step |
|
| train_loss | train_acc | valid_loss | valid_acc | step |
|
||||||
|--------------|-------------|--------------|-------------|--------|
|
|--------------|-------------|--------------|-------------|--------|
|
||||||
| 2.24226 | 0.734708 | 2.31843 | 0.73815 | 4 |
|
| 2.24075 | 0.225371 | 2.32439 | 0.207467 | 4 |
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|
Before Width: | Height: | Size: 21 KiB After Width: | Height: | Size: 22 KiB |
Before Width: | Height: | Size: 20 KiB After Width: | Height: | Size: 21 KiB |
Before Width: | Height: | Size: 21 KiB After Width: | Height: | Size: 20 KiB |
Before Width: | Height: | Size: 22 KiB After Width: | Height: | Size: 22 KiB |
|
@ -8,7 +8,7 @@ import yaml
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchmetrics.classification import MulticlassAccuracy
|
from torchmetrics.classification import Accuracy
|
||||||
from dvclive import Live
|
from dvclive import Live
|
||||||
|
|
||||||
from utils.dataset import ProcessedDataset
|
from utils.dataset import ProcessedDataset
|
||||||
|
@ -33,7 +33,7 @@ def evaluate(params_path: str = 'params.yaml') -> None:
|
||||||
net.to(device)
|
net.to(device)
|
||||||
net.eval()
|
net.eval()
|
||||||
|
|
||||||
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted')
|
metric = Accuracy(task='multiclass', num_classes=10)
|
||||||
metric.to(device)
|
metric.to(device)
|
||||||
|
|
||||||
with Live(dir='dvclive/eval', report='md') as live:
|
with Live(dir='dvclive/eval', report='md') as live:
|
||||||
|
@ -43,7 +43,7 @@ def evaluate(params_path: str = 'params.yaml') -> None:
|
||||||
for data in test_dataloader:
|
for data in test_dataloader:
|
||||||
inputs, labels = data[0].to(device), data[1].to(device)
|
inputs, labels = data[0].to(device), data[1].to(device)
|
||||||
outputs = net(inputs)
|
outputs = net(inputs)
|
||||||
_ = metric(outputs, labels)
|
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
|
||||||
test_acc = metric.compute()
|
test_acc = metric.compute()
|
||||||
|
|
||||||
print(f'test_acc:{test_acc}')
|
print(f'test_acc:{test_acc}')
|
||||||
|
|
9
train.py
|
@ -10,7 +10,7 @@ import torch
|
||||||
from rich.progress import track
|
from rich.progress import track
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision.models import resnet50
|
from torchvision.models import resnet50
|
||||||
from torchmetrics.classification import MulticlassAccuracy
|
from torchmetrics.classification import Accuracy
|
||||||
from dvclive import Live
|
from dvclive import Live
|
||||||
|
|
||||||
from utils.dataset import ProcessedDataset
|
from utils.dataset import ProcessedDataset
|
||||||
|
@ -45,7 +45,7 @@ def train(params_path: str = 'params.yaml') -> None:
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
|
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
|
||||||
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted')
|
metric = Accuracy(task='multiclass', num_classes=10)
|
||||||
metric.to(device)
|
metric.to(device)
|
||||||
|
|
||||||
with Live(dir='dvclive/train', report='md') as live:
|
with Live(dir='dvclive/train', report='md') as live:
|
||||||
|
@ -63,7 +63,7 @@ def train(params_path: str = 'params.yaml') -> None:
|
||||||
train_loss += loss
|
train_loss += loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
_ = metric(outputs, labels)
|
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
|
||||||
train_loss /= len(train_dataloader)
|
train_loss /= len(train_dataloader)
|
||||||
train_acc = metric.compute()
|
train_acc = metric.compute()
|
||||||
metric.reset()
|
metric.reset()
|
||||||
|
@ -76,7 +76,7 @@ def train(params_path: str = 'params.yaml') -> None:
|
||||||
outputs = net(inputs)
|
outputs = net(inputs)
|
||||||
loss = criterion(outputs, labels)
|
loss = criterion(outputs, labels)
|
||||||
valid_loss += loss
|
valid_loss += loss
|
||||||
_ = metric(outputs, labels)
|
_ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1])
|
||||||
valid_loss /= len(valid_dataloader)
|
valid_loss /= len(valid_dataloader)
|
||||||
valid_acc = metric.compute()
|
valid_acc = metric.compute()
|
||||||
metric.reset()
|
metric.reset()
|
||||||
|
@ -93,7 +93,6 @@ def train(params_path: str = 'params.yaml') -> None:
|
||||||
live.next_step()
|
live.next_step()
|
||||||
|
|
||||||
torch.save(net, 'model.pt')
|
torch.save(net, 'model.pt')
|
||||||
live.log_artifact('model.pt', type='model', name='resnet50')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|