commit 7a891969e044bac17c8d26792ae7314edc610647 Author: deng Date: Thu Dec 28 22:06:25 2023 +0800 init diff --git a/.dvc/.gitignore b/.dvc/.gitignore new file mode 100644 index 0000000..528f30c --- /dev/null +++ b/.dvc/.gitignore @@ -0,0 +1,3 @@ +/config.local +/tmp +/cache diff --git a/.dvc/config b/.dvc/config new file mode 100644 index 0000000..1ac08f6 --- /dev/null +++ b/.dvc/config @@ -0,0 +1,4 @@ +[core] + remote = nas +['remote "nas"'] + url = https://webdav.guineapig.love/home/dvc diff --git a/.dvcignore b/.dvcignore new file mode 100644 index 0000000..5197305 --- /dev/null +++ b/.dvcignore @@ -0,0 +1,3 @@ +# Add patterns of files dvc should ignore, which could improve +# the performance. Learn more at +# https://dvc.org/doc/user-guide/dvcignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1ffdb32 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +data +*.pt \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..659abf8 --- /dev/null +++ b/README.md @@ -0,0 +1,24 @@ +# Abstract + +Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track model training with PyTorch, including data, trained model file, and used parameters. The data will be recorded and pushed to my private DVC remote via webdavšŸŽ + +# Requirements + +* MacOS 13.3 + +# Dirs + +* **env** + * **pt.yaml** + * conda env yaml to run this repo + +# Files + +* **prepare.py** + * prepare materials for model training +* **train.py** + * try to train a small neural network +* **evaluate.py** + * evaluate trained model with some metrics + +###### tags: `DVC` \ No newline at end of file diff --git a/dvc.yaml b/dvc.yaml new file mode 100644 index 0000000..ccf68d4 --- /dev/null +++ b/dvc.yaml @@ -0,0 +1,28 @@ +stages: + prepare: + cmd: python prepare.py + deps: + - prepare.py + params: + - prepare + outs: + - data/processed + train: + cmd: python train.py + deps: + - data/processed + - train.py + params: + - train + outs: + - model.pt + evaluate: + cmd: python evaluate.py + deps: + - data/processed + - evaluate.py + - model.pt + params: + - evaluate + outs: + - eval \ No newline at end of file diff --git a/env/pt.yaml b/env/pt.yaml new file mode 100644 index 0000000..1f32466 --- /dev/null +++ b/env/pt.yaml @@ -0,0 +1,231 @@ +name: pt +channels: + - pytorch + - anaconda + - conda-forge +dependencies: + - aiohttp=3.9.1 + - aiohttp-retry=2.8.3 + - aiosignal=1.3.1 + - amqp=5.2.0 + - annotated-types=0.6.0 + - antlr-python-runtime=4.9.3 + - aom=3.7.1 + - appdirs=1.4.4 + - async-timeout=4.0.3 + - asyncssh=2.14.1 + - atk-1.0=2.38.0 + - atpublic=3.0.1 + - attrs=23.1.0 + - backports.zoneinfo=0.2.1 + - billiard=4.1.0 + - boto3=1.34.9 + - botocore=1.34.9 + - brotli-python=1.1.0 + - bzip2=1.0.8 + - ca-certificates=2023.08.22 + - cairo=1.18.0 + - celery=5.3.4 + - certifi=2023.11.17 + - cffi=1.16.0 + - charset-normalizer=3.3.2 + - click=8.1.7 + - click-didyoumean=0.3.0 + - click-plugins=1.1.1 + - click-repl=0.3.0 + - colorama=0.4.6 + - configobj=5.0.8 + - cryptography=41.0.7 + - dav1d=1.2.1 + - decorator=5.1.1 + - dictdiffer=0.9.0 + - diskcache=5.6.3 + - distro=1.8.0 + - dpath=2.1.6 + - dulwich=0.21.7 + - dvc=3.37.0 + - dvc-data=3.5.0 + - dvc-http=2.32.0 + - dvc-objects=3.0.0 + - dvc-render=1.0.0 + - dvc-studio-client=0.18.0 + - dvc-task=0.3.0 + - dvclive=3.5.1 + - entrypoints=0.4 + - expat=2.5.0 + - ffmpeg=6.1.0 + - filelock=3.13.1 + - flatten-dict=0.4.2 + - flufl.lock=7.1 + - font-ttf-dejavu-sans-mono=2.37 + - font-ttf-inconsolata=3.000 + - font-ttf-source-code-pro=2.038 + - font-ttf-ubuntu=0.83 + - fontconfig=2.14.2 + - fonts-conda-ecosystem=1 + - fonts-conda-forge=1 + - freetype=2.12.1 + - fribidi=1.0.10 + - frozenlist=1.4.1 + - fsspec=2023.12.2 + - funcy=2.0 + - future=0.18.3 + - gdk-pixbuf=2.42.10 + - gettext=0.21.1 + - giflib=5.2.1 + - gitdb=4.0.11 + - gitpython=3.1.40 + - gmp=6.3.0 + - gmpy2=2.1.2 + - gnutls=3.7.9 + - grandalf=0.7 + - graphite2=1.3.13 + - graphviz=9.0.0 + - gtk2=2.24.33 + - gto=1.6.1 + - gts=0.7.6 + - harfbuzz=8.3.0 + - hydra-core=1.3.2 + - icu=73.2 + - idna=3.6 + - importlib-metadata=7.0.1 + - importlib_resources=6.1.1 + - iterative-telemetry=0.0.8 + - jinja2=3.1.2 + - jmespath=1.0.1 + - joblib=1.2.0 + - kombu=5.3.4 + - krb5=1.21.2 + - lame=3.100 + - lcms2=2.16 + - lerc=4.0.0 + - libass=0.17.1 + - libblas=3.9.0 + - libcblas=3.9.0 + - libcxx=16.0.6 + - libdeflate=1.19 + - libedit=3.1.20191231 + - libexpat=2.5.0 + - libffi=3.4.2 + - libgd=2.3.3 + - libgfortran=5.0.0 + - libgfortran5=13.2.0 + - libgit2=1.7.1 + - libglib=2.78.3 + - libiconv=1.17 + - libidn2=2.3.4 + - libjpeg-turbo=3.0.0 + - liblapack=3.9.0 + - libopenblas=0.3.25 + - libopus=1.3.1 + - libpng=1.6.39 + - librsvg=2.56.3 + - libsqlite=3.44.2 + - libssh2=1.11.0 + - libtasn1=4.19.0 + - libtiff=4.6.0 + - libunistring=0.9.10 + - libvpx=1.13.1 + - libwebp=1.3.2 + - libwebp-base=1.3.2 + - libxcb=1.15 + - libxml2=2.12.3 + - libzlib=1.2.13 + - llvm-openmp=17.0.6 + - markdown-it-py=3.0.0 + - markupsafe=2.1.3 + - mdurl=0.1.0 + - mpc=1.3.1 + - mpfr=4.2.1 + - mpmath=1.3.0 + - multidict=6.0.4 + - nanotime=0.5.2 + - ncurses=6.4 + - nettle=3.9.1 + - networkx=3.2.1 + - numpy=1.26.2 + - omegaconf=2.3.0 + - openh264=2.4.0 + - openjpeg=2.5.0 + - openssl=3.2.0 + - orjson=3.9.10 + - p11-kit=0.24.1 + - packaging=23.2 + - pango=1.50.14 + - pathlib2=2.3.7.post1 + - pathspec=0.12.1 + - pcre2=10.42 + - pillow=10.1.0 + - pip=23.3.2 + - pixman=0.42.2 + - platformdirs=3.11.0 + - prompt-toolkit=3.0.42 + - prompt_toolkit=3.0.42 + - psutil=5.9.7 + - pthread-stubs=0.4 + - pycparser=2.21 + - pydantic=2.5.3 + - pydantic-core=2.14.6 + - pydot=1.4.2 + - pygit2=1.13.3 + - pygments=2.17.2 + - pygtrie=2.5.0 + - pyopenssl=23.3.0 + - pyparsing=3.1.1 + - pysocks=1.7.1 + - python=3.10.13 + - python-dateutil=2.8.2 + - python-gssapi=1.8.3 + - python-tzdata=2023.3 + - python_abi=3.10 + - pytorch=2.1.2 + - pytz=2023.3.post1 + - pywin32-on-windows=0.1.0 + - pyyaml=6.0.1 + - readline=8.2 + - requests=2.31.0 + - rich=13.7.0 + - ruamel.yaml=0.18.5 + - ruamel.yaml.clib=0.2.7 + - s3transfer=0.10.0 + - scikit-learn=1.3.0 + - scipy=1.11.4 + - scmrepo=2.0.2 + - semver=3.0.2 + - setuptools=68.2.2 + - shellingham=1.5.4 + - shortuuid=1.0.11 + - shtab=1.6.5 + - six=1.16.0 + - smmap=5.0.0 + - sqltrie=0.11.0 + - svt-av1=1.8.0 + - sympy=1.12 + - tabulate=0.9.0 + - threadpoolctl=2.2.0 + - tk=8.6.13 + - tomlkit=0.12.3 + - torchaudio=2.1.2 + - torchvision=0.16.2 + - tqdm=4.66.1 + - typer=0.9.0 + - typing-extensions=4.9.0 + - typing_extensions=4.9.0 + - tzdata=2023d + - urllib3=1.26.18 + - vine=5.0.0 + - voluptuous=0.14.1 + - wcwidth=0.2.12 + - wheel=0.42.0 + - x264=1!164.3095 + - x265=3.5 + - xorg-libxau=1.0.11 + - xorg-libxdmcp=1.1.3 + - xz=5.2.6 + - yaml=0.2.5 + - yarl=1.9.3 + - zc.lockfile=3.0.post1 + - zipp=3.17.0 + - zlib=1.2.13 + - zstd=1.5.5 +prefix: /Users/xiao_deng/miniforge3/envs/pt diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..a1c9852 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,23 @@ +# evaluate.py +# +# author: deng +# date : 20231228 + +import yaml + +import torch + + +def evaluate(params_path: str = 'params.yaml') -> None: + """Evaluate model and save results to eval dir + + Args: + params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'. + """ + + with open(params_path, 'r') as f: + params = yaml.safe_load(f) + + +if __name__ == '__main__': + evaluate() diff --git a/params.yaml b/params.yaml new file mode 100644 index 0000000..80f5bb5 --- /dev/null +++ b/params.yaml @@ -0,0 +1,3 @@ +prepare: +train: +evaluate: \ No newline at end of file diff --git a/prepare.py b/prepare.py new file mode 100644 index 0000000..b9dc5cc --- /dev/null +++ b/prepare.py @@ -0,0 +1,23 @@ +# prepare.py +# +# author: deng +# date : 20231228 + +import yaml + +import torch + + +def prepare(params_path: str = 'params.yaml') -> None: + """Preprocess data save as npz for model training + + Args: + params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'. + """ + + with open(params_path, 'r') as f: + params = yaml.safe_load(f) + + +if __name__ == '__main__': + prepare() diff --git a/train.py b/train.py new file mode 100644 index 0000000..fe371e0 --- /dev/null +++ b/train.py @@ -0,0 +1,24 @@ +# train.py +# +# author: deng +# date : 20231228 + +import yaml + +import torch +from dvclive import Live + + +def train(params_path: str = 'params.yaml') -> None: + """Train a simple model using Pytorch + + Args: + params_path (str, optional): path of config yaml. Defaults to 'params.yaml'. + """ + + with open(params_path, 'r') as f: + params = yaml.safe_load(f) + + +if __name__ == '__main__': + train()