fix accuracy computation
							
								
								
									
										12
									
								
								dvc.lock
									
									
									
									
									
								
							
							
						
						| @ -32,8 +32,8 @@ stages: | ||||
|       nfiles: 60001 | ||||
|     - path: train.py | ||||
|       hash: md5 | ||||
|       md5: aabaf1a407badf48c97b14a69b0072ea | ||||
|       size: 3420 | ||||
|       md5: d32feb4bad10a201fdde2b6424238d16 | ||||
|       size: 3401 | ||||
|     params: | ||||
|       params.yaml: | ||||
|         train: | ||||
| @ -44,7 +44,7 @@ stages: | ||||
|     outs: | ||||
|     - path: model.pt | ||||
|       hash: md5 | ||||
|       md5: 8ead2a7cd52d70b359d3cdc3df5e43e3 | ||||
|       md5: a0823af18fbf58ff562d804a02dca7d2 | ||||
|       size: 102592994 | ||||
|   evaluate: | ||||
|     cmd: python evaluate.py | ||||
| @ -56,11 +56,11 @@ stages: | ||||
|       nfiles: 60001 | ||||
|     - path: evaluate.py | ||||
|       hash: md5 | ||||
|       md5: 8a9a2e95a6b64e632a4f2feac62d294b | ||||
|       size: 1473 | ||||
|       md5: 8eb9acfdc80a5ca6b7bd782643a0997b | ||||
|       size: 1483 | ||||
|     - path: model.pt | ||||
|       hash: md5 | ||||
|       md5: 8ead2a7cd52d70b359d3cdc3df5e43e3 | ||||
|       md5: a0823af18fbf58ff562d804a02dca7d2 | ||||
|       size: 102592994 | ||||
|     params: | ||||
|       params.yaml: | ||||
|  | ||||
| @ -1,3 +1,3 @@ | ||||
| { | ||||
|     "test_acc": 0.7336809039115906 | ||||
|     "test_acc": 0.20589999854564667 | ||||
| } | ||||
|  | ||||
| @ -1,2 +1,2 @@ | ||||
| step	test_acc | ||||
| 0	0.7336809039115906 | ||||
| 0	0.20589999854564667 | ||||
|  | ||||
| 
 | 
| @ -10,6 +10,6 @@ metrics.json | ||||
|  | ||||
| |   test_acc | | ||||
| |------------| | ||||
| |   0.733681 | | ||||
| |     0.2059 | | ||||
|  | ||||
|  | ||||
|  | ||||
| Before Width: | Height: | Size: 14 KiB After Width: | Height: | Size: 12 KiB | 
| @ -1,7 +1,7 @@ | ||||
| { | ||||
|     "train_loss": 2.2422571182250977, | ||||
|     "train_acc": 0.7347080707550049, | ||||
|     "valid_loss": 2.3184266090393066, | ||||
|     "valid_acc": 0.7381500005722046, | ||||
|     "train_loss": 2.2407546043395996, | ||||
|     "train_acc": 0.22537143528461456, | ||||
|     "valid_loss": 2.324389696121216, | ||||
|     "valid_acc": 0.20746666193008423, | ||||
|     "step": 4 | ||||
| } | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| step	train_acc | ||||
| 0	0.6712241768836975 | ||||
| 1	0.6976224184036255 | ||||
| 2	0.7157850861549377 | ||||
| 3	0.7277812957763672 | ||||
| 4	0.7347080707550049 | ||||
| 0	0.0997999981045723 | ||||
| 1	0.13202856481075287 | ||||
| 2	0.16345714032649994 | ||||
| 3	0.19614285230636597 | ||||
| 4	0.22537143528461456 | ||||
|  | ||||
| 
 | 
| @ -1,6 +1,6 @@ | ||||
| step	train_loss | ||||
| 0	3.0726168155670166 | ||||
| 1	2.7409346103668213 | ||||
| 2	2.5224294662475586 | ||||
| 3	2.364570140838623 | ||||
| 4	2.2422571182250977 | ||||
| 0	3.096755027770996 | ||||
| 1	2.749246120452881 | ||||
| 2	2.522728681564331 | ||||
| 3	2.3631479740142822 | ||||
| 4	2.2407546043395996 | ||||
|  | ||||
| 
 | 
| @ -1,6 +1,6 @@ | ||||
| step	valid_acc | ||||
| 0	0.6918894052505493 | ||||
| 1	0.7131190896034241 | ||||
| 2	0.7261338233947754 | ||||
| 3	0.7339118123054504 | ||||
| 4	0.7381500005722046 | ||||
| 0	0.10953333228826523 | ||||
| 1	0.13466666638851166 | ||||
| 2	0.15913332998752594 | ||||
| 3	0.18406666815280914 | ||||
| 4	0.20746666193008423 | ||||
|  | ||||
| 
 | 
| @ -1,6 +1,6 @@ | ||||
| step	valid_loss | ||||
| 0	2.890321969985962 | ||||
| 1	2.669679880142212 | ||||
| 2	2.5183584690093994 | ||||
| 3	2.4061686992645264 | ||||
| 4	2.3184266090393066 | ||||
| 0	2.919710636138916 | ||||
| 1	2.68220853805542 | ||||
| 2	2.524806022644043 | ||||
| 3	2.4115779399871826 | ||||
| 4	2.324389696121216 | ||||
|  | ||||
| 
 | 
| @ -10,7 +10,7 @@ metrics.json | ||||
|  | ||||
| |   train_loss |   train_acc |   valid_loss |   valid_acc |   step | | ||||
| |--------------|-------------|--------------|-------------|--------| | ||||
| |      2.24226 |    0.734708 |      2.31843 |     0.73815 |      4 | | ||||
| |      2.24075 |    0.225371 |      2.32439 |    0.207467 |      4 | | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| Before Width: | Height: | Size: 21 KiB After Width: | Height: | Size: 22 KiB | 
| Before Width: | Height: | Size: 20 KiB After Width: | Height: | Size: 21 KiB | 
| Before Width: | Height: | Size: 21 KiB After Width: | Height: | Size: 20 KiB | 
| Before Width: | Height: | Size: 22 KiB After Width: | Height: | Size: 22 KiB | 
| @ -8,7 +8,7 @@ import yaml | ||||
|  | ||||
| import torch | ||||
| from torch.utils.data import DataLoader | ||||
| from torchmetrics.classification import MulticlassAccuracy | ||||
| from torchmetrics.classification import Accuracy | ||||
| from dvclive import Live | ||||
|  | ||||
| from utils.dataset import ProcessedDataset | ||||
| @ -33,7 +33,7 @@ def evaluate(params_path: str = 'params.yaml') -> None: | ||||
|     net.to(device) | ||||
|     net.eval() | ||||
|  | ||||
|     metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted') | ||||
|     metric = Accuracy(task='multiclass', num_classes=10) | ||||
|     metric.to(device) | ||||
|  | ||||
|     with Live(dir='dvclive/eval', report='md') as live: | ||||
| @ -43,7 +43,7 @@ def evaluate(params_path: str = 'params.yaml') -> None: | ||||
|             for data in test_dataloader: | ||||
|                 inputs, labels = data[0].to(device), data[1].to(device) | ||||
|                 outputs = net(inputs) | ||||
|                 _ = metric(outputs, labels) | ||||
|                 _ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1]) | ||||
|             test_acc = metric.compute() | ||||
|  | ||||
|         print(f'test_acc:{test_acc}') | ||||
|  | ||||
							
								
								
									
										9
									
								
								train.py
									
									
									
									
									
								
							
							
						
						| @ -10,7 +10,7 @@ import torch | ||||
| from rich.progress import track | ||||
| from torch.utils.data import DataLoader | ||||
| from torchvision.models import resnet50 | ||||
| from torchmetrics.classification import MulticlassAccuracy | ||||
| from torchmetrics.classification import Accuracy | ||||
| from dvclive import Live | ||||
|  | ||||
| from utils.dataset import ProcessedDataset | ||||
| @ -45,7 +45,7 @@ def train(params_path: str = 'params.yaml') -> None: | ||||
|  | ||||
|     criterion = torch.nn.CrossEntropyLoss() | ||||
|     optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate) | ||||
|     metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted') | ||||
|     metric = Accuracy(task='multiclass', num_classes=10) | ||||
|     metric.to(device) | ||||
|  | ||||
|     with Live(dir='dvclive/train', report='md') as live: | ||||
| @ -63,7 +63,7 @@ def train(params_path: str = 'params.yaml') -> None: | ||||
|                 train_loss += loss | ||||
|                 loss.backward() | ||||
|                 optimizer.step() | ||||
|                 _ = metric(outputs, labels) | ||||
|                 _ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1]) | ||||
|             train_loss /= len(train_dataloader) | ||||
|             train_acc = metric.compute() | ||||
|             metric.reset() | ||||
| @ -76,7 +76,7 @@ def train(params_path: str = 'params.yaml') -> None: | ||||
|                     outputs = net(inputs) | ||||
|                     loss = criterion(outputs, labels) | ||||
|                     valid_loss += loss | ||||
|                     _ = metric(outputs, labels) | ||||
|                     _ = metric(outputs.topk(k=1, dim=1)[1], labels.topk(k=1, dim=1)[1]) | ||||
|                 valid_loss /= len(valid_dataloader) | ||||
|                 valid_acc = metric.compute() | ||||
|                 metric.reset() | ||||
| @ -93,7 +93,6 @@ def train(params_path: str = 'params.yaml') -> None: | ||||
|             live.next_step() | ||||
|  | ||||
|         torch.save(net, 'model.pt') | ||||
|         live.log_artifact('model.pt', type='model', name='resnet50') | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||