use resnet50 to train a cifar10 classifier
This commit is contained in:
parent
7a891969e0
commit
a27d0a24d9
|
@ -1,2 +1,4 @@
|
|||
data
|
||||
*.pt
|
||||
*.pt
|
||||
.DS_Store
|
||||
__pycache__
|
|
@ -1,6 +1,6 @@
|
|||
# Abstract
|
||||
|
||||
Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track model training with PyTorch, including data, trained model file, and used parameters. The data will be recorded and pushed to my private DVC remote via webdav🎁
|
||||
Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track image classification model training with PyTorch, including data, trained model file, and used parameters. The data will be recorded and pushed to my private DVC remote via webdav🎁
|
||||
|
||||
# Requirements
|
||||
|
||||
|
@ -11,6 +11,8 @@ Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track model tr
|
|||
* **env**
|
||||
* **pt.yaml**
|
||||
* conda env yaml to run this repo
|
||||
* **utils**
|
||||
* house pre-built functions
|
||||
|
||||
# Files
|
||||
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
schema: '2.0'
|
||||
stages:
|
||||
prepare:
|
||||
cmd: python prepare.py
|
||||
deps:
|
||||
- path: prepare.py
|
||||
hash: md5
|
||||
md5: a1c07d1d5caf6e5288560a189415785c
|
||||
size: 2979
|
||||
params:
|
||||
params.yaml:
|
||||
prepare:
|
||||
data_dir: data/raw
|
||||
save_dir: data/processed
|
||||
train_valid_split:
|
||||
- 0.7
|
||||
- 0.3
|
||||
random_seed: 0
|
||||
outs:
|
||||
- path: data/processed
|
||||
hash: md5
|
||||
md5: f4bf62ffa725ca9144b7852a283dc1da.dir
|
||||
size: 295118798
|
||||
nfiles: 60000
|
||||
train:
|
||||
cmd: python train.py
|
||||
deps:
|
||||
- path: data/processed
|
||||
hash: md5
|
||||
md5: f4bf62ffa725ca9144b7852a283dc1da.dir
|
||||
size: 295118798
|
||||
nfiles: 60000
|
||||
- path: train.py
|
||||
hash: md5
|
||||
md5: b797ccf2fe61952bbf6d83fa51b0b11f
|
||||
size: 3407
|
||||
params:
|
||||
params.yaml:
|
||||
train:
|
||||
data_dir: data/processed
|
||||
epochs: 5
|
||||
batch_size: 256
|
||||
learning_rate: 5e-05
|
||||
outs:
|
||||
- path: model.pt
|
||||
hash: md5
|
||||
md5: 8ead2a7cd52d70b359d3cdc3df5e43e3
|
||||
size: 102592994
|
39
dvc.yaml
39
dvc.yaml
|
@ -2,27 +2,42 @@ stages:
|
|||
prepare:
|
||||
cmd: python prepare.py
|
||||
deps:
|
||||
- prepare.py
|
||||
- prepare.py
|
||||
params:
|
||||
- prepare
|
||||
- prepare
|
||||
outs:
|
||||
- data/processed
|
||||
- data/processed
|
||||
train:
|
||||
cmd: python train.py
|
||||
deps:
|
||||
- data/processed
|
||||
- train.py
|
||||
- data/processed
|
||||
- train.py
|
||||
params:
|
||||
- train
|
||||
- train
|
||||
outs:
|
||||
- model.pt
|
||||
- model.pt
|
||||
evaluate:
|
||||
cmd: python evaluate.py
|
||||
deps:
|
||||
- data/processed
|
||||
- evaluate.py
|
||||
- model.pt
|
||||
- data/processed
|
||||
- evaluate.py
|
||||
- model.pt
|
||||
params:
|
||||
- evaluate
|
||||
- evaluate
|
||||
outs:
|
||||
- eval
|
||||
- eval
|
||||
params:
|
||||
- dvclive/train/params.yaml
|
||||
- dvclive/eval/params.yaml
|
||||
artifacts:
|
||||
resnet50:
|
||||
path: model.pt
|
||||
type: model
|
||||
metrics:
|
||||
- dvclive/train/metrics.json
|
||||
- dvclive/eval/metrics.json
|
||||
plots:
|
||||
- dvclive/train/plots/metrics:
|
||||
x: step
|
||||
- dvclive/eval/plots/metrics:
|
||||
x: step
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"test_acc": 0.7336928844451904
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
data_dir: data/processed
|
|
@ -0,0 +1,2 @@
|
|||
step test_acc
|
||||
0 0.7336928844451904
|
|
|
@ -0,0 +1,15 @@
|
|||
# DVC Report
|
||||
|
||||
params.yaml
|
||||
|
||||
| data_dir |
|
||||
|----------------|
|
||||
| data/processed |
|
||||
|
||||
metrics.json
|
||||
|
||||
| test_acc |
|
||||
|------------|
|
||||
| 0.733693 |
|
||||
|
||||

|
Binary file not shown.
After Width: | Height: | Size: 14 KiB |
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"train_loss": 2.2422571182250977,
|
||||
"train_acc": 0.7347080707550049,
|
||||
"valid_loss": 2.3184266090393066,
|
||||
"valid_acc": 0.7381500005722046,
|
||||
"step": 4
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
data_dir: data/processed
|
||||
epochs: 5
|
||||
batch_size: 256
|
||||
learning_rate: 5e-05
|
||||
epoch: 4
|
|
@ -0,0 +1,6 @@
|
|||
step train_acc
|
||||
0 0.6712241768836975
|
||||
1 0.6976224184036255
|
||||
2 0.7157850861549377
|
||||
3 0.7277812957763672
|
||||
4 0.7347080707550049
|
|
|
@ -0,0 +1,6 @@
|
|||
step train_loss
|
||||
0 3.0726168155670166
|
||||
1 2.7409346103668213
|
||||
2 2.5224294662475586
|
||||
3 2.364570140838623
|
||||
4 2.2422571182250977
|
|
|
@ -0,0 +1,6 @@
|
|||
step valid_acc
|
||||
0 0.6918894052505493
|
||||
1 0.7131190896034241
|
||||
2 0.7261338233947754
|
||||
3 0.7339118123054504
|
||||
4 0.7381500005722046
|
|
|
@ -0,0 +1,6 @@
|
|||
step valid_loss
|
||||
0 2.890321969985962
|
||||
1 2.669679880142212
|
||||
2 2.5183584690093994
|
||||
3 2.4061686992645264
|
||||
4 2.3184266090393066
|
|
|
@ -0,0 +1,21 @@
|
|||
# DVC Report
|
||||
|
||||
params.yaml
|
||||
|
||||
| data_dir | epochs | batch_size | learning_rate | epoch |
|
||||
|----------------|----------|--------------|-----------------|---------|
|
||||
| data/processed | 5 | 256 | 5e-05 | 4 |
|
||||
|
||||
metrics.json
|
||||
|
||||
| train_loss | train_acc | valid_loss | valid_acc | step |
|
||||
|--------------|-------------|--------------|-------------|--------|
|
||||
| 2.24226 | 0.734708 | 2.31843 | 0.73815 | 4 |
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
Binary file not shown.
After Width: | Height: | Size: 21 KiB |
Binary file not shown.
After Width: | Height: | Size: 20 KiB |
Binary file not shown.
After Width: | Height: | Size: 21 KiB |
Binary file not shown.
After Width: | Height: | Size: 22 KiB |
|
@ -21,9 +21,11 @@ dependencies:
|
|||
- billiard=4.1.0
|
||||
- boto3=1.34.9
|
||||
- botocore=1.34.9
|
||||
- brotli=1.1.0
|
||||
- brotli-bin=1.1.0
|
||||
- brotli-python=1.1.0
|
||||
- bzip2=1.0.8
|
||||
- ca-certificates=2023.08.22
|
||||
- ca-certificates=2023.11.17
|
||||
- cairo=1.18.0
|
||||
- celery=5.3.4
|
||||
- certifi=2023.11.17
|
||||
|
@ -35,7 +37,9 @@ dependencies:
|
|||
- click-repl=0.3.0
|
||||
- colorama=0.4.6
|
||||
- configobj=5.0.8
|
||||
- contourpy=1.2.0
|
||||
- cryptography=41.0.7
|
||||
- cycler=0.12.1
|
||||
- dav1d=1.2.1
|
||||
- decorator=5.1.1
|
||||
- dictdiffer=0.9.0
|
||||
|
@ -64,6 +68,7 @@ dependencies:
|
|||
- fontconfig=2.14.2
|
||||
- fonts-conda-ecosystem=1
|
||||
- fonts-conda-forge=1
|
||||
- fonttools=4.47.0
|
||||
- freetype=2.12.1
|
||||
- fribidi=1.0.10
|
||||
- frozenlist=1.4.1
|
||||
|
@ -94,6 +99,7 @@ dependencies:
|
|||
- jinja2=3.1.2
|
||||
- jmespath=1.0.1
|
||||
- joblib=1.2.0
|
||||
- kiwisolver=1.4.5
|
||||
- kombu=5.3.4
|
||||
- krb5=1.21.2
|
||||
- lame=3.100
|
||||
|
@ -101,6 +107,9 @@ dependencies:
|
|||
- lerc=4.0.0
|
||||
- libass=0.17.1
|
||||
- libblas=3.9.0
|
||||
- libbrotlicommon=1.1.0
|
||||
- libbrotlidec=1.1.0
|
||||
- libbrotlienc=1.1.0
|
||||
- libcblas=3.9.0
|
||||
- libcxx=16.0.6
|
||||
- libdeflate=1.19
|
||||
|
@ -131,14 +140,18 @@ dependencies:
|
|||
- libxcb=1.15
|
||||
- libxml2=2.12.3
|
||||
- libzlib=1.2.13
|
||||
- lightning-utilities=0.10.0
|
||||
- llvm-openmp=17.0.6
|
||||
- markdown-it-py=3.0.0
|
||||
- markupsafe=2.1.3
|
||||
- matplotlib=3.8.2
|
||||
- matplotlib-base=3.8.2
|
||||
- mdurl=0.1.0
|
||||
- mpc=1.3.1
|
||||
- mpfr=4.2.1
|
||||
- mpmath=1.3.0
|
||||
- multidict=6.0.4
|
||||
- munkres=1.1.4
|
||||
- nanotime=0.5.2
|
||||
- ncurses=6.4
|
||||
- nettle=3.9.1
|
||||
|
@ -206,12 +219,15 @@ dependencies:
|
|||
- tk=8.6.13
|
||||
- tomlkit=0.12.3
|
||||
- torchaudio=2.1.2
|
||||
- torchmetrics=1.2.1
|
||||
- torchvision=0.16.2
|
||||
- tornado=6.3.3
|
||||
- tqdm=4.66.1
|
||||
- typer=0.9.0
|
||||
- typing-extensions=4.9.0
|
||||
- typing_extensions=4.9.0
|
||||
- tzdata=2023d
|
||||
- unicodedata2=15.1.0
|
||||
- urllib3=1.26.18
|
||||
- vine=5.0.0
|
||||
- voluptuous=0.14.1
|
||||
|
|
30
evaluate.py
30
evaluate.py
|
@ -3,9 +3,15 @@
|
|||
# author: deng
|
||||
# date : 20231228
|
||||
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from dvclive import Live
|
||||
|
||||
from utils.dataset import ProcessedDataset
|
||||
|
||||
|
||||
def evaluate(params_path: str = 'params.yaml') -> None:
|
||||
|
@ -17,6 +23,30 @@ def evaluate(params_path: str = 'params.yaml') -> None:
|
|||
|
||||
with open(params_path, 'r') as f:
|
||||
params = yaml.safe_load(f)
|
||||
data_dir = Path(params['evaluate']['data_dir'])
|
||||
|
||||
test_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('test'))
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=64)
|
||||
|
||||
device = torch.device('mps')
|
||||
net = torch.load('model.pt')
|
||||
net.to(device)
|
||||
net.eval()
|
||||
|
||||
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted')
|
||||
metric.to(device)
|
||||
|
||||
with Live(dir='dvclive/eval', report='md') as live:
|
||||
live.log_params(params['evaluate'])
|
||||
|
||||
for data in test_dataloader:
|
||||
inputs, labels = data[0].to(device), data[1].to(device)
|
||||
outputs = net(inputs)
|
||||
_ = metric(outputs, labels)
|
||||
test_acc = metric.compute()
|
||||
|
||||
print(f'test_acc:{test_acc}')
|
||||
live.log_metric('test_acc', float(test_acc.cpu()))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
11
params.yaml
11
params.yaml
|
@ -1,3 +1,12 @@
|
|||
prepare:
|
||||
data_dir: data/raw
|
||||
save_dir: data/processed
|
||||
train_valid_split: [0.7, 0.3]
|
||||
random_seed: 0
|
||||
train:
|
||||
evaluate:
|
||||
data_dir: data/processed
|
||||
epochs: 5
|
||||
batch_size: 256
|
||||
learning_rate: 0.00005
|
||||
evaluate:
|
||||
data_dir: data/processed
|
76
prepare.py
76
prepare.py
|
@ -3,9 +3,13 @@
|
|||
# author: deng
|
||||
# date : 20231228
|
||||
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
import random
|
||||
import pickle
|
||||
import yaml
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def prepare(params_path: str = 'params.yaml') -> None:
|
||||
|
@ -17,7 +21,75 @@ def prepare(params_path: str = 'params.yaml') -> None:
|
|||
|
||||
with open(params_path, 'r') as f:
|
||||
params = yaml.safe_load(f)
|
||||
data_dir = Path(params['prepare']['data_dir'])
|
||||
save_dir = Path(params['prepare']['save_dir'])
|
||||
train_valid_split = params['prepare']['train_valid_split']
|
||||
random_seed = params['prepare']['random_seed']
|
||||
|
||||
train_dir = save_dir.joinpath('train')
|
||||
valid_dir = save_dir.joinpath('valid')
|
||||
test_dir = save_dir.joinpath('test')
|
||||
|
||||
if train_dir.is_dir():
|
||||
rmtree(train_dir)
|
||||
train_dir.mkdir()
|
||||
if valid_dir.is_dir():
|
||||
rmtree(valid_dir)
|
||||
valid_dir.mkdir()
|
||||
if test_dir.is_dir():
|
||||
rmtree(test_dir)
|
||||
test_dir.mkdir()
|
||||
|
||||
# Process training data
|
||||
ids = list(range(50000))
|
||||
random.Random(random_seed).shuffle(ids)
|
||||
train_ids = ids[:int(50000 * train_valid_split[0])]
|
||||
valid_ids = ids[int(50000 * train_valid_split[0]):]
|
||||
|
||||
current_id, train_count, valid_count = 0, 0, 0
|
||||
cifar_10_dir = data_dir.joinpath('cifar-10-batches-py')
|
||||
for data_path in cifar_10_dir.glob('data_batch_*'):
|
||||
with open(data_path, 'rb') as f:
|
||||
data = pickle.load(f, encoding='bytes')
|
||||
for i, label in enumerate(data[b'labels']):
|
||||
x = data[b'data'][i]
|
||||
x = x.reshape(3, 32, 32)
|
||||
x = x / 255
|
||||
x = x.astype(np.float32) # mps does not support float64
|
||||
y = np.zeros(10, dtype=np.float32)
|
||||
y[label] = 1.
|
||||
if current_id in train_ids:
|
||||
npz_path = train_dir.joinpath(f'{train_count}.npz')
|
||||
train_ids.remove(current_id)
|
||||
train_count += 1
|
||||
else:
|
||||
npz_path = valid_dir.joinpath(f'{valid_count}.npz')
|
||||
valid_ids.remove(current_id)
|
||||
valid_count += 1
|
||||
np.savez_compressed(
|
||||
npz_path,
|
||||
x=x,
|
||||
y=y
|
||||
)
|
||||
current_id += 1
|
||||
|
||||
# Process testing data
|
||||
data_path = cifar_10_dir.joinpath('test_batch')
|
||||
with open(data_path, 'rb') as f:
|
||||
data = pickle.load(f, encoding='bytes')
|
||||
for i, label in enumerate(data[b'labels']):
|
||||
x = data[b'data'][i]
|
||||
x = x.reshape(3, 32, 32)
|
||||
x = x / 255
|
||||
x = x.astype(np.float32)
|
||||
y = np.zeros(10, dtype=np.float32)
|
||||
y[label] = 1.
|
||||
np.savez_compressed(
|
||||
save_dir.joinpath('test', f'{i}.npz'),
|
||||
x=x,
|
||||
y=y
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
prepare()
|
||||
prepare('params.yaml')
|
||||
|
|
78
train.py
78
train.py
|
@ -3,21 +3,97 @@
|
|||
# author: deng
|
||||
# date : 20231228
|
||||
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
import torch
|
||||
from rich.progress import track
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.models import resnet50
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from dvclive import Live
|
||||
|
||||
from utils.dataset import ProcessedDataset
|
||||
|
||||
|
||||
def train(params_path: str = 'params.yaml') -> None:
|
||||
"""Train a simple model using Pytorch
|
||||
|
||||
Args:
|
||||
params_path (str, optional): path of config yaml. Defaults to 'params.yaml'.
|
||||
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
|
||||
"""
|
||||
|
||||
with open(params_path, 'r') as f:
|
||||
params = yaml.safe_load(f)
|
||||
data_dir = Path(params['train']['data_dir'])
|
||||
epochs = params['train']['epochs']
|
||||
batch_size = params['train']['batch_size']
|
||||
learning_rate = params['train']['learning_rate']
|
||||
|
||||
train_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('train'))
|
||||
valid_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('valid'))
|
||||
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2)
|
||||
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=2)
|
||||
|
||||
device = torch.device('mps')
|
||||
net = torch.nn.Sequential(
|
||||
resnet50(weights='IMAGENET1K_V1'),
|
||||
torch.nn.Linear(1000, 10)
|
||||
)
|
||||
net.to(device)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
|
||||
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted')
|
||||
metric.to(device)
|
||||
|
||||
with Live(dir='dvclive/train', report='md') as live:
|
||||
live.log_params(params['train'])
|
||||
|
||||
for epoch in track(range(epochs), 'Epochs'):
|
||||
|
||||
net.train()
|
||||
train_loss = 0.
|
||||
for data in train_dataloader:
|
||||
inputs, labels = data[0].to(device), data[1].to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
train_loss += loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
_ = metric(outputs, labels)
|
||||
train_loss /= len(train_dataloader)
|
||||
train_acc = metric.compute()
|
||||
metric.reset()
|
||||
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
valid_loss = 0.
|
||||
for data in valid_dataloader:
|
||||
inputs, labels = data[0].to(device), data[1].to(device)
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
valid_loss += loss
|
||||
_ = metric(outputs, labels)
|
||||
valid_loss /= len(valid_dataloader)
|
||||
valid_acc = metric.compute()
|
||||
metric.reset()
|
||||
|
||||
print(
|
||||
f'Epoch {epoch} - train_loss:{train_loss} train_acc:{train_acc} '
|
||||
f'valid_loss:{valid_loss} valid_acc:{valid_acc}'
|
||||
)
|
||||
live.log_param('epoch', epoch)
|
||||
live.log_metric('train_loss', float(train_loss.cpu()))
|
||||
live.log_metric('train_acc', float(train_acc.cpu()))
|
||||
live.log_metric('valid_loss', float(valid_loss.cpu()))
|
||||
live.log_metric('valid_acc', float(valid_acc.cpu()))
|
||||
live.next_step()
|
||||
|
||||
torch.save(net, 'model.pt')
|
||||
live.log_artifact('model.pt', type='model', name='resnet50')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# prepare.py
|
||||
#
|
||||
# author: deng
|
||||
# date : 20231229
|
||||
|
||||
from pathlib import PosixPath
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class ProcessedDataset(Dataset):
|
||||
""""Load processed data"""
|
||||
def __init__(self, dataset_dir: PosixPath):
|
||||
self.dataset_dir = dataset_dir
|
||||
self.file_paths = list(self.dataset_dir.glob('*.npz'))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
npz = np.load(self.file_paths[idx])
|
||||
x = torch.from_numpy(npz['x'])
|
||||
y = torch.from_numpy(npz['y'])
|
||||
return x, y
|
Loading…
Reference in New Issue