use resnet50 to train a cifar10 classifier

This commit is contained in:
deng 2023-12-30 00:03:36 +08:00
parent 7a891969e0
commit a27d0a24d9
27 changed files with 393 additions and 19 deletions

4
.gitignore vendored
View File

@ -1,2 +1,4 @@
data
*.pt
*.pt
.DS_Store
__pycache__

View File

@ -1,6 +1,6 @@
# Abstract
Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track model training with PyTorch, including data, trained model file, and used parameters. The data will be recorded and pushed to my private DVC remote via webdav🎁
Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track image classification model training with PyTorch, including data, trained model file, and used parameters. The data will be recorded and pushed to my private DVC remote via webdav🎁
# Requirements
@ -11,6 +11,8 @@ Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track model tr
* **env**
* **pt.yaml**
* conda env yaml to run this repo
* **utils**
* house pre-built functions
# Files

48
dvc.lock Normal file
View File

@ -0,0 +1,48 @@
schema: '2.0'
stages:
prepare:
cmd: python prepare.py
deps:
- path: prepare.py
hash: md5
md5: a1c07d1d5caf6e5288560a189415785c
size: 2979
params:
params.yaml:
prepare:
data_dir: data/raw
save_dir: data/processed
train_valid_split:
- 0.7
- 0.3
random_seed: 0
outs:
- path: data/processed
hash: md5
md5: f4bf62ffa725ca9144b7852a283dc1da.dir
size: 295118798
nfiles: 60000
train:
cmd: python train.py
deps:
- path: data/processed
hash: md5
md5: f4bf62ffa725ca9144b7852a283dc1da.dir
size: 295118798
nfiles: 60000
- path: train.py
hash: md5
md5: b797ccf2fe61952bbf6d83fa51b0b11f
size: 3407
params:
params.yaml:
train:
data_dir: data/processed
epochs: 5
batch_size: 256
learning_rate: 5e-05
outs:
- path: model.pt
hash: md5
md5: 8ead2a7cd52d70b359d3cdc3df5e43e3
size: 102592994

View File

@ -2,27 +2,42 @@ stages:
prepare:
cmd: python prepare.py
deps:
- prepare.py
- prepare.py
params:
- prepare
- prepare
outs:
- data/processed
- data/processed
train:
cmd: python train.py
deps:
- data/processed
- train.py
- data/processed
- train.py
params:
- train
- train
outs:
- model.pt
- model.pt
evaluate:
cmd: python evaluate.py
deps:
- data/processed
- evaluate.py
- model.pt
- data/processed
- evaluate.py
- model.pt
params:
- evaluate
- evaluate
outs:
- eval
- eval
params:
- dvclive/train/params.yaml
- dvclive/eval/params.yaml
artifacts:
resnet50:
path: model.pt
type: model
metrics:
- dvclive/train/metrics.json
- dvclive/eval/metrics.json
plots:
- dvclive/train/plots/metrics:
x: step
- dvclive/eval/plots/metrics:
x: step

View File

@ -0,0 +1,3 @@
{
"test_acc": 0.7336928844451904
}

1
dvclive/eval/params.yaml Normal file
View File

@ -0,0 +1 @@
data_dir: data/processed

View File

@ -0,0 +1,2 @@
step test_acc
0 0.7336928844451904
1 step test_acc
2 0 0.7336928844451904

15
dvclive/eval/report.md Normal file
View File

@ -0,0 +1,15 @@
# DVC Report
params.yaml
| data_dir |
|----------------|
| data/processed |
metrics.json
| test_acc |
|------------|
| 0.733693 |
![static/test_acc](static/test_acc.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

View File

@ -0,0 +1,7 @@
{
"train_loss": 2.2422571182250977,
"train_acc": 0.7347080707550049,
"valid_loss": 2.3184266090393066,
"valid_acc": 0.7381500005722046,
"step": 4
}

View File

@ -0,0 +1,5 @@
data_dir: data/processed
epochs: 5
batch_size: 256
learning_rate: 5e-05
epoch: 4

View File

@ -0,0 +1,6 @@
step train_acc
0 0.6712241768836975
1 0.6976224184036255
2 0.7157850861549377
3 0.7277812957763672
4 0.7347080707550049
1 step train_acc
2 0 0.6712241768836975
3 1 0.6976224184036255
4 2 0.7157850861549377
5 3 0.7277812957763672
6 4 0.7347080707550049

View File

@ -0,0 +1,6 @@
step train_loss
0 3.0726168155670166
1 2.7409346103668213
2 2.5224294662475586
3 2.364570140838623
4 2.2422571182250977
1 step train_loss
2 0 3.0726168155670166
3 1 2.7409346103668213
4 2 2.5224294662475586
5 3 2.364570140838623
6 4 2.2422571182250977

View File

@ -0,0 +1,6 @@
step valid_acc
0 0.6918894052505493
1 0.7131190896034241
2 0.7261338233947754
3 0.7339118123054504
4 0.7381500005722046
1 step valid_acc
2 0 0.6918894052505493
3 1 0.7131190896034241
4 2 0.7261338233947754
5 3 0.7339118123054504
6 4 0.7381500005722046

View File

@ -0,0 +1,6 @@
step valid_loss
0 2.890321969985962
1 2.669679880142212
2 2.5183584690093994
3 2.4061686992645264
4 2.3184266090393066
1 step valid_loss
2 0 2.890321969985962
3 1 2.669679880142212
4 2 2.5183584690093994
5 3 2.4061686992645264
6 4 2.3184266090393066

21
dvclive/train/report.md Normal file
View File

@ -0,0 +1,21 @@
# DVC Report
params.yaml
| data_dir | epochs | batch_size | learning_rate | epoch |
|----------------|----------|--------------|-----------------|---------|
| data/processed | 5 | 256 | 5e-05 | 4 |
metrics.json
| train_loss | train_acc | valid_loss | valid_acc | step |
|--------------|-------------|--------------|-------------|--------|
| 2.24226 | 0.734708 | 2.31843 | 0.73815 | 4 |
![static/valid_loss](static/valid_loss.png)
![static/train_acc](static/train_acc.png)
![static/valid_acc](static/valid_acc.png)
![static/train_loss](static/train_loss.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

18
env/pt.yaml vendored
View File

@ -21,9 +21,11 @@ dependencies:
- billiard=4.1.0
- boto3=1.34.9
- botocore=1.34.9
- brotli=1.1.0
- brotli-bin=1.1.0
- brotli-python=1.1.0
- bzip2=1.0.8
- ca-certificates=2023.08.22
- ca-certificates=2023.11.17
- cairo=1.18.0
- celery=5.3.4
- certifi=2023.11.17
@ -35,7 +37,9 @@ dependencies:
- click-repl=0.3.0
- colorama=0.4.6
- configobj=5.0.8
- contourpy=1.2.0
- cryptography=41.0.7
- cycler=0.12.1
- dav1d=1.2.1
- decorator=5.1.1
- dictdiffer=0.9.0
@ -64,6 +68,7 @@ dependencies:
- fontconfig=2.14.2
- fonts-conda-ecosystem=1
- fonts-conda-forge=1
- fonttools=4.47.0
- freetype=2.12.1
- fribidi=1.0.10
- frozenlist=1.4.1
@ -94,6 +99,7 @@ dependencies:
- jinja2=3.1.2
- jmespath=1.0.1
- joblib=1.2.0
- kiwisolver=1.4.5
- kombu=5.3.4
- krb5=1.21.2
- lame=3.100
@ -101,6 +107,9 @@ dependencies:
- lerc=4.0.0
- libass=0.17.1
- libblas=3.9.0
- libbrotlicommon=1.1.0
- libbrotlidec=1.1.0
- libbrotlienc=1.1.0
- libcblas=3.9.0
- libcxx=16.0.6
- libdeflate=1.19
@ -131,14 +140,18 @@ dependencies:
- libxcb=1.15
- libxml2=2.12.3
- libzlib=1.2.13
- lightning-utilities=0.10.0
- llvm-openmp=17.0.6
- markdown-it-py=3.0.0
- markupsafe=2.1.3
- matplotlib=3.8.2
- matplotlib-base=3.8.2
- mdurl=0.1.0
- mpc=1.3.1
- mpfr=4.2.1
- mpmath=1.3.0
- multidict=6.0.4
- munkres=1.1.4
- nanotime=0.5.2
- ncurses=6.4
- nettle=3.9.1
@ -206,12 +219,15 @@ dependencies:
- tk=8.6.13
- tomlkit=0.12.3
- torchaudio=2.1.2
- torchmetrics=1.2.1
- torchvision=0.16.2
- tornado=6.3.3
- tqdm=4.66.1
- typer=0.9.0
- typing-extensions=4.9.0
- typing_extensions=4.9.0
- tzdata=2023d
- unicodedata2=15.1.0
- urllib3=1.26.18
- vine=5.0.0
- voluptuous=0.14.1

View File

@ -3,9 +3,15 @@
# author: deng
# date : 20231228
from pathlib import Path
import yaml
import torch
from torch.utils.data import DataLoader
from torchmetrics.classification import MulticlassAccuracy
from dvclive import Live
from utils.dataset import ProcessedDataset
def evaluate(params_path: str = 'params.yaml') -> None:
@ -17,6 +23,30 @@ def evaluate(params_path: str = 'params.yaml') -> None:
with open(params_path, 'r') as f:
params = yaml.safe_load(f)
data_dir = Path(params['evaluate']['data_dir'])
test_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('test'))
test_dataloader = DataLoader(test_dataset, batch_size=64)
device = torch.device('mps')
net = torch.load('model.pt')
net.to(device)
net.eval()
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted')
metric.to(device)
with Live(dir='dvclive/eval', report='md') as live:
live.log_params(params['evaluate'])
for data in test_dataloader:
inputs, labels = data[0].to(device), data[1].to(device)
outputs = net(inputs)
_ = metric(outputs, labels)
test_acc = metric.compute()
print(f'test_acc:{test_acc}')
live.log_metric('test_acc', float(test_acc.cpu()))
if __name__ == '__main__':

View File

@ -1,3 +1,12 @@
prepare:
data_dir: data/raw
save_dir: data/processed
train_valid_split: [0.7, 0.3]
random_seed: 0
train:
evaluate:
data_dir: data/processed
epochs: 5
batch_size: 256
learning_rate: 0.00005
evaluate:
data_dir: data/processed

View File

@ -3,9 +3,13 @@
# author: deng
# date : 20231228
from pathlib import Path
from shutil import rmtree
import random
import pickle
import yaml
import torch
import numpy as np
def prepare(params_path: str = 'params.yaml') -> None:
@ -17,7 +21,75 @@ def prepare(params_path: str = 'params.yaml') -> None:
with open(params_path, 'r') as f:
params = yaml.safe_load(f)
data_dir = Path(params['prepare']['data_dir'])
save_dir = Path(params['prepare']['save_dir'])
train_valid_split = params['prepare']['train_valid_split']
random_seed = params['prepare']['random_seed']
train_dir = save_dir.joinpath('train')
valid_dir = save_dir.joinpath('valid')
test_dir = save_dir.joinpath('test')
if train_dir.is_dir():
rmtree(train_dir)
train_dir.mkdir()
if valid_dir.is_dir():
rmtree(valid_dir)
valid_dir.mkdir()
if test_dir.is_dir():
rmtree(test_dir)
test_dir.mkdir()
# Process training data
ids = list(range(50000))
random.Random(random_seed).shuffle(ids)
train_ids = ids[:int(50000 * train_valid_split[0])]
valid_ids = ids[int(50000 * train_valid_split[0]):]
current_id, train_count, valid_count = 0, 0, 0
cifar_10_dir = data_dir.joinpath('cifar-10-batches-py')
for data_path in cifar_10_dir.glob('data_batch_*'):
with open(data_path, 'rb') as f:
data = pickle.load(f, encoding='bytes')
for i, label in enumerate(data[b'labels']):
x = data[b'data'][i]
x = x.reshape(3, 32, 32)
x = x / 255
x = x.astype(np.float32) # mps does not support float64
y = np.zeros(10, dtype=np.float32)
y[label] = 1.
if current_id in train_ids:
npz_path = train_dir.joinpath(f'{train_count}.npz')
train_ids.remove(current_id)
train_count += 1
else:
npz_path = valid_dir.joinpath(f'{valid_count}.npz')
valid_ids.remove(current_id)
valid_count += 1
np.savez_compressed(
npz_path,
x=x,
y=y
)
current_id += 1
# Process testing data
data_path = cifar_10_dir.joinpath('test_batch')
with open(data_path, 'rb') as f:
data = pickle.load(f, encoding='bytes')
for i, label in enumerate(data[b'labels']):
x = data[b'data'][i]
x = x.reshape(3, 32, 32)
x = x / 255
x = x.astype(np.float32)
y = np.zeros(10, dtype=np.float32)
y[label] = 1.
np.savez_compressed(
save_dir.joinpath('test', f'{i}.npz'),
x=x,
y=y
)
if __name__ == '__main__':
prepare()
prepare('params.yaml')

View File

@ -3,21 +3,97 @@
# author: deng
# date : 20231228
from pathlib import Path
import yaml
import torch
from rich.progress import track
from torch.utils.data import DataLoader
from torchvision.models import resnet50
from torchmetrics.classification import MulticlassAccuracy
from dvclive import Live
from utils.dataset import ProcessedDataset
def train(params_path: str = 'params.yaml') -> None:
"""Train a simple model using Pytorch
Args:
params_path (str, optional): path of config yaml. Defaults to 'params.yaml'.
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
"""
with open(params_path, 'r') as f:
params = yaml.safe_load(f)
data_dir = Path(params['train']['data_dir'])
epochs = params['train']['epochs']
batch_size = params['train']['batch_size']
learning_rate = params['train']['learning_rate']
train_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('train'))
valid_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('valid'))
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=2)
device = torch.device('mps')
net = torch.nn.Sequential(
resnet50(weights='IMAGENET1K_V1'),
torch.nn.Linear(1000, 10)
)
net.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted')
metric.to(device)
with Live(dir='dvclive/train', report='md') as live:
live.log_params(params['train'])
for epoch in track(range(epochs), 'Epochs'):
net.train()
train_loss = 0.
for data in train_dataloader:
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
train_loss += loss
loss.backward()
optimizer.step()
_ = metric(outputs, labels)
train_loss /= len(train_dataloader)
train_acc = metric.compute()
metric.reset()
net.eval()
with torch.no_grad():
valid_loss = 0.
for data in valid_dataloader:
inputs, labels = data[0].to(device), data[1].to(device)
outputs = net(inputs)
loss = criterion(outputs, labels)
valid_loss += loss
_ = metric(outputs, labels)
valid_loss /= len(valid_dataloader)
valid_acc = metric.compute()
metric.reset()
print(
f'Epoch {epoch} - train_loss:{train_loss} train_acc:{train_acc} '
f'valid_loss:{valid_loss} valid_acc:{valid_acc}'
)
live.log_param('epoch', epoch)
live.log_metric('train_loss', float(train_loss.cpu()))
live.log_metric('train_acc', float(train_acc.cpu()))
live.log_metric('valid_loss', float(valid_loss.cpu()))
live.log_metric('valid_acc', float(valid_acc.cpu()))
live.next_step()
torch.save(net, 'model.pt')
live.log_artifact('model.pt', type='model', name='resnet50')
if __name__ == '__main__':

0
utils/__init__.py Normal file
View File

26
utils/dataset.py Normal file
View File

@ -0,0 +1,26 @@
# prepare.py
#
# author: deng
# date : 20231229
from pathlib import PosixPath
import torch
import numpy as np
from torch.utils.data import Dataset
class ProcessedDataset(Dataset):
""""Load processed data"""
def __init__(self, dataset_dir: PosixPath):
self.dataset_dir = dataset_dir
self.file_paths = list(self.dataset_dir.glob('*.npz'))
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
npz = np.load(self.file_paths[idx])
x = torch.from_numpy(npz['x'])
y = torch.from_numpy(npz['y'])
return x, y