diff --git a/.gitignore b/.gitignore index 1ffdb32..b9cea6d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ data -*.pt \ No newline at end of file +*.pt +.DS_Store +__pycache__ \ No newline at end of file diff --git a/README.md b/README.md index 659abf8..10edd78 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Abstract -Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track model training with PyTorch, including data, trained model file, and used parameters. The data will be recorded and pushed to my private DVC remote via webdavšŸŽ +Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track image classification model training with PyTorch, including data, trained model file, and used parameters. The data will be recorded and pushed to my private DVC remote via webdavšŸŽ # Requirements @@ -11,6 +11,8 @@ Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track model tr * **env** * **pt.yaml** * conda env yaml to run this repo +* **utils** + * house pre-built functions # Files diff --git a/dvc.lock b/dvc.lock new file mode 100644 index 0000000..5cf30dc --- /dev/null +++ b/dvc.lock @@ -0,0 +1,48 @@ +schema: '2.0' +stages: + prepare: + cmd: python prepare.py + deps: + - path: prepare.py + hash: md5 + md5: a1c07d1d5caf6e5288560a189415785c + size: 2979 + params: + params.yaml: + prepare: + data_dir: data/raw + save_dir: data/processed + train_valid_split: + - 0.7 + - 0.3 + random_seed: 0 + outs: + - path: data/processed + hash: md5 + md5: f4bf62ffa725ca9144b7852a283dc1da.dir + size: 295118798 + nfiles: 60000 + train: + cmd: python train.py + deps: + - path: data/processed + hash: md5 + md5: f4bf62ffa725ca9144b7852a283dc1da.dir + size: 295118798 + nfiles: 60000 + - path: train.py + hash: md5 + md5: b797ccf2fe61952bbf6d83fa51b0b11f + size: 3407 + params: + params.yaml: + train: + data_dir: data/processed + epochs: 5 + batch_size: 256 + learning_rate: 5e-05 + outs: + - path: model.pt + hash: md5 + md5: 8ead2a7cd52d70b359d3cdc3df5e43e3 + size: 102592994 diff --git a/dvc.yaml b/dvc.yaml index ccf68d4..33c2a8a 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -2,27 +2,42 @@ stages: prepare: cmd: python prepare.py deps: - - prepare.py + - prepare.py params: - - prepare + - prepare outs: - - data/processed + - data/processed train: cmd: python train.py deps: - - data/processed - - train.py + - data/processed + - train.py params: - - train + - train outs: - - model.pt + - model.pt evaluate: cmd: python evaluate.py deps: - - data/processed - - evaluate.py - - model.pt + - data/processed + - evaluate.py + - model.pt params: - - evaluate + - evaluate outs: - - eval \ No newline at end of file + - eval +params: +- dvclive/train/params.yaml +- dvclive/eval/params.yaml +artifacts: + resnet50: + path: model.pt + type: model +metrics: +- dvclive/train/metrics.json +- dvclive/eval/metrics.json +plots: +- dvclive/train/plots/metrics: + x: step +- dvclive/eval/plots/metrics: + x: step diff --git a/dvclive/eval/metrics.json b/dvclive/eval/metrics.json new file mode 100644 index 0000000..59e3ec6 --- /dev/null +++ b/dvclive/eval/metrics.json @@ -0,0 +1,3 @@ +{ + "test_acc": 0.7336928844451904 +} diff --git a/dvclive/eval/params.yaml b/dvclive/eval/params.yaml new file mode 100644 index 0000000..a800f2b --- /dev/null +++ b/dvclive/eval/params.yaml @@ -0,0 +1 @@ +data_dir: data/processed diff --git a/dvclive/eval/plots/metrics/test_acc.tsv b/dvclive/eval/plots/metrics/test_acc.tsv new file mode 100644 index 0000000..f4169b0 --- /dev/null +++ b/dvclive/eval/plots/metrics/test_acc.tsv @@ -0,0 +1,2 @@ +step test_acc +0 0.7336928844451904 diff --git a/dvclive/eval/report.md b/dvclive/eval/report.md new file mode 100644 index 0000000..6e43a9c --- /dev/null +++ b/dvclive/eval/report.md @@ -0,0 +1,15 @@ +# DVC Report + +params.yaml + +| data_dir | +|----------------| +| data/processed | + +metrics.json + +| test_acc | +|------------| +| 0.733693 | + +![static/test_acc](static/test_acc.png) diff --git a/dvclive/eval/static/test_acc.png b/dvclive/eval/static/test_acc.png new file mode 100644 index 0000000..27643cd Binary files /dev/null and b/dvclive/eval/static/test_acc.png differ diff --git a/dvclive/train/metrics.json b/dvclive/train/metrics.json new file mode 100644 index 0000000..9a0ddb1 --- /dev/null +++ b/dvclive/train/metrics.json @@ -0,0 +1,7 @@ +{ + "train_loss": 2.2422571182250977, + "train_acc": 0.7347080707550049, + "valid_loss": 2.3184266090393066, + "valid_acc": 0.7381500005722046, + "step": 4 +} diff --git a/dvclive/train/params.yaml b/dvclive/train/params.yaml new file mode 100644 index 0000000..b2dbe4b --- /dev/null +++ b/dvclive/train/params.yaml @@ -0,0 +1,5 @@ +data_dir: data/processed +epochs: 5 +batch_size: 256 +learning_rate: 5e-05 +epoch: 4 diff --git a/dvclive/train/plots/metrics/train_acc.tsv b/dvclive/train/plots/metrics/train_acc.tsv new file mode 100644 index 0000000..3535742 --- /dev/null +++ b/dvclive/train/plots/metrics/train_acc.tsv @@ -0,0 +1,6 @@ +step train_acc +0 0.6712241768836975 +1 0.6976224184036255 +2 0.7157850861549377 +3 0.7277812957763672 +4 0.7347080707550049 diff --git a/dvclive/train/plots/metrics/train_loss.tsv b/dvclive/train/plots/metrics/train_loss.tsv new file mode 100644 index 0000000..26010c9 --- /dev/null +++ b/dvclive/train/plots/metrics/train_loss.tsv @@ -0,0 +1,6 @@ +step train_loss +0 3.0726168155670166 +1 2.7409346103668213 +2 2.5224294662475586 +3 2.364570140838623 +4 2.2422571182250977 diff --git a/dvclive/train/plots/metrics/valid_acc.tsv b/dvclive/train/plots/metrics/valid_acc.tsv new file mode 100644 index 0000000..9818ddd --- /dev/null +++ b/dvclive/train/plots/metrics/valid_acc.tsv @@ -0,0 +1,6 @@ +step valid_acc +0 0.6918894052505493 +1 0.7131190896034241 +2 0.7261338233947754 +3 0.7339118123054504 +4 0.7381500005722046 diff --git a/dvclive/train/plots/metrics/valid_loss.tsv b/dvclive/train/plots/metrics/valid_loss.tsv new file mode 100644 index 0000000..a556569 --- /dev/null +++ b/dvclive/train/plots/metrics/valid_loss.tsv @@ -0,0 +1,6 @@ +step valid_loss +0 2.890321969985962 +1 2.669679880142212 +2 2.5183584690093994 +3 2.4061686992645264 +4 2.3184266090393066 diff --git a/dvclive/train/report.md b/dvclive/train/report.md new file mode 100644 index 0000000..e584bb1 --- /dev/null +++ b/dvclive/train/report.md @@ -0,0 +1,21 @@ +# DVC Report + +params.yaml + +| data_dir | epochs | batch_size | learning_rate | epoch | +|----------------|----------|--------------|-----------------|---------| +| data/processed | 5 | 256 | 5e-05 | 4 | + +metrics.json + +| train_loss | train_acc | valid_loss | valid_acc | step | +|--------------|-------------|--------------|-------------|--------| +| 2.24226 | 0.734708 | 2.31843 | 0.73815 | 4 | + +![static/valid_loss](static/valid_loss.png) + +![static/train_acc](static/train_acc.png) + +![static/valid_acc](static/valid_acc.png) + +![static/train_loss](static/train_loss.png) diff --git a/dvclive/train/static/train_acc.png b/dvclive/train/static/train_acc.png new file mode 100644 index 0000000..cda0996 Binary files /dev/null and b/dvclive/train/static/train_acc.png differ diff --git a/dvclive/train/static/train_loss.png b/dvclive/train/static/train_loss.png new file mode 100644 index 0000000..d2a49db Binary files /dev/null and b/dvclive/train/static/train_loss.png differ diff --git a/dvclive/train/static/valid_acc.png b/dvclive/train/static/valid_acc.png new file mode 100644 index 0000000..6cc9d85 Binary files /dev/null and b/dvclive/train/static/valid_acc.png differ diff --git a/dvclive/train/static/valid_loss.png b/dvclive/train/static/valid_loss.png new file mode 100644 index 0000000..7cd6fd3 Binary files /dev/null and b/dvclive/train/static/valid_loss.png differ diff --git a/env/pt.yaml b/env/pt.yaml index 1f32466..a21af17 100644 --- a/env/pt.yaml +++ b/env/pt.yaml @@ -21,9 +21,11 @@ dependencies: - billiard=4.1.0 - boto3=1.34.9 - botocore=1.34.9 + - brotli=1.1.0 + - brotli-bin=1.1.0 - brotli-python=1.1.0 - bzip2=1.0.8 - - ca-certificates=2023.08.22 + - ca-certificates=2023.11.17 - cairo=1.18.0 - celery=5.3.4 - certifi=2023.11.17 @@ -35,7 +37,9 @@ dependencies: - click-repl=0.3.0 - colorama=0.4.6 - configobj=5.0.8 + - contourpy=1.2.0 - cryptography=41.0.7 + - cycler=0.12.1 - dav1d=1.2.1 - decorator=5.1.1 - dictdiffer=0.9.0 @@ -64,6 +68,7 @@ dependencies: - fontconfig=2.14.2 - fonts-conda-ecosystem=1 - fonts-conda-forge=1 + - fonttools=4.47.0 - freetype=2.12.1 - fribidi=1.0.10 - frozenlist=1.4.1 @@ -94,6 +99,7 @@ dependencies: - jinja2=3.1.2 - jmespath=1.0.1 - joblib=1.2.0 + - kiwisolver=1.4.5 - kombu=5.3.4 - krb5=1.21.2 - lame=3.100 @@ -101,6 +107,9 @@ dependencies: - lerc=4.0.0 - libass=0.17.1 - libblas=3.9.0 + - libbrotlicommon=1.1.0 + - libbrotlidec=1.1.0 + - libbrotlienc=1.1.0 - libcblas=3.9.0 - libcxx=16.0.6 - libdeflate=1.19 @@ -131,14 +140,18 @@ dependencies: - libxcb=1.15 - libxml2=2.12.3 - libzlib=1.2.13 + - lightning-utilities=0.10.0 - llvm-openmp=17.0.6 - markdown-it-py=3.0.0 - markupsafe=2.1.3 + - matplotlib=3.8.2 + - matplotlib-base=3.8.2 - mdurl=0.1.0 - mpc=1.3.1 - mpfr=4.2.1 - mpmath=1.3.0 - multidict=6.0.4 + - munkres=1.1.4 - nanotime=0.5.2 - ncurses=6.4 - nettle=3.9.1 @@ -206,12 +219,15 @@ dependencies: - tk=8.6.13 - tomlkit=0.12.3 - torchaudio=2.1.2 + - torchmetrics=1.2.1 - torchvision=0.16.2 + - tornado=6.3.3 - tqdm=4.66.1 - typer=0.9.0 - typing-extensions=4.9.0 - typing_extensions=4.9.0 - tzdata=2023d + - unicodedata2=15.1.0 - urllib3=1.26.18 - vine=5.0.0 - voluptuous=0.14.1 diff --git a/evaluate.py b/evaluate.py index a1c9852..d0dcad8 100644 --- a/evaluate.py +++ b/evaluate.py @@ -3,9 +3,15 @@ # author: deng # date : 20231228 +from pathlib import Path import yaml import torch +from torch.utils.data import DataLoader +from torchmetrics.classification import MulticlassAccuracy +from dvclive import Live + +from utils.dataset import ProcessedDataset def evaluate(params_path: str = 'params.yaml') -> None: @@ -17,6 +23,30 @@ def evaluate(params_path: str = 'params.yaml') -> None: with open(params_path, 'r') as f: params = yaml.safe_load(f) + data_dir = Path(params['evaluate']['data_dir']) + + test_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('test')) + test_dataloader = DataLoader(test_dataset, batch_size=64) + + device = torch.device('mps') + net = torch.load('model.pt') + net.to(device) + net.eval() + + metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted') + metric.to(device) + + with Live(dir='dvclive/eval', report='md') as live: + live.log_params(params['evaluate']) + + for data in test_dataloader: + inputs, labels = data[0].to(device), data[1].to(device) + outputs = net(inputs) + _ = metric(outputs, labels) + test_acc = metric.compute() + + print(f'test_acc:{test_acc}') + live.log_metric('test_acc', float(test_acc.cpu())) if __name__ == '__main__': diff --git a/params.yaml b/params.yaml index 80f5bb5..763681f 100644 --- a/params.yaml +++ b/params.yaml @@ -1,3 +1,12 @@ prepare: + data_dir: data/raw + save_dir: data/processed + train_valid_split: [0.7, 0.3] + random_seed: 0 train: -evaluate: \ No newline at end of file + data_dir: data/processed + epochs: 5 + batch_size: 256 + learning_rate: 0.00005 +evaluate: + data_dir: data/processed \ No newline at end of file diff --git a/prepare.py b/prepare.py index b9dc5cc..0c36846 100644 --- a/prepare.py +++ b/prepare.py @@ -3,9 +3,13 @@ # author: deng # date : 20231228 +from pathlib import Path +from shutil import rmtree +import random +import pickle import yaml -import torch +import numpy as np def prepare(params_path: str = 'params.yaml') -> None: @@ -17,7 +21,75 @@ def prepare(params_path: str = 'params.yaml') -> None: with open(params_path, 'r') as f: params = yaml.safe_load(f) + data_dir = Path(params['prepare']['data_dir']) + save_dir = Path(params['prepare']['save_dir']) + train_valid_split = params['prepare']['train_valid_split'] + random_seed = params['prepare']['random_seed'] + + train_dir = save_dir.joinpath('train') + valid_dir = save_dir.joinpath('valid') + test_dir = save_dir.joinpath('test') + + if train_dir.is_dir(): + rmtree(train_dir) + train_dir.mkdir() + if valid_dir.is_dir(): + rmtree(valid_dir) + valid_dir.mkdir() + if test_dir.is_dir(): + rmtree(test_dir) + test_dir.mkdir() + + # Process training data + ids = list(range(50000)) + random.Random(random_seed).shuffle(ids) + train_ids = ids[:int(50000 * train_valid_split[0])] + valid_ids = ids[int(50000 * train_valid_split[0]):] + + current_id, train_count, valid_count = 0, 0, 0 + cifar_10_dir = data_dir.joinpath('cifar-10-batches-py') + for data_path in cifar_10_dir.glob('data_batch_*'): + with open(data_path, 'rb') as f: + data = pickle.load(f, encoding='bytes') + for i, label in enumerate(data[b'labels']): + x = data[b'data'][i] + x = x.reshape(3, 32, 32) + x = x / 255 + x = x.astype(np.float32) # mps does not support float64 + y = np.zeros(10, dtype=np.float32) + y[label] = 1. + if current_id in train_ids: + npz_path = train_dir.joinpath(f'{train_count}.npz') + train_ids.remove(current_id) + train_count += 1 + else: + npz_path = valid_dir.joinpath(f'{valid_count}.npz') + valid_ids.remove(current_id) + valid_count += 1 + np.savez_compressed( + npz_path, + x=x, + y=y + ) + current_id += 1 + + # Process testing data + data_path = cifar_10_dir.joinpath('test_batch') + with open(data_path, 'rb') as f: + data = pickle.load(f, encoding='bytes') + for i, label in enumerate(data[b'labels']): + x = data[b'data'][i] + x = x.reshape(3, 32, 32) + x = x / 255 + x = x.astype(np.float32) + y = np.zeros(10, dtype=np.float32) + y[label] = 1. + np.savez_compressed( + save_dir.joinpath('test', f'{i}.npz'), + x=x, + y=y + ) if __name__ == '__main__': - prepare() + prepare('params.yaml') diff --git a/train.py b/train.py index fe371e0..56c4480 100644 --- a/train.py +++ b/train.py @@ -3,21 +3,97 @@ # author: deng # date : 20231228 +from pathlib import Path import yaml import torch +from rich.progress import track +from torch.utils.data import DataLoader +from torchvision.models import resnet50 +from torchmetrics.classification import MulticlassAccuracy from dvclive import Live +from utils.dataset import ProcessedDataset + def train(params_path: str = 'params.yaml') -> None: """Train a simple model using Pytorch Args: - params_path (str, optional): path of config yaml. Defaults to 'params.yaml'. + params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'. """ with open(params_path, 'r') as f: params = yaml.safe_load(f) + data_dir = Path(params['train']['data_dir']) + epochs = params['train']['epochs'] + batch_size = params['train']['batch_size'] + learning_rate = params['train']['learning_rate'] + + train_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('train')) + valid_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('valid')) + + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2) + valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=2) + + device = torch.device('mps') + net = torch.nn.Sequential( + resnet50(weights='IMAGENET1K_V1'), + torch.nn.Linear(1000, 10) + ) + net.to(device) + + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate) + metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted') + metric.to(device) + + with Live(dir='dvclive/train', report='md') as live: + live.log_params(params['train']) + + for epoch in track(range(epochs), 'Epochs'): + + net.train() + train_loss = 0. + for data in train_dataloader: + inputs, labels = data[0].to(device), data[1].to(device) + optimizer.zero_grad() + outputs = net(inputs) + loss = criterion(outputs, labels) + train_loss += loss + loss.backward() + optimizer.step() + _ = metric(outputs, labels) + train_loss /= len(train_dataloader) + train_acc = metric.compute() + metric.reset() + + net.eval() + with torch.no_grad(): + valid_loss = 0. + for data in valid_dataloader: + inputs, labels = data[0].to(device), data[1].to(device) + outputs = net(inputs) + loss = criterion(outputs, labels) + valid_loss += loss + _ = metric(outputs, labels) + valid_loss /= len(valid_dataloader) + valid_acc = metric.compute() + metric.reset() + + print( + f'Epoch {epoch} - train_loss:{train_loss} train_acc:{train_acc} ' + f'valid_loss:{valid_loss} valid_acc:{valid_acc}' + ) + live.log_param('epoch', epoch) + live.log_metric('train_loss', float(train_loss.cpu())) + live.log_metric('train_acc', float(train_acc.cpu())) + live.log_metric('valid_loss', float(valid_loss.cpu())) + live.log_metric('valid_acc', float(valid_acc.cpu())) + live.next_step() + + torch.save(net, 'model.pt') + live.log_artifact('model.pt', type='model', name='resnet50') if __name__ == '__main__': diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100644 index 0000000..08104e7 --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,26 @@ +# prepare.py +# +# author: deng +# date : 20231229 + +from pathlib import PosixPath + +import torch +import numpy as np +from torch.utils.data import Dataset + + +class ProcessedDataset(Dataset): + """"Load processed data""" + def __init__(self, dataset_dir: PosixPath): + self.dataset_dir = dataset_dir + self.file_paths = list(self.dataset_dir.glob('*.npz')) + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + npz = np.load(self.file_paths[idx]) + x = torch.from_numpy(npz['x']) + y = torch.from_numpy(npz['y']) + return x, y