From 7a891969e044bac17c8d26792ae7314edc610647 Mon Sep 17 00:00:00 2001 From: deng Date: Thu, 28 Dec 2023 22:06:25 +0800 Subject: [PATCH] init --- .dvc/.gitignore | 3 + .dvc/config | 4 + .dvcignore | 3 + .gitignore | 2 + README.md | 24 +++++ dvc.yaml | 28 ++++++ env/pt.yaml | 231 ++++++++++++++++++++++++++++++++++++++++++++++++ evaluate.py | 23 +++++ params.yaml | 3 + prepare.py | 23 +++++ train.py | 24 +++++ 11 files changed, 368 insertions(+) create mode 100644 .dvc/.gitignore create mode 100644 .dvc/config create mode 100644 .dvcignore create mode 100644 .gitignore create mode 100644 README.md create mode 100644 dvc.yaml create mode 100644 env/pt.yaml create mode 100644 evaluate.py create mode 100644 params.yaml create mode 100644 prepare.py create mode 100644 train.py diff --git a/.dvc/.gitignore b/.dvc/.gitignore new file mode 100644 index 0000000..528f30c --- /dev/null +++ b/.dvc/.gitignore @@ -0,0 +1,3 @@ +/config.local +/tmp +/cache diff --git a/.dvc/config b/.dvc/config new file mode 100644 index 0000000..1ac08f6 --- /dev/null +++ b/.dvc/config @@ -0,0 +1,4 @@ +[core] + remote = nas +['remote "nas"'] + url = https://webdav.guineapig.love/home/dvc diff --git a/.dvcignore b/.dvcignore new file mode 100644 index 0000000..5197305 --- /dev/null +++ b/.dvcignore @@ -0,0 +1,3 @@ +# Add patterns of files dvc should ignore, which could improve +# the performance. Learn more at +# https://dvc.org/doc/user-guide/dvcignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1ffdb32 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +data +*.pt \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..659abf8 --- /dev/null +++ b/README.md @@ -0,0 +1,24 @@ +# Abstract + +Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track model training with PyTorch, including data, trained model file, and used parameters. The data will be recorded and pushed to my private DVC remote via webdavšŸŽ + +# Requirements + +* MacOS 13.3 + +# Dirs + +* **env** + * **pt.yaml** + * conda env yaml to run this repo + +# Files + +* **prepare.py** + * prepare materials for model training +* **train.py** + * try to train a small neural network +* **evaluate.py** + * evaluate trained model with some metrics + +###### tags: `DVC` \ No newline at end of file diff --git a/dvc.yaml b/dvc.yaml new file mode 100644 index 0000000..ccf68d4 --- /dev/null +++ b/dvc.yaml @@ -0,0 +1,28 @@ +stages: + prepare: + cmd: python prepare.py + deps: + - prepare.py + params: + - prepare + outs: + - data/processed + train: + cmd: python train.py + deps: + - data/processed + - train.py + params: + - train + outs: + - model.pt + evaluate: + cmd: python evaluate.py + deps: + - data/processed + - evaluate.py + - model.pt + params: + - evaluate + outs: + - eval \ No newline at end of file diff --git a/env/pt.yaml b/env/pt.yaml new file mode 100644 index 0000000..1f32466 --- /dev/null +++ b/env/pt.yaml @@ -0,0 +1,231 @@ +name: pt +channels: + - pytorch + - anaconda + - conda-forge +dependencies: + - aiohttp=3.9.1 + - aiohttp-retry=2.8.3 + - aiosignal=1.3.1 + - amqp=5.2.0 + - annotated-types=0.6.0 + - antlr-python-runtime=4.9.3 + - aom=3.7.1 + - appdirs=1.4.4 + - async-timeout=4.0.3 + - asyncssh=2.14.1 + - atk-1.0=2.38.0 + - atpublic=3.0.1 + - attrs=23.1.0 + - backports.zoneinfo=0.2.1 + - billiard=4.1.0 + - boto3=1.34.9 + - botocore=1.34.9 + - brotli-python=1.1.0 + - bzip2=1.0.8 + - ca-certificates=2023.08.22 + - cairo=1.18.0 + - celery=5.3.4 + - certifi=2023.11.17 + - cffi=1.16.0 + - charset-normalizer=3.3.2 + - click=8.1.7 + - click-didyoumean=0.3.0 + - click-plugins=1.1.1 + - click-repl=0.3.0 + - colorama=0.4.6 + - configobj=5.0.8 + - cryptography=41.0.7 + - dav1d=1.2.1 + - decorator=5.1.1 + - dictdiffer=0.9.0 + - diskcache=5.6.3 + - distro=1.8.0 + - dpath=2.1.6 + - dulwich=0.21.7 + - dvc=3.37.0 + - dvc-data=3.5.0 + - dvc-http=2.32.0 + - dvc-objects=3.0.0 + - dvc-render=1.0.0 + - dvc-studio-client=0.18.0 + - dvc-task=0.3.0 + - dvclive=3.5.1 + - entrypoints=0.4 + - expat=2.5.0 + - ffmpeg=6.1.0 + - filelock=3.13.1 + - flatten-dict=0.4.2 + - flufl.lock=7.1 + - font-ttf-dejavu-sans-mono=2.37 + - font-ttf-inconsolata=3.000 + - font-ttf-source-code-pro=2.038 + - font-ttf-ubuntu=0.83 + - fontconfig=2.14.2 + - fonts-conda-ecosystem=1 + - fonts-conda-forge=1 + - freetype=2.12.1 + - fribidi=1.0.10 + - frozenlist=1.4.1 + - fsspec=2023.12.2 + - funcy=2.0 + - future=0.18.3 + - gdk-pixbuf=2.42.10 + - gettext=0.21.1 + - giflib=5.2.1 + - gitdb=4.0.11 + - gitpython=3.1.40 + - gmp=6.3.0 + - gmpy2=2.1.2 + - gnutls=3.7.9 + - grandalf=0.7 + - graphite2=1.3.13 + - graphviz=9.0.0 + - gtk2=2.24.33 + - gto=1.6.1 + - gts=0.7.6 + - harfbuzz=8.3.0 + - hydra-core=1.3.2 + - icu=73.2 + - idna=3.6 + - importlib-metadata=7.0.1 + - importlib_resources=6.1.1 + - iterative-telemetry=0.0.8 + - jinja2=3.1.2 + - jmespath=1.0.1 + - joblib=1.2.0 + - kombu=5.3.4 + - krb5=1.21.2 + - lame=3.100 + - lcms2=2.16 + - lerc=4.0.0 + - libass=0.17.1 + - libblas=3.9.0 + - libcblas=3.9.0 + - libcxx=16.0.6 + - libdeflate=1.19 + - libedit=3.1.20191231 + - libexpat=2.5.0 + - libffi=3.4.2 + - libgd=2.3.3 + - libgfortran=5.0.0 + - libgfortran5=13.2.0 + - libgit2=1.7.1 + - libglib=2.78.3 + - libiconv=1.17 + - libidn2=2.3.4 + - libjpeg-turbo=3.0.0 + - liblapack=3.9.0 + - libopenblas=0.3.25 + - libopus=1.3.1 + - libpng=1.6.39 + - librsvg=2.56.3 + - libsqlite=3.44.2 + - libssh2=1.11.0 + - libtasn1=4.19.0 + - libtiff=4.6.0 + - libunistring=0.9.10 + - libvpx=1.13.1 + - libwebp=1.3.2 + - libwebp-base=1.3.2 + - libxcb=1.15 + - libxml2=2.12.3 + - libzlib=1.2.13 + - llvm-openmp=17.0.6 + - markdown-it-py=3.0.0 + - markupsafe=2.1.3 + - mdurl=0.1.0 + - mpc=1.3.1 + - mpfr=4.2.1 + - mpmath=1.3.0 + - multidict=6.0.4 + - nanotime=0.5.2 + - ncurses=6.4 + - nettle=3.9.1 + - networkx=3.2.1 + - numpy=1.26.2 + - omegaconf=2.3.0 + - openh264=2.4.0 + - openjpeg=2.5.0 + - openssl=3.2.0 + - orjson=3.9.10 + - p11-kit=0.24.1 + - packaging=23.2 + - pango=1.50.14 + - pathlib2=2.3.7.post1 + - pathspec=0.12.1 + - pcre2=10.42 + - pillow=10.1.0 + - pip=23.3.2 + - pixman=0.42.2 + - platformdirs=3.11.0 + - prompt-toolkit=3.0.42 + - prompt_toolkit=3.0.42 + - psutil=5.9.7 + - pthread-stubs=0.4 + - pycparser=2.21 + - pydantic=2.5.3 + - pydantic-core=2.14.6 + - pydot=1.4.2 + - pygit2=1.13.3 + - pygments=2.17.2 + - pygtrie=2.5.0 + - pyopenssl=23.3.0 + - pyparsing=3.1.1 + - pysocks=1.7.1 + - python=3.10.13 + - python-dateutil=2.8.2 + - python-gssapi=1.8.3 + - python-tzdata=2023.3 + - python_abi=3.10 + - pytorch=2.1.2 + - pytz=2023.3.post1 + - pywin32-on-windows=0.1.0 + - pyyaml=6.0.1 + - readline=8.2 + - requests=2.31.0 + - rich=13.7.0 + - ruamel.yaml=0.18.5 + - ruamel.yaml.clib=0.2.7 + - s3transfer=0.10.0 + - scikit-learn=1.3.0 + - scipy=1.11.4 + - scmrepo=2.0.2 + - semver=3.0.2 + - setuptools=68.2.2 + - shellingham=1.5.4 + - shortuuid=1.0.11 + - shtab=1.6.5 + - six=1.16.0 + - smmap=5.0.0 + - sqltrie=0.11.0 + - svt-av1=1.8.0 + - sympy=1.12 + - tabulate=0.9.0 + - threadpoolctl=2.2.0 + - tk=8.6.13 + - tomlkit=0.12.3 + - torchaudio=2.1.2 + - torchvision=0.16.2 + - tqdm=4.66.1 + - typer=0.9.0 + - typing-extensions=4.9.0 + - typing_extensions=4.9.0 + - tzdata=2023d + - urllib3=1.26.18 + - vine=5.0.0 + - voluptuous=0.14.1 + - wcwidth=0.2.12 + - wheel=0.42.0 + - x264=1!164.3095 + - x265=3.5 + - xorg-libxau=1.0.11 + - xorg-libxdmcp=1.1.3 + - xz=5.2.6 + - yaml=0.2.5 + - yarl=1.9.3 + - zc.lockfile=3.0.post1 + - zipp=3.17.0 + - zlib=1.2.13 + - zstd=1.5.5 +prefix: /Users/xiao_deng/miniforge3/envs/pt diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..a1c9852 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,23 @@ +# evaluate.py +# +# author: deng +# date : 20231228 + +import yaml + +import torch + + +def evaluate(params_path: str = 'params.yaml') -> None: + """Evaluate model and save results to eval dir + + Args: + params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'. + """ + + with open(params_path, 'r') as f: + params = yaml.safe_load(f) + + +if __name__ == '__main__': + evaluate() diff --git a/params.yaml b/params.yaml new file mode 100644 index 0000000..80f5bb5 --- /dev/null +++ b/params.yaml @@ -0,0 +1,3 @@ +prepare: +train: +evaluate: \ No newline at end of file diff --git a/prepare.py b/prepare.py new file mode 100644 index 0000000..b9dc5cc --- /dev/null +++ b/prepare.py @@ -0,0 +1,23 @@ +# prepare.py +# +# author: deng +# date : 20231228 + +import yaml + +import torch + + +def prepare(params_path: str = 'params.yaml') -> None: + """Preprocess data save as npz for model training + + Args: + params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'. + """ + + with open(params_path, 'r') as f: + params = yaml.safe_load(f) + + +if __name__ == '__main__': + prepare() diff --git a/train.py b/train.py new file mode 100644 index 0000000..fe371e0 --- /dev/null +++ b/train.py @@ -0,0 +1,24 @@ +# train.py +# +# author: deng +# date : 20231228 + +import yaml + +import torch +from dvclive import Live + + +def train(params_path: str = 'params.yaml') -> None: + """Train a simple model using Pytorch + + Args: + params_path (str, optional): path of config yaml. Defaults to 'params.yaml'. + """ + + with open(params_path, 'r') as f: + params = yaml.safe_load(f) + + +if __name__ == '__main__': + train()