init
This commit is contained in:
3
.dvc/.gitignore
vendored
Normal file
3
.dvc/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
/config.local
|
||||||
|
/tmp
|
||||||
|
/cache
|
4
.dvc/config
Normal file
4
.dvc/config
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[core]
|
||||||
|
remote = nas
|
||||||
|
['remote "nas"']
|
||||||
|
url = https://webdav.guineapig.love/home/dvc
|
3
.dvcignore
Normal file
3
.dvcignore
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Add patterns of files dvc should ignore, which could improve
|
||||||
|
# the performance. Learn more at
|
||||||
|
# https://dvc.org/doc/user-guide/dvcignore
|
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
data
|
||||||
|
*.pt
|
24
README.md
Normal file
24
README.md
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# Abstract
|
||||||
|
|
||||||
|
Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track model training with PyTorch, including data, trained model file, and used parameters. The data will be recorded and pushed to my private DVC remote via webdav🎁
|
||||||
|
|
||||||
|
# Requirements
|
||||||
|
|
||||||
|
* MacOS 13.3
|
||||||
|
|
||||||
|
# Dirs
|
||||||
|
|
||||||
|
* **env**
|
||||||
|
* **pt.yaml**
|
||||||
|
* conda env yaml to run this repo
|
||||||
|
|
||||||
|
# Files
|
||||||
|
|
||||||
|
* **prepare.py**
|
||||||
|
* prepare materials for model training
|
||||||
|
* **train.py**
|
||||||
|
* try to train a small neural network
|
||||||
|
* **evaluate.py**
|
||||||
|
* evaluate trained model with some metrics
|
||||||
|
|
||||||
|
###### tags: `DVC`
|
28
dvc.yaml
Normal file
28
dvc.yaml
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
stages:
|
||||||
|
prepare:
|
||||||
|
cmd: python prepare.py
|
||||||
|
deps:
|
||||||
|
- prepare.py
|
||||||
|
params:
|
||||||
|
- prepare
|
||||||
|
outs:
|
||||||
|
- data/processed
|
||||||
|
train:
|
||||||
|
cmd: python train.py
|
||||||
|
deps:
|
||||||
|
- data/processed
|
||||||
|
- train.py
|
||||||
|
params:
|
||||||
|
- train
|
||||||
|
outs:
|
||||||
|
- model.pt
|
||||||
|
evaluate:
|
||||||
|
cmd: python evaluate.py
|
||||||
|
deps:
|
||||||
|
- data/processed
|
||||||
|
- evaluate.py
|
||||||
|
- model.pt
|
||||||
|
params:
|
||||||
|
- evaluate
|
||||||
|
outs:
|
||||||
|
- eval
|
231
env/pt.yaml
vendored
Normal file
231
env/pt.yaml
vendored
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
name: pt
|
||||||
|
channels:
|
||||||
|
- pytorch
|
||||||
|
- anaconda
|
||||||
|
- conda-forge
|
||||||
|
dependencies:
|
||||||
|
- aiohttp=3.9.1
|
||||||
|
- aiohttp-retry=2.8.3
|
||||||
|
- aiosignal=1.3.1
|
||||||
|
- amqp=5.2.0
|
||||||
|
- annotated-types=0.6.0
|
||||||
|
- antlr-python-runtime=4.9.3
|
||||||
|
- aom=3.7.1
|
||||||
|
- appdirs=1.4.4
|
||||||
|
- async-timeout=4.0.3
|
||||||
|
- asyncssh=2.14.1
|
||||||
|
- atk-1.0=2.38.0
|
||||||
|
- atpublic=3.0.1
|
||||||
|
- attrs=23.1.0
|
||||||
|
- backports.zoneinfo=0.2.1
|
||||||
|
- billiard=4.1.0
|
||||||
|
- boto3=1.34.9
|
||||||
|
- botocore=1.34.9
|
||||||
|
- brotli-python=1.1.0
|
||||||
|
- bzip2=1.0.8
|
||||||
|
- ca-certificates=2023.08.22
|
||||||
|
- cairo=1.18.0
|
||||||
|
- celery=5.3.4
|
||||||
|
- certifi=2023.11.17
|
||||||
|
- cffi=1.16.0
|
||||||
|
- charset-normalizer=3.3.2
|
||||||
|
- click=8.1.7
|
||||||
|
- click-didyoumean=0.3.0
|
||||||
|
- click-plugins=1.1.1
|
||||||
|
- click-repl=0.3.0
|
||||||
|
- colorama=0.4.6
|
||||||
|
- configobj=5.0.8
|
||||||
|
- cryptography=41.0.7
|
||||||
|
- dav1d=1.2.1
|
||||||
|
- decorator=5.1.1
|
||||||
|
- dictdiffer=0.9.0
|
||||||
|
- diskcache=5.6.3
|
||||||
|
- distro=1.8.0
|
||||||
|
- dpath=2.1.6
|
||||||
|
- dulwich=0.21.7
|
||||||
|
- dvc=3.37.0
|
||||||
|
- dvc-data=3.5.0
|
||||||
|
- dvc-http=2.32.0
|
||||||
|
- dvc-objects=3.0.0
|
||||||
|
- dvc-render=1.0.0
|
||||||
|
- dvc-studio-client=0.18.0
|
||||||
|
- dvc-task=0.3.0
|
||||||
|
- dvclive=3.5.1
|
||||||
|
- entrypoints=0.4
|
||||||
|
- expat=2.5.0
|
||||||
|
- ffmpeg=6.1.0
|
||||||
|
- filelock=3.13.1
|
||||||
|
- flatten-dict=0.4.2
|
||||||
|
- flufl.lock=7.1
|
||||||
|
- font-ttf-dejavu-sans-mono=2.37
|
||||||
|
- font-ttf-inconsolata=3.000
|
||||||
|
- font-ttf-source-code-pro=2.038
|
||||||
|
- font-ttf-ubuntu=0.83
|
||||||
|
- fontconfig=2.14.2
|
||||||
|
- fonts-conda-ecosystem=1
|
||||||
|
- fonts-conda-forge=1
|
||||||
|
- freetype=2.12.1
|
||||||
|
- fribidi=1.0.10
|
||||||
|
- frozenlist=1.4.1
|
||||||
|
- fsspec=2023.12.2
|
||||||
|
- funcy=2.0
|
||||||
|
- future=0.18.3
|
||||||
|
- gdk-pixbuf=2.42.10
|
||||||
|
- gettext=0.21.1
|
||||||
|
- giflib=5.2.1
|
||||||
|
- gitdb=4.0.11
|
||||||
|
- gitpython=3.1.40
|
||||||
|
- gmp=6.3.0
|
||||||
|
- gmpy2=2.1.2
|
||||||
|
- gnutls=3.7.9
|
||||||
|
- grandalf=0.7
|
||||||
|
- graphite2=1.3.13
|
||||||
|
- graphviz=9.0.0
|
||||||
|
- gtk2=2.24.33
|
||||||
|
- gto=1.6.1
|
||||||
|
- gts=0.7.6
|
||||||
|
- harfbuzz=8.3.0
|
||||||
|
- hydra-core=1.3.2
|
||||||
|
- icu=73.2
|
||||||
|
- idna=3.6
|
||||||
|
- importlib-metadata=7.0.1
|
||||||
|
- importlib_resources=6.1.1
|
||||||
|
- iterative-telemetry=0.0.8
|
||||||
|
- jinja2=3.1.2
|
||||||
|
- jmespath=1.0.1
|
||||||
|
- joblib=1.2.0
|
||||||
|
- kombu=5.3.4
|
||||||
|
- krb5=1.21.2
|
||||||
|
- lame=3.100
|
||||||
|
- lcms2=2.16
|
||||||
|
- lerc=4.0.0
|
||||||
|
- libass=0.17.1
|
||||||
|
- libblas=3.9.0
|
||||||
|
- libcblas=3.9.0
|
||||||
|
- libcxx=16.0.6
|
||||||
|
- libdeflate=1.19
|
||||||
|
- libedit=3.1.20191231
|
||||||
|
- libexpat=2.5.0
|
||||||
|
- libffi=3.4.2
|
||||||
|
- libgd=2.3.3
|
||||||
|
- libgfortran=5.0.0
|
||||||
|
- libgfortran5=13.2.0
|
||||||
|
- libgit2=1.7.1
|
||||||
|
- libglib=2.78.3
|
||||||
|
- libiconv=1.17
|
||||||
|
- libidn2=2.3.4
|
||||||
|
- libjpeg-turbo=3.0.0
|
||||||
|
- liblapack=3.9.0
|
||||||
|
- libopenblas=0.3.25
|
||||||
|
- libopus=1.3.1
|
||||||
|
- libpng=1.6.39
|
||||||
|
- librsvg=2.56.3
|
||||||
|
- libsqlite=3.44.2
|
||||||
|
- libssh2=1.11.0
|
||||||
|
- libtasn1=4.19.0
|
||||||
|
- libtiff=4.6.0
|
||||||
|
- libunistring=0.9.10
|
||||||
|
- libvpx=1.13.1
|
||||||
|
- libwebp=1.3.2
|
||||||
|
- libwebp-base=1.3.2
|
||||||
|
- libxcb=1.15
|
||||||
|
- libxml2=2.12.3
|
||||||
|
- libzlib=1.2.13
|
||||||
|
- llvm-openmp=17.0.6
|
||||||
|
- markdown-it-py=3.0.0
|
||||||
|
- markupsafe=2.1.3
|
||||||
|
- mdurl=0.1.0
|
||||||
|
- mpc=1.3.1
|
||||||
|
- mpfr=4.2.1
|
||||||
|
- mpmath=1.3.0
|
||||||
|
- multidict=6.0.4
|
||||||
|
- nanotime=0.5.2
|
||||||
|
- ncurses=6.4
|
||||||
|
- nettle=3.9.1
|
||||||
|
- networkx=3.2.1
|
||||||
|
- numpy=1.26.2
|
||||||
|
- omegaconf=2.3.0
|
||||||
|
- openh264=2.4.0
|
||||||
|
- openjpeg=2.5.0
|
||||||
|
- openssl=3.2.0
|
||||||
|
- orjson=3.9.10
|
||||||
|
- p11-kit=0.24.1
|
||||||
|
- packaging=23.2
|
||||||
|
- pango=1.50.14
|
||||||
|
- pathlib2=2.3.7.post1
|
||||||
|
- pathspec=0.12.1
|
||||||
|
- pcre2=10.42
|
||||||
|
- pillow=10.1.0
|
||||||
|
- pip=23.3.2
|
||||||
|
- pixman=0.42.2
|
||||||
|
- platformdirs=3.11.0
|
||||||
|
- prompt-toolkit=3.0.42
|
||||||
|
- prompt_toolkit=3.0.42
|
||||||
|
- psutil=5.9.7
|
||||||
|
- pthread-stubs=0.4
|
||||||
|
- pycparser=2.21
|
||||||
|
- pydantic=2.5.3
|
||||||
|
- pydantic-core=2.14.6
|
||||||
|
- pydot=1.4.2
|
||||||
|
- pygit2=1.13.3
|
||||||
|
- pygments=2.17.2
|
||||||
|
- pygtrie=2.5.0
|
||||||
|
- pyopenssl=23.3.0
|
||||||
|
- pyparsing=3.1.1
|
||||||
|
- pysocks=1.7.1
|
||||||
|
- python=3.10.13
|
||||||
|
- python-dateutil=2.8.2
|
||||||
|
- python-gssapi=1.8.3
|
||||||
|
- python-tzdata=2023.3
|
||||||
|
- python_abi=3.10
|
||||||
|
- pytorch=2.1.2
|
||||||
|
- pytz=2023.3.post1
|
||||||
|
- pywin32-on-windows=0.1.0
|
||||||
|
- pyyaml=6.0.1
|
||||||
|
- readline=8.2
|
||||||
|
- requests=2.31.0
|
||||||
|
- rich=13.7.0
|
||||||
|
- ruamel.yaml=0.18.5
|
||||||
|
- ruamel.yaml.clib=0.2.7
|
||||||
|
- s3transfer=0.10.0
|
||||||
|
- scikit-learn=1.3.0
|
||||||
|
- scipy=1.11.4
|
||||||
|
- scmrepo=2.0.2
|
||||||
|
- semver=3.0.2
|
||||||
|
- setuptools=68.2.2
|
||||||
|
- shellingham=1.5.4
|
||||||
|
- shortuuid=1.0.11
|
||||||
|
- shtab=1.6.5
|
||||||
|
- six=1.16.0
|
||||||
|
- smmap=5.0.0
|
||||||
|
- sqltrie=0.11.0
|
||||||
|
- svt-av1=1.8.0
|
||||||
|
- sympy=1.12
|
||||||
|
- tabulate=0.9.0
|
||||||
|
- threadpoolctl=2.2.0
|
||||||
|
- tk=8.6.13
|
||||||
|
- tomlkit=0.12.3
|
||||||
|
- torchaudio=2.1.2
|
||||||
|
- torchvision=0.16.2
|
||||||
|
- tqdm=4.66.1
|
||||||
|
- typer=0.9.0
|
||||||
|
- typing-extensions=4.9.0
|
||||||
|
- typing_extensions=4.9.0
|
||||||
|
- tzdata=2023d
|
||||||
|
- urllib3=1.26.18
|
||||||
|
- vine=5.0.0
|
||||||
|
- voluptuous=0.14.1
|
||||||
|
- wcwidth=0.2.12
|
||||||
|
- wheel=0.42.0
|
||||||
|
- x264=1!164.3095
|
||||||
|
- x265=3.5
|
||||||
|
- xorg-libxau=1.0.11
|
||||||
|
- xorg-libxdmcp=1.1.3
|
||||||
|
- xz=5.2.6
|
||||||
|
- yaml=0.2.5
|
||||||
|
- yarl=1.9.3
|
||||||
|
- zc.lockfile=3.0.post1
|
||||||
|
- zipp=3.17.0
|
||||||
|
- zlib=1.2.13
|
||||||
|
- zstd=1.5.5
|
||||||
|
prefix: /Users/xiao_deng/miniforge3/envs/pt
|
23
evaluate.py
Normal file
23
evaluate.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# evaluate.py
|
||||||
|
#
|
||||||
|
# author: deng
|
||||||
|
# date : 20231228
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(params_path: str = 'params.yaml') -> None:
|
||||||
|
"""Evaluate model and save results to eval dir
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open(params_path, 'r') as f:
|
||||||
|
params = yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
evaluate()
|
3
params.yaml
Normal file
3
params.yaml
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
prepare:
|
||||||
|
train:
|
||||||
|
evaluate:
|
23
prepare.py
Normal file
23
prepare.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# prepare.py
|
||||||
|
#
|
||||||
|
# author: deng
|
||||||
|
# date : 20231228
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def prepare(params_path: str = 'params.yaml') -> None:
|
||||||
|
"""Preprocess data save as npz for model training
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open(params_path, 'r') as f:
|
||||||
|
params = yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
prepare()
|
24
train.py
Normal file
24
train.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# train.py
|
||||||
|
#
|
||||||
|
# author: deng
|
||||||
|
# date : 20231228
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from dvclive import Live
|
||||||
|
|
||||||
|
|
||||||
|
def train(params_path: str = 'params.yaml') -> None:
|
||||||
|
"""Train a simple model using Pytorch
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params_path (str, optional): path of config yaml. Defaults to 'params.yaml'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open(params_path, 'r') as f:
|
||||||
|
params = yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
train()
|
Reference in New Issue
Block a user