init
This commit is contained in:
commit
7a891969e0
|
@ -0,0 +1,3 @@
|
|||
/config.local
|
||||
/tmp
|
||||
/cache
|
|
@ -0,0 +1,4 @@
|
|||
[core]
|
||||
remote = nas
|
||||
['remote "nas"']
|
||||
url = https://webdav.guineapig.love/home/dvc
|
|
@ -0,0 +1,3 @@
|
|||
# Add patterns of files dvc should ignore, which could improve
|
||||
# the performance. Learn more at
|
||||
# https://dvc.org/doc/user-guide/dvcignore
|
|
@ -0,0 +1,2 @@
|
|||
data
|
||||
*.pt
|
|
@ -0,0 +1,24 @@
|
|||
# Abstract
|
||||
|
||||
Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track model training with PyTorch, including data, trained model file, and used parameters. The data will be recorded and pushed to my private DVC remote via webdav🎁
|
||||
|
||||
# Requirements
|
||||
|
||||
* MacOS 13.3
|
||||
|
||||
# Dirs
|
||||
|
||||
* **env**
|
||||
* **pt.yaml**
|
||||
* conda env yaml to run this repo
|
||||
|
||||
# Files
|
||||
|
||||
* **prepare.py**
|
||||
* prepare materials for model training
|
||||
* **train.py**
|
||||
* try to train a small neural network
|
||||
* **evaluate.py**
|
||||
* evaluate trained model with some metrics
|
||||
|
||||
###### tags: `DVC`
|
|
@ -0,0 +1,28 @@
|
|||
stages:
|
||||
prepare:
|
||||
cmd: python prepare.py
|
||||
deps:
|
||||
- prepare.py
|
||||
params:
|
||||
- prepare
|
||||
outs:
|
||||
- data/processed
|
||||
train:
|
||||
cmd: python train.py
|
||||
deps:
|
||||
- data/processed
|
||||
- train.py
|
||||
params:
|
||||
- train
|
||||
outs:
|
||||
- model.pt
|
||||
evaluate:
|
||||
cmd: python evaluate.py
|
||||
deps:
|
||||
- data/processed
|
||||
- evaluate.py
|
||||
- model.pt
|
||||
params:
|
||||
- evaluate
|
||||
outs:
|
||||
- eval
|
|
@ -0,0 +1,231 @@
|
|||
name: pt
|
||||
channels:
|
||||
- pytorch
|
||||
- anaconda
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- aiohttp=3.9.1
|
||||
- aiohttp-retry=2.8.3
|
||||
- aiosignal=1.3.1
|
||||
- amqp=5.2.0
|
||||
- annotated-types=0.6.0
|
||||
- antlr-python-runtime=4.9.3
|
||||
- aom=3.7.1
|
||||
- appdirs=1.4.4
|
||||
- async-timeout=4.0.3
|
||||
- asyncssh=2.14.1
|
||||
- atk-1.0=2.38.0
|
||||
- atpublic=3.0.1
|
||||
- attrs=23.1.0
|
||||
- backports.zoneinfo=0.2.1
|
||||
- billiard=4.1.0
|
||||
- boto3=1.34.9
|
||||
- botocore=1.34.9
|
||||
- brotli-python=1.1.0
|
||||
- bzip2=1.0.8
|
||||
- ca-certificates=2023.08.22
|
||||
- cairo=1.18.0
|
||||
- celery=5.3.4
|
||||
- certifi=2023.11.17
|
||||
- cffi=1.16.0
|
||||
- charset-normalizer=3.3.2
|
||||
- click=8.1.7
|
||||
- click-didyoumean=0.3.0
|
||||
- click-plugins=1.1.1
|
||||
- click-repl=0.3.0
|
||||
- colorama=0.4.6
|
||||
- configobj=5.0.8
|
||||
- cryptography=41.0.7
|
||||
- dav1d=1.2.1
|
||||
- decorator=5.1.1
|
||||
- dictdiffer=0.9.0
|
||||
- diskcache=5.6.3
|
||||
- distro=1.8.0
|
||||
- dpath=2.1.6
|
||||
- dulwich=0.21.7
|
||||
- dvc=3.37.0
|
||||
- dvc-data=3.5.0
|
||||
- dvc-http=2.32.0
|
||||
- dvc-objects=3.0.0
|
||||
- dvc-render=1.0.0
|
||||
- dvc-studio-client=0.18.0
|
||||
- dvc-task=0.3.0
|
||||
- dvclive=3.5.1
|
||||
- entrypoints=0.4
|
||||
- expat=2.5.0
|
||||
- ffmpeg=6.1.0
|
||||
- filelock=3.13.1
|
||||
- flatten-dict=0.4.2
|
||||
- flufl.lock=7.1
|
||||
- font-ttf-dejavu-sans-mono=2.37
|
||||
- font-ttf-inconsolata=3.000
|
||||
- font-ttf-source-code-pro=2.038
|
||||
- font-ttf-ubuntu=0.83
|
||||
- fontconfig=2.14.2
|
||||
- fonts-conda-ecosystem=1
|
||||
- fonts-conda-forge=1
|
||||
- freetype=2.12.1
|
||||
- fribidi=1.0.10
|
||||
- frozenlist=1.4.1
|
||||
- fsspec=2023.12.2
|
||||
- funcy=2.0
|
||||
- future=0.18.3
|
||||
- gdk-pixbuf=2.42.10
|
||||
- gettext=0.21.1
|
||||
- giflib=5.2.1
|
||||
- gitdb=4.0.11
|
||||
- gitpython=3.1.40
|
||||
- gmp=6.3.0
|
||||
- gmpy2=2.1.2
|
||||
- gnutls=3.7.9
|
||||
- grandalf=0.7
|
||||
- graphite2=1.3.13
|
||||
- graphviz=9.0.0
|
||||
- gtk2=2.24.33
|
||||
- gto=1.6.1
|
||||
- gts=0.7.6
|
||||
- harfbuzz=8.3.0
|
||||
- hydra-core=1.3.2
|
||||
- icu=73.2
|
||||
- idna=3.6
|
||||
- importlib-metadata=7.0.1
|
||||
- importlib_resources=6.1.1
|
||||
- iterative-telemetry=0.0.8
|
||||
- jinja2=3.1.2
|
||||
- jmespath=1.0.1
|
||||
- joblib=1.2.0
|
||||
- kombu=5.3.4
|
||||
- krb5=1.21.2
|
||||
- lame=3.100
|
||||
- lcms2=2.16
|
||||
- lerc=4.0.0
|
||||
- libass=0.17.1
|
||||
- libblas=3.9.0
|
||||
- libcblas=3.9.0
|
||||
- libcxx=16.0.6
|
||||
- libdeflate=1.19
|
||||
- libedit=3.1.20191231
|
||||
- libexpat=2.5.0
|
||||
- libffi=3.4.2
|
||||
- libgd=2.3.3
|
||||
- libgfortran=5.0.0
|
||||
- libgfortran5=13.2.0
|
||||
- libgit2=1.7.1
|
||||
- libglib=2.78.3
|
||||
- libiconv=1.17
|
||||
- libidn2=2.3.4
|
||||
- libjpeg-turbo=3.0.0
|
||||
- liblapack=3.9.0
|
||||
- libopenblas=0.3.25
|
||||
- libopus=1.3.1
|
||||
- libpng=1.6.39
|
||||
- librsvg=2.56.3
|
||||
- libsqlite=3.44.2
|
||||
- libssh2=1.11.0
|
||||
- libtasn1=4.19.0
|
||||
- libtiff=4.6.0
|
||||
- libunistring=0.9.10
|
||||
- libvpx=1.13.1
|
||||
- libwebp=1.3.2
|
||||
- libwebp-base=1.3.2
|
||||
- libxcb=1.15
|
||||
- libxml2=2.12.3
|
||||
- libzlib=1.2.13
|
||||
- llvm-openmp=17.0.6
|
||||
- markdown-it-py=3.0.0
|
||||
- markupsafe=2.1.3
|
||||
- mdurl=0.1.0
|
||||
- mpc=1.3.1
|
||||
- mpfr=4.2.1
|
||||
- mpmath=1.3.0
|
||||
- multidict=6.0.4
|
||||
- nanotime=0.5.2
|
||||
- ncurses=6.4
|
||||
- nettle=3.9.1
|
||||
- networkx=3.2.1
|
||||
- numpy=1.26.2
|
||||
- omegaconf=2.3.0
|
||||
- openh264=2.4.0
|
||||
- openjpeg=2.5.0
|
||||
- openssl=3.2.0
|
||||
- orjson=3.9.10
|
||||
- p11-kit=0.24.1
|
||||
- packaging=23.2
|
||||
- pango=1.50.14
|
||||
- pathlib2=2.3.7.post1
|
||||
- pathspec=0.12.1
|
||||
- pcre2=10.42
|
||||
- pillow=10.1.0
|
||||
- pip=23.3.2
|
||||
- pixman=0.42.2
|
||||
- platformdirs=3.11.0
|
||||
- prompt-toolkit=3.0.42
|
||||
- prompt_toolkit=3.0.42
|
||||
- psutil=5.9.7
|
||||
- pthread-stubs=0.4
|
||||
- pycparser=2.21
|
||||
- pydantic=2.5.3
|
||||
- pydantic-core=2.14.6
|
||||
- pydot=1.4.2
|
||||
- pygit2=1.13.3
|
||||
- pygments=2.17.2
|
||||
- pygtrie=2.5.0
|
||||
- pyopenssl=23.3.0
|
||||
- pyparsing=3.1.1
|
||||
- pysocks=1.7.1
|
||||
- python=3.10.13
|
||||
- python-dateutil=2.8.2
|
||||
- python-gssapi=1.8.3
|
||||
- python-tzdata=2023.3
|
||||
- python_abi=3.10
|
||||
- pytorch=2.1.2
|
||||
- pytz=2023.3.post1
|
||||
- pywin32-on-windows=0.1.0
|
||||
- pyyaml=6.0.1
|
||||
- readline=8.2
|
||||
- requests=2.31.0
|
||||
- rich=13.7.0
|
||||
- ruamel.yaml=0.18.5
|
||||
- ruamel.yaml.clib=0.2.7
|
||||
- s3transfer=0.10.0
|
||||
- scikit-learn=1.3.0
|
||||
- scipy=1.11.4
|
||||
- scmrepo=2.0.2
|
||||
- semver=3.0.2
|
||||
- setuptools=68.2.2
|
||||
- shellingham=1.5.4
|
||||
- shortuuid=1.0.11
|
||||
- shtab=1.6.5
|
||||
- six=1.16.0
|
||||
- smmap=5.0.0
|
||||
- sqltrie=0.11.0
|
||||
- svt-av1=1.8.0
|
||||
- sympy=1.12
|
||||
- tabulate=0.9.0
|
||||
- threadpoolctl=2.2.0
|
||||
- tk=8.6.13
|
||||
- tomlkit=0.12.3
|
||||
- torchaudio=2.1.2
|
||||
- torchvision=0.16.2
|
||||
- tqdm=4.66.1
|
||||
- typer=0.9.0
|
||||
- typing-extensions=4.9.0
|
||||
- typing_extensions=4.9.0
|
||||
- tzdata=2023d
|
||||
- urllib3=1.26.18
|
||||
- vine=5.0.0
|
||||
- voluptuous=0.14.1
|
||||
- wcwidth=0.2.12
|
||||
- wheel=0.42.0
|
||||
- x264=1!164.3095
|
||||
- x265=3.5
|
||||
- xorg-libxau=1.0.11
|
||||
- xorg-libxdmcp=1.1.3
|
||||
- xz=5.2.6
|
||||
- yaml=0.2.5
|
||||
- yarl=1.9.3
|
||||
- zc.lockfile=3.0.post1
|
||||
- zipp=3.17.0
|
||||
- zlib=1.2.13
|
||||
- zstd=1.5.5
|
||||
prefix: /Users/xiao_deng/miniforge3/envs/pt
|
|
@ -0,0 +1,23 @@
|
|||
# evaluate.py
|
||||
#
|
||||
# author: deng
|
||||
# date : 20231228
|
||||
|
||||
import yaml
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def evaluate(params_path: str = 'params.yaml') -> None:
|
||||
"""Evaluate model and save results to eval dir
|
||||
|
||||
Args:
|
||||
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
|
||||
"""
|
||||
|
||||
with open(params_path, 'r') as f:
|
||||
params = yaml.safe_load(f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
evaluate()
|
|
@ -0,0 +1,3 @@
|
|||
prepare:
|
||||
train:
|
||||
evaluate:
|
|
@ -0,0 +1,23 @@
|
|||
# prepare.py
|
||||
#
|
||||
# author: deng
|
||||
# date : 20231228
|
||||
|
||||
import yaml
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def prepare(params_path: str = 'params.yaml') -> None:
|
||||
"""Preprocess data save as npz for model training
|
||||
|
||||
Args:
|
||||
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
|
||||
"""
|
||||
|
||||
with open(params_path, 'r') as f:
|
||||
params = yaml.safe_load(f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
prepare()
|
|
@ -0,0 +1,24 @@
|
|||
# train.py
|
||||
#
|
||||
# author: deng
|
||||
# date : 20231228
|
||||
|
||||
import yaml
|
||||
|
||||
import torch
|
||||
from dvclive import Live
|
||||
|
||||
|
||||
def train(params_path: str = 'params.yaml') -> None:
|
||||
"""Train a simple model using Pytorch
|
||||
|
||||
Args:
|
||||
params_path (str, optional): path of config yaml. Defaults to 'params.yaml'.
|
||||
"""
|
||||
|
||||
with open(params_path, 'r') as f:
|
||||
params = yaml.safe_load(f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
Loading…
Reference in New Issue