This commit is contained in:
deng 2023-12-28 22:06:25 +08:00
commit 7a891969e0
11 changed files with 368 additions and 0 deletions

3
.dvc/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
/config.local
/tmp
/cache

4
.dvc/config Normal file
View File

@ -0,0 +1,4 @@
[core]
remote = nas
['remote "nas"']
url = https://webdav.guineapig.love/home/dvc

3
.dvcignore Normal file
View File

@ -0,0 +1,3 @@
# Add patterns of files dvc should ignore, which could improve
# the performance. Learn more at
# https://dvc.org/doc/user-guide/dvcignore

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
data
*.pt

24
README.md Normal file
View File

@ -0,0 +1,24 @@
# Abstract
Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track model training with PyTorch, including data, trained model file, and used parameters. The data will be recorded and pushed to my private DVC remote via webdav🎁
# Requirements
* MacOS 13.3
# Dirs
* **env**
* **pt.yaml**
* conda env yaml to run this repo
# Files
* **prepare.py**
* prepare materials for model training
* **train.py**
* try to train a small neural network
* **evaluate.py**
* evaluate trained model with some metrics
###### tags: `DVC`

28
dvc.yaml Normal file
View File

@ -0,0 +1,28 @@
stages:
prepare:
cmd: python prepare.py
deps:
- prepare.py
params:
- prepare
outs:
- data/processed
train:
cmd: python train.py
deps:
- data/processed
- train.py
params:
- train
outs:
- model.pt
evaluate:
cmd: python evaluate.py
deps:
- data/processed
- evaluate.py
- model.pt
params:
- evaluate
outs:
- eval

231
env/pt.yaml vendored Normal file
View File

@ -0,0 +1,231 @@
name: pt
channels:
- pytorch
- anaconda
- conda-forge
dependencies:
- aiohttp=3.9.1
- aiohttp-retry=2.8.3
- aiosignal=1.3.1
- amqp=5.2.0
- annotated-types=0.6.0
- antlr-python-runtime=4.9.3
- aom=3.7.1
- appdirs=1.4.4
- async-timeout=4.0.3
- asyncssh=2.14.1
- atk-1.0=2.38.0
- atpublic=3.0.1
- attrs=23.1.0
- backports.zoneinfo=0.2.1
- billiard=4.1.0
- boto3=1.34.9
- botocore=1.34.9
- brotli-python=1.1.0
- bzip2=1.0.8
- ca-certificates=2023.08.22
- cairo=1.18.0
- celery=5.3.4
- certifi=2023.11.17
- cffi=1.16.0
- charset-normalizer=3.3.2
- click=8.1.7
- click-didyoumean=0.3.0
- click-plugins=1.1.1
- click-repl=0.3.0
- colorama=0.4.6
- configobj=5.0.8
- cryptography=41.0.7
- dav1d=1.2.1
- decorator=5.1.1
- dictdiffer=0.9.0
- diskcache=5.6.3
- distro=1.8.0
- dpath=2.1.6
- dulwich=0.21.7
- dvc=3.37.0
- dvc-data=3.5.0
- dvc-http=2.32.0
- dvc-objects=3.0.0
- dvc-render=1.0.0
- dvc-studio-client=0.18.0
- dvc-task=0.3.0
- dvclive=3.5.1
- entrypoints=0.4
- expat=2.5.0
- ffmpeg=6.1.0
- filelock=3.13.1
- flatten-dict=0.4.2
- flufl.lock=7.1
- font-ttf-dejavu-sans-mono=2.37
- font-ttf-inconsolata=3.000
- font-ttf-source-code-pro=2.038
- font-ttf-ubuntu=0.83
- fontconfig=2.14.2
- fonts-conda-ecosystem=1
- fonts-conda-forge=1
- freetype=2.12.1
- fribidi=1.0.10
- frozenlist=1.4.1
- fsspec=2023.12.2
- funcy=2.0
- future=0.18.3
- gdk-pixbuf=2.42.10
- gettext=0.21.1
- giflib=5.2.1
- gitdb=4.0.11
- gitpython=3.1.40
- gmp=6.3.0
- gmpy2=2.1.2
- gnutls=3.7.9
- grandalf=0.7
- graphite2=1.3.13
- graphviz=9.0.0
- gtk2=2.24.33
- gto=1.6.1
- gts=0.7.6
- harfbuzz=8.3.0
- hydra-core=1.3.2
- icu=73.2
- idna=3.6
- importlib-metadata=7.0.1
- importlib_resources=6.1.1
- iterative-telemetry=0.0.8
- jinja2=3.1.2
- jmespath=1.0.1
- joblib=1.2.0
- kombu=5.3.4
- krb5=1.21.2
- lame=3.100
- lcms2=2.16
- lerc=4.0.0
- libass=0.17.1
- libblas=3.9.0
- libcblas=3.9.0
- libcxx=16.0.6
- libdeflate=1.19
- libedit=3.1.20191231
- libexpat=2.5.0
- libffi=3.4.2
- libgd=2.3.3
- libgfortran=5.0.0
- libgfortran5=13.2.0
- libgit2=1.7.1
- libglib=2.78.3
- libiconv=1.17
- libidn2=2.3.4
- libjpeg-turbo=3.0.0
- liblapack=3.9.0
- libopenblas=0.3.25
- libopus=1.3.1
- libpng=1.6.39
- librsvg=2.56.3
- libsqlite=3.44.2
- libssh2=1.11.0
- libtasn1=4.19.0
- libtiff=4.6.0
- libunistring=0.9.10
- libvpx=1.13.1
- libwebp=1.3.2
- libwebp-base=1.3.2
- libxcb=1.15
- libxml2=2.12.3
- libzlib=1.2.13
- llvm-openmp=17.0.6
- markdown-it-py=3.0.0
- markupsafe=2.1.3
- mdurl=0.1.0
- mpc=1.3.1
- mpfr=4.2.1
- mpmath=1.3.0
- multidict=6.0.4
- nanotime=0.5.2
- ncurses=6.4
- nettle=3.9.1
- networkx=3.2.1
- numpy=1.26.2
- omegaconf=2.3.0
- openh264=2.4.0
- openjpeg=2.5.0
- openssl=3.2.0
- orjson=3.9.10
- p11-kit=0.24.1
- packaging=23.2
- pango=1.50.14
- pathlib2=2.3.7.post1
- pathspec=0.12.1
- pcre2=10.42
- pillow=10.1.0
- pip=23.3.2
- pixman=0.42.2
- platformdirs=3.11.0
- prompt-toolkit=3.0.42
- prompt_toolkit=3.0.42
- psutil=5.9.7
- pthread-stubs=0.4
- pycparser=2.21
- pydantic=2.5.3
- pydantic-core=2.14.6
- pydot=1.4.2
- pygit2=1.13.3
- pygments=2.17.2
- pygtrie=2.5.0
- pyopenssl=23.3.0
- pyparsing=3.1.1
- pysocks=1.7.1
- python=3.10.13
- python-dateutil=2.8.2
- python-gssapi=1.8.3
- python-tzdata=2023.3
- python_abi=3.10
- pytorch=2.1.2
- pytz=2023.3.post1
- pywin32-on-windows=0.1.0
- pyyaml=6.0.1
- readline=8.2
- requests=2.31.0
- rich=13.7.0
- ruamel.yaml=0.18.5
- ruamel.yaml.clib=0.2.7
- s3transfer=0.10.0
- scikit-learn=1.3.0
- scipy=1.11.4
- scmrepo=2.0.2
- semver=3.0.2
- setuptools=68.2.2
- shellingham=1.5.4
- shortuuid=1.0.11
- shtab=1.6.5
- six=1.16.0
- smmap=5.0.0
- sqltrie=0.11.0
- svt-av1=1.8.0
- sympy=1.12
- tabulate=0.9.0
- threadpoolctl=2.2.0
- tk=8.6.13
- tomlkit=0.12.3
- torchaudio=2.1.2
- torchvision=0.16.2
- tqdm=4.66.1
- typer=0.9.0
- typing-extensions=4.9.0
- typing_extensions=4.9.0
- tzdata=2023d
- urllib3=1.26.18
- vine=5.0.0
- voluptuous=0.14.1
- wcwidth=0.2.12
- wheel=0.42.0
- x264=1!164.3095
- x265=3.5
- xorg-libxau=1.0.11
- xorg-libxdmcp=1.1.3
- xz=5.2.6
- yaml=0.2.5
- yarl=1.9.3
- zc.lockfile=3.0.post1
- zipp=3.17.0
- zlib=1.2.13
- zstd=1.5.5
prefix: /Users/xiao_deng/miniforge3/envs/pt

23
evaluate.py Normal file
View File

@ -0,0 +1,23 @@
# evaluate.py
#
# author: deng
# date : 20231228
import yaml
import torch
def evaluate(params_path: str = 'params.yaml') -> None:
"""Evaluate model and save results to eval dir
Args:
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
"""
with open(params_path, 'r') as f:
params = yaml.safe_load(f)
if __name__ == '__main__':
evaluate()

3
params.yaml Normal file
View File

@ -0,0 +1,3 @@
prepare:
train:
evaluate:

23
prepare.py Normal file
View File

@ -0,0 +1,23 @@
# prepare.py
#
# author: deng
# date : 20231228
import yaml
import torch
def prepare(params_path: str = 'params.yaml') -> None:
"""Preprocess data save as npz for model training
Args:
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
"""
with open(params_path, 'r') as f:
params = yaml.safe_load(f)
if __name__ == '__main__':
prepare()

24
train.py Normal file
View File

@ -0,0 +1,24 @@
# train.py
#
# author: deng
# date : 20231228
import yaml
import torch
from dvclive import Live
def train(params_path: str = 'params.yaml') -> None:
"""Train a simple model using Pytorch
Args:
params_path (str, optional): path of config yaml. Defaults to 'params.yaml'.
"""
with open(params_path, 'r') as f:
params = yaml.safe_load(f)
if __name__ == '__main__':
train()