init
This commit is contained in:
		| @ -0,0 +1,16 @@ | |||||||
|  | artifact_path: model | ||||||
|  | flavors: | ||||||
|  |   python_function: | ||||||
|  |     data: data | ||||||
|  |     env: conda.yaml | ||||||
|  |     loader_module: mlflow.pytorch | ||||||
|  |     pickle_module_name: mlflow.pytorch.pickle_module | ||||||
|  |     python_version: 3.10.9 | ||||||
|  |   pytorch: | ||||||
|  |     code: null | ||||||
|  |     model_data: data | ||||||
|  |     pytorch_version: 1.13.1 | ||||||
|  | mlflow_version: 1.30.0 | ||||||
|  | model_uuid: 2382b7a39c064e7b9b1465cfd84140a3 | ||||||
|  | run_id: 24469fc083d6470a9cad7f17a6eeeea0 | ||||||
|  | utc_time_created: '2023-02-21 05:57:41.973454' | ||||||
| @ -0,0 +1,11 @@ | |||||||
|  | channels: | ||||||
|  | - conda-forge | ||||||
|  | dependencies: | ||||||
|  | - python=3.10.9 | ||||||
|  | - pip<=23.0.1 | ||||||
|  | - pip: | ||||||
|  |   - mlflow | ||||||
|  |   - cloudpickle==2.2.1 | ||||||
|  |   - torch==1.13.1 | ||||||
|  |   - tqdm==4.64.1 | ||||||
|  | name: mlflow-env | ||||||
										
											Binary file not shown.
										
									
								
							| @ -0,0 +1 @@ | |||||||
|  | mlflow.pytorch.pickle_module | ||||||
| @ -0,0 +1,7 @@ | |||||||
|  | python: 3.10.9 | ||||||
|  | build_dependencies: | ||||||
|  | - pip==23.0.1 | ||||||
|  | - setuptools==67.3.2 | ||||||
|  | - wheel==0.38.4 | ||||||
|  | dependencies: | ||||||
|  | - -r requirements.txt | ||||||
| @ -0,0 +1,4 @@ | |||||||
|  | mlflow | ||||||
|  | cloudpickle==2.2.1 | ||||||
|  | torch==1.13.1 | ||||||
|  | tqdm==4.64.1 | ||||||
| @ -0,0 +1,16 @@ | |||||||
|  | artifact_path: cls_model | ||||||
|  | flavors: | ||||||
|  |   python_function: | ||||||
|  |     data: data | ||||||
|  |     env: conda.yaml | ||||||
|  |     loader_module: mlflow.pytorch | ||||||
|  |     pickle_module_name: mlflow.pytorch.pickle_module | ||||||
|  |     python_version: 3.10.9 | ||||||
|  |   pytorch: | ||||||
|  |     code: null | ||||||
|  |     model_data: data | ||||||
|  |     pytorch_version: 1.13.1 | ||||||
|  | mlflow_version: 1.30.0 | ||||||
|  | model_uuid: e40643f3e1b9481896e1ae6ed30e8654 | ||||||
|  | run_id: 2820b379bfc945358bfd516e5577846c | ||||||
|  | utc_time_created: '2023-02-21 05:33:10.779919' | ||||||
| @ -0,0 +1,10 @@ | |||||||
|  | channels: | ||||||
|  | - conda-forge | ||||||
|  | dependencies: | ||||||
|  | - python=3.10.9 | ||||||
|  | - pip<=23.0.1 | ||||||
|  | - pip: | ||||||
|  |   - mlflow | ||||||
|  |   - cloudpickle==2.2.1 | ||||||
|  |   - torch==1.13.1 | ||||||
|  | name: mlflow-env | ||||||
										
											Binary file not shown.
										
									
								
							| @ -0,0 +1 @@ | |||||||
|  | mlflow.pytorch.pickle_module | ||||||
| @ -0,0 +1,7 @@ | |||||||
|  | python: 3.10.9 | ||||||
|  | build_dependencies: | ||||||
|  | - pip==23.0.1 | ||||||
|  | - setuptools==67.3.2 | ||||||
|  | - wheel==0.38.4 | ||||||
|  | dependencies: | ||||||
|  | - -r requirements.txt | ||||||
| @ -0,0 +1,3 @@ | |||||||
|  | mlflow | ||||||
|  | cloudpickle==2.2.1 | ||||||
|  | torch==1.13.1 | ||||||
| @ -0,0 +1,16 @@ | |||||||
|  | artifact_path: models | ||||||
|  | flavors: | ||||||
|  |   python_function: | ||||||
|  |     data: data | ||||||
|  |     env: conda.yaml | ||||||
|  |     loader_module: mlflow.pytorch | ||||||
|  |     pickle_module_name: mlflow.pytorch.pickle_module | ||||||
|  |     python_version: 3.10.9 | ||||||
|  |   pytorch: | ||||||
|  |     code: null | ||||||
|  |     model_data: data | ||||||
|  |     pytorch_version: 1.13.1 | ||||||
|  | mlflow_version: 1.30.0 | ||||||
|  | model_uuid: faf1bec9ecb64581b22a0b8e09a9cca8 | ||||||
|  | run_id: 3ef01a1e3e3d4ba2be705da789bbb8e1 | ||||||
|  | utc_time_created: '2023-02-21 05:07:17.344052' | ||||||
| @ -0,0 +1,10 @@ | |||||||
|  | channels: | ||||||
|  | - conda-forge | ||||||
|  | dependencies: | ||||||
|  | - python=3.10.9 | ||||||
|  | - pip<=23.0.1 | ||||||
|  | - pip: | ||||||
|  |   - mlflow | ||||||
|  |   - cloudpickle==2.2.1 | ||||||
|  |   - torch==1.13.1 | ||||||
|  | name: mlflow-env | ||||||
										
											Binary file not shown.
										
									
								
							| @ -0,0 +1 @@ | |||||||
|  | mlflow.pytorch.pickle_module | ||||||
| @ -0,0 +1,7 @@ | |||||||
|  | python: 3.10.9 | ||||||
|  | build_dependencies: | ||||||
|  | - pip==23.0.1 | ||||||
|  | - setuptools==67.3.2 | ||||||
|  | - wheel==0.38.4 | ||||||
|  | dependencies: | ||||||
|  | - -r requirements.txt | ||||||
| @ -0,0 +1,3 @@ | |||||||
|  | mlflow | ||||||
|  | cloudpickle==2.2.1 | ||||||
|  | torch==1.13.1 | ||||||
| @ -0,0 +1,16 @@ | |||||||
|  | artifact_path: cls_model | ||||||
|  | flavors: | ||||||
|  |   python_function: | ||||||
|  |     data: data | ||||||
|  |     env: conda.yaml | ||||||
|  |     loader_module: mlflow.pytorch | ||||||
|  |     pickle_module_name: mlflow.pytorch.pickle_module | ||||||
|  |     python_version: 3.10.9 | ||||||
|  |   pytorch: | ||||||
|  |     code: null | ||||||
|  |     model_data: data | ||||||
|  |     pytorch_version: 1.13.1 | ||||||
|  | mlflow_version: 1.30.0 | ||||||
|  | model_uuid: a0ecc970cadb47a9b839283e9514732d | ||||||
|  | run_id: 63c7363339e042f4848d9041ba8deb82 | ||||||
|  | utc_time_created: '2023-02-21 05:37:55.904472' | ||||||
| @ -0,0 +1,10 @@ | |||||||
|  | channels: | ||||||
|  | - conda-forge | ||||||
|  | dependencies: | ||||||
|  | - python=3.10.9 | ||||||
|  | - pip<=23.0.1 | ||||||
|  | - pip: | ||||||
|  |   - mlflow | ||||||
|  |   - cloudpickle==2.2.1 | ||||||
|  |   - torch==1.13.1 | ||||||
|  | name: mlflow-env | ||||||
										
											Binary file not shown.
										
									
								
							| @ -0,0 +1 @@ | |||||||
|  | mlflow.pytorch.pickle_module | ||||||
| @ -0,0 +1,7 @@ | |||||||
|  | python: 3.10.9 | ||||||
|  | build_dependencies: | ||||||
|  | - pip==23.0.1 | ||||||
|  | - setuptools==67.3.2 | ||||||
|  | - wheel==0.38.4 | ||||||
|  | dependencies: | ||||||
|  | - -r requirements.txt | ||||||
| @ -0,0 +1,3 @@ | |||||||
|  | mlflow | ||||||
|  | cloudpickle==2.2.1 | ||||||
|  | torch==1.13.1 | ||||||
| @ -0,0 +1,16 @@ | |||||||
|  | artifact_path: models | ||||||
|  | flavors: | ||||||
|  |   python_function: | ||||||
|  |     data: data | ||||||
|  |     env: conda.yaml | ||||||
|  |     loader_module: mlflow.pytorch | ||||||
|  |     pickle_module_name: mlflow.pytorch.pickle_module | ||||||
|  |     python_version: 3.10.9 | ||||||
|  |   pytorch: | ||||||
|  |     code: null | ||||||
|  |     model_data: data | ||||||
|  |     pytorch_version: 1.13.1 | ||||||
|  | mlflow_version: 1.30.0 | ||||||
|  | model_uuid: 4ebe94bd0249452a90b3497d3b00a1c3 | ||||||
|  | run_id: 6845ef0d54024cb3bdb32050f6a46fea | ||||||
|  | utc_time_created: '2023-02-21 05:25:14.020335' | ||||||
| @ -0,0 +1,10 @@ | |||||||
|  | channels: | ||||||
|  | - conda-forge | ||||||
|  | dependencies: | ||||||
|  | - python=3.10.9 | ||||||
|  | - pip<=23.0.1 | ||||||
|  | - pip: | ||||||
|  |   - mlflow | ||||||
|  |   - cloudpickle==2.2.1 | ||||||
|  |   - torch==1.13.1 | ||||||
|  | name: mlflow-env | ||||||
										
											Binary file not shown.
										
									
								
							| @ -0,0 +1 @@ | |||||||
|  | mlflow.pytorch.pickle_module | ||||||
| @ -0,0 +1,7 @@ | |||||||
|  | python: 3.10.9 | ||||||
|  | build_dependencies: | ||||||
|  | - pip==23.0.1 | ||||||
|  | - setuptools==67.3.2 | ||||||
|  | - wheel==0.38.4 | ||||||
|  | dependencies: | ||||||
|  | - -r requirements.txt | ||||||
| @ -0,0 +1,3 @@ | |||||||
|  | mlflow | ||||||
|  | cloudpickle==2.2.1 | ||||||
|  | torch==1.13.1 | ||||||
| @ -0,0 +1,16 @@ | |||||||
|  | artifact_path: models | ||||||
|  | flavors: | ||||||
|  |   python_function: | ||||||
|  |     data: data | ||||||
|  |     env: conda.yaml | ||||||
|  |     loader_module: mlflow.pytorch | ||||||
|  |     pickle_module_name: mlflow.pytorch.pickle_module | ||||||
|  |     python_version: 3.10.9 | ||||||
|  |   pytorch: | ||||||
|  |     code: null | ||||||
|  |     model_data: data | ||||||
|  |     pytorch_version: 1.13.1 | ||||||
|  | mlflow_version: 1.30.0 | ||||||
|  | model_uuid: 8cd1e70114e548ea8d9bfb1bf468e285 | ||||||
|  | run_id: 68e8a3cbbafa46538ebb8a60d80f185d | ||||||
|  | utc_time_created: '2023-02-21 05:05:46.624814' | ||||||
| @ -0,0 +1,10 @@ | |||||||
|  | channels: | ||||||
|  | - conda-forge | ||||||
|  | dependencies: | ||||||
|  | - python=3.10.9 | ||||||
|  | - pip<=23.0.1 | ||||||
|  | - pip: | ||||||
|  |   - mlflow | ||||||
|  |   - cloudpickle==2.2.1 | ||||||
|  |   - torch==1.13.1 | ||||||
|  | name: mlflow-env | ||||||
										
											Binary file not shown.
										
									
								
							| @ -0,0 +1 @@ | |||||||
|  | mlflow.pytorch.pickle_module | ||||||
| @ -0,0 +1,7 @@ | |||||||
|  | python: 3.10.9 | ||||||
|  | build_dependencies: | ||||||
|  | - pip==23.0.1 | ||||||
|  | - setuptools==67.3.2 | ||||||
|  | - wheel==0.38.4 | ||||||
|  | dependencies: | ||||||
|  | - -r requirements.txt | ||||||
| @ -0,0 +1,3 @@ | |||||||
|  | mlflow | ||||||
|  | cloudpickle==2.2.1 | ||||||
|  | torch==1.13.1 | ||||||
| @ -0,0 +1,16 @@ | |||||||
|  | artifact_path: models | ||||||
|  | flavors: | ||||||
|  |   python_function: | ||||||
|  |     data: data | ||||||
|  |     env: conda.yaml | ||||||
|  |     loader_module: mlflow.pytorch | ||||||
|  |     pickle_module_name: mlflow.pytorch.pickle_module | ||||||
|  |     python_version: 3.10.9 | ||||||
|  |   pytorch: | ||||||
|  |     code: null | ||||||
|  |     model_data: data | ||||||
|  |     pytorch_version: 1.13.1 | ||||||
|  | mlflow_version: 1.30.0 | ||||||
|  | model_uuid: dd1b1e3a6b5f4274a5690a8843751ff3 | ||||||
|  | run_id: 8ba27f225a7442be8816977c2077c510 | ||||||
|  | utc_time_created: '2023-02-21 05:05:04.660670' | ||||||
| @ -0,0 +1,10 @@ | |||||||
|  | channels: | ||||||
|  | - conda-forge | ||||||
|  | dependencies: | ||||||
|  | - python=3.10.9 | ||||||
|  | - pip<=23.0.1 | ||||||
|  | - pip: | ||||||
|  |   - mlflow | ||||||
|  |   - cloudpickle==2.2.1 | ||||||
|  |   - torch==1.13.1 | ||||||
|  | name: mlflow-env | ||||||
										
											Binary file not shown.
										
									
								
							| @ -0,0 +1 @@ | |||||||
|  | mlflow.pytorch.pickle_module | ||||||
| @ -0,0 +1,7 @@ | |||||||
|  | python: 3.10.9 | ||||||
|  | build_dependencies: | ||||||
|  | - pip==23.0.1 | ||||||
|  | - setuptools==67.3.2 | ||||||
|  | - wheel==0.38.4 | ||||||
|  | dependencies: | ||||||
|  | - -r requirements.txt | ||||||
| @ -0,0 +1,3 @@ | |||||||
|  | mlflow | ||||||
|  | cloudpickle==2.2.1 | ||||||
|  | torch==1.13.1 | ||||||
| @ -0,0 +1,16 @@ | |||||||
|  | artifact_path: cls_model | ||||||
|  | flavors: | ||||||
|  |   python_function: | ||||||
|  |     data: data | ||||||
|  |     env: conda.yaml | ||||||
|  |     loader_module: mlflow.pytorch | ||||||
|  |     pickle_module_name: mlflow.pytorch.pickle_module | ||||||
|  |     python_version: 3.10.9 | ||||||
|  |   pytorch: | ||||||
|  |     code: null | ||||||
|  |     model_data: data | ||||||
|  |     pytorch_version: 1.13.1 | ||||||
|  | mlflow_version: 1.30.0 | ||||||
|  | model_uuid: aaa800b217da4dd0b8f17e8dbfdc5c45 | ||||||
|  | run_id: f1320882f24c4f489cbf85159627eaf8 | ||||||
|  | utc_time_created: '2023-02-21 05:34:08.242864' | ||||||
| @ -0,0 +1,10 @@ | |||||||
|  | channels: | ||||||
|  | - conda-forge | ||||||
|  | dependencies: | ||||||
|  | - python=3.10.9 | ||||||
|  | - pip<=23.0.1 | ||||||
|  | - pip: | ||||||
|  |   - mlflow | ||||||
|  |   - cloudpickle==2.2.1 | ||||||
|  |   - torch==1.13.1 | ||||||
|  | name: mlflow-env | ||||||
										
											Binary file not shown.
										
									
								
							| @ -0,0 +1 @@ | |||||||
|  | mlflow.pytorch.pickle_module | ||||||
| @ -0,0 +1,7 @@ | |||||||
|  | python: 3.10.9 | ||||||
|  | build_dependencies: | ||||||
|  | - pip==23.0.1 | ||||||
|  | - setuptools==67.3.2 | ||||||
|  | - wheel==0.38.4 | ||||||
|  | dependencies: | ||||||
|  | - -r requirements.txt | ||||||
| @ -0,0 +1,3 @@ | |||||||
|  | mlflow | ||||||
|  | cloudpickle==2.2.1 | ||||||
|  | torch==1.13.1 | ||||||
							
								
								
									
										141
									
								
								env.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								env.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,141 @@ | |||||||
|  | name: torch | ||||||
|  | channels: | ||||||
|  |   - pytorch | ||||||
|  |   - conda-forge | ||||||
|  | dependencies: | ||||||
|  |   - alembic=1.9.4 | ||||||
|  |   - aom=3.5.0 | ||||||
|  |   - appdirs=1.4.4 | ||||||
|  |   - bcrypt=3.2.2 | ||||||
|  |   - blinker=1.5 | ||||||
|  |   - brotlipy=0.7.0 | ||||||
|  |   - bzip2=1.0.8 | ||||||
|  |   - ca-certificates=2022.12.7 | ||||||
|  |   - certifi=2022.12.7 | ||||||
|  |   - cffi=1.15.1 | ||||||
|  |   - charset-normalizer=2.1.1 | ||||||
|  |   - click=8.1.3 | ||||||
|  |   - cloudpickle=2.2.1 | ||||||
|  |   - colorama=0.4.6 | ||||||
|  |   - configparser=5.3.0 | ||||||
|  |   - cryptography=39.0.1 | ||||||
|  |   - databricks-cli=0.17.4 | ||||||
|  |   - docker-py=6.0.0 | ||||||
|  |   - entrypoints=0.4 | ||||||
|  |   - expat=2.5.0 | ||||||
|  |   - ffmpeg=5.1.2 | ||||||
|  |   - flask=2.2.3 | ||||||
|  |   - font-ttf-dejavu-sans-mono=2.37 | ||||||
|  |   - font-ttf-inconsolata=3.000 | ||||||
|  |   - font-ttf-source-code-pro=2.038 | ||||||
|  |   - font-ttf-ubuntu=0.83 | ||||||
|  |   - fontconfig=2.14.2 | ||||||
|  |   - fonts-conda-ecosystem=1 | ||||||
|  |   - fonts-conda-forge=1 | ||||||
|  |   - freetype=2.12.1 | ||||||
|  |   - gettext=0.21.1 | ||||||
|  |   - gitdb=4.0.10 | ||||||
|  |   - gitpython=3.1.31 | ||||||
|  |   - gmp=6.2.1 | ||||||
|  |   - gnutls=3.7.8 | ||||||
|  |   - greenlet=2.0.2 | ||||||
|  |   - gunicorn=20.1.0 | ||||||
|  |   - icu=70.1 | ||||||
|  |   - idna=3.4 | ||||||
|  |   - importlib-metadata=5.2.0 | ||||||
|  |   - importlib_resources=5.12.0 | ||||||
|  |   - itsdangerous=2.1.2 | ||||||
|  |   - jinja2=3.1.2 | ||||||
|  |   - jpeg=9e | ||||||
|  |   - lame=3.100 | ||||||
|  |   - lcms2=2.14 | ||||||
|  |   - lerc=4.0.0 | ||||||
|  |   - libblas=3.9.0 | ||||||
|  |   - libcblas=3.9.0 | ||||||
|  |   - libcxx=14.0.6 | ||||||
|  |   - libdeflate=1.17 | ||||||
|  |   - libffi=3.4.2 | ||||||
|  |   - libgfortran=5.0.0 | ||||||
|  |   - libgfortran5=11.3.0 | ||||||
|  |   - libiconv=1.17 | ||||||
|  |   - libidn2=2.3.4 | ||||||
|  |   - liblapack=3.9.0 | ||||||
|  |   - libopenblas=0.3.21 | ||||||
|  |   - libopus=1.3.1 | ||||||
|  |   - libpng=1.6.39 | ||||||
|  |   - libprotobuf=3.21.12 | ||||||
|  |   - libsodium=1.0.18 | ||||||
|  |   - libsqlite=3.40.0 | ||||||
|  |   - libtasn1=4.19.0 | ||||||
|  |   - libtiff=4.5.0 | ||||||
|  |   - libunistring=0.9.10 | ||||||
|  |   - libvpx=1.11.0 | ||||||
|  |   - libwebp-base=1.2.4 | ||||||
|  |   - libxcb=1.13 | ||||||
|  |   - libxml2=2.10.3 | ||||||
|  |   - libzlib=1.2.13 | ||||||
|  |   - llvm-openmp=15.0.7 | ||||||
|  |   - mako=1.2.4 | ||||||
|  |   - markupsafe=2.1.2 | ||||||
|  |   - mlflow=1.30.0 | ||||||
|  |   - ncurses=6.3 | ||||||
|  |   - nettle=3.8.1 | ||||||
|  |   - numpy=1.24.2 | ||||||
|  |   - oauthlib=3.2.2 | ||||||
|  |   - openh264=2.3.1 | ||||||
|  |   - openjpeg=2.5.0 | ||||||
|  |   - openssl=3.0.8 | ||||||
|  |   - p11-kit=0.24.1 | ||||||
|  |   - packaging=21.3 | ||||||
|  |   - pandas=1.5.3 | ||||||
|  |   - paramiko=3.0.0 | ||||||
|  |   - pillow=9.4.0 | ||||||
|  |   - pip=23.0.1 | ||||||
|  |   - pooch=1.6.0 | ||||||
|  |   - prometheus_client=0.16.0 | ||||||
|  |   - prometheus_flask_exporter=0.22.0 | ||||||
|  |   - protobuf=4.21.12 | ||||||
|  |   - pthread-stubs=0.4 | ||||||
|  |   - pycparser=2.21 | ||||||
|  |   - pyjwt=2.6.0 | ||||||
|  |   - pynacl=1.5.0 | ||||||
|  |   - pyopenssl=23.0.0 | ||||||
|  |   - pyparsing=3.0.9 | ||||||
|  |   - pysocks=1.7.1 | ||||||
|  |   - python=3.10.9 | ||||||
|  |   - python-dateutil=2.8.2 | ||||||
|  |   - python_abi=3.10 | ||||||
|  |   - pytorch=1.13.1 | ||||||
|  |   - pytz=2022.7.1 | ||||||
|  |   - pywin32-on-windows=0.1.0 | ||||||
|  |   - pyyaml=6.0 | ||||||
|  |   - querystring_parser=1.2.4 | ||||||
|  |   - readline=8.1.2 | ||||||
|  |   - requests=2.28.2 | ||||||
|  |   - scipy=1.10.0 | ||||||
|  |   - setuptools=67.3.2 | ||||||
|  |   - six=1.16.0 | ||||||
|  |   - smmap=3.0.5 | ||||||
|  |   - sqlalchemy=1.4.46 | ||||||
|  |   - sqlparse=0.4.3 | ||||||
|  |   - svt-av1=1.4.1 | ||||||
|  |   - tabulate=0.9.0 | ||||||
|  |   - tk=8.6.12 | ||||||
|  |   - torchaudio=0.13.1 | ||||||
|  |   - torchvision=0.14.1 | ||||||
|  |   - tqdm=4.64.1 | ||||||
|  |   - typing_extensions=4.4.0 | ||||||
|  |   - tzdata=2022g | ||||||
|  |   - urllib3=1.26.14 | ||||||
|  |   - websocket-client=1.5.1 | ||||||
|  |   - werkzeug=2.2.3 | ||||||
|  |   - wheel=0.38.4 | ||||||
|  |   - x264=1!164.3095 | ||||||
|  |   - x265=3.5 | ||||||
|  |   - xorg-libxau=1.0.9 | ||||||
|  |   - xorg-libxdmcp=1.1.3 | ||||||
|  |   - xz=5.2.6 | ||||||
|  |   - yaml=0.2.5 | ||||||
|  |   - zipp=3.14.0 | ||||||
|  |   - zstd=1.5.2 | ||||||
|  | prefix: /Users/xiao_deng/miniforge3/envs/torch | ||||||
							
								
								
									
										21
									
								
								predict.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								predict.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | |||||||
|  | # predict.py | ||||||
|  | # | ||||||
|  | # author: deng | ||||||
|  | # date  : 20230221 | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import mlflow | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |     # set MLflow server | ||||||
|  |     mlflow.set_tracking_uri('http://127.0.0.1:5000') | ||||||
|  |  | ||||||
|  |     # load production model | ||||||
|  |     model = mlflow.pytorch.load_model('models:/cls_model/production') | ||||||
|  |  | ||||||
|  |     # predict | ||||||
|  |     fake_data = torch.randn(10) | ||||||
|  |     output = model(fake_data) | ||||||
|  |     print(output) | ||||||
							
								
								
									
										7
									
								
								start_mlflow_server.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								start_mlflow_server.sh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,7 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | # start_mlflow_server.sh | ||||||
|  | # | ||||||
|  | # author: deng | ||||||
|  | # date  : 20230221 | ||||||
|  |  | ||||||
|  | mlflow server --backend-store-uri sqlite:///mlflow.db --default-artifact-root ./artifacts | ||||||
							
								
								
									
										79
									
								
								train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								train.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,79 @@ | |||||||
|  | # train.py | ||||||
|  | # | ||||||
|  | # author: deng | ||||||
|  | # date  : 20230221 | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from torch.optim import SGD | ||||||
|  | from tqdm import tqdm | ||||||
|  | import mlflow | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Net(nn.Module): | ||||||
|  |     """ define a simple neural network model """ | ||||||
|  |     def __init__(self): | ||||||
|  |         super(Net, self).__init__() | ||||||
|  |         self.fc1 = nn.Linear(10, 5) | ||||||
|  |         self.fc2 = nn.Linear(5, 1) | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = self.fc1(x) | ||||||
|  |         x = torch.relu(x) | ||||||
|  |         x = self.fc2(x) | ||||||
|  |         return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def train(model, dataloader, criterion, optimizer, epochs): | ||||||
|  |     """ define the training function """ | ||||||
|  |     for epoch in tqdm(range(epochs), 'Epochs'): | ||||||
|  |  | ||||||
|  |         for i, (inputs, labels) in enumerate(dataloader): | ||||||
|  |  | ||||||
|  |             # forwarding | ||||||
|  |             outputs = model(inputs) | ||||||
|  |             loss = criterion(outputs, labels) | ||||||
|  |  | ||||||
|  |             # update gradient | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |  | ||||||
|  |             # log loss | ||||||
|  |             mlflow.log_metric('train_loss', loss.item(), step=i) | ||||||
|  |  | ||||||
|  |     return loss | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |     # set hyper parameters | ||||||
|  |     learning_rate = 1e-2 | ||||||
|  |     epochs = 20 | ||||||
|  |  | ||||||
|  |     # create a dataloader with fake data | ||||||
|  |     dataloader = [(torch.randn(10), torch.randn(1)) for _ in range(100)] | ||||||
|  |     dataloader = torch.utils.data.DataLoader(dataloader, batch_size=10) | ||||||
|  |  | ||||||
|  |     # create the model, criterion, and optimizer | ||||||
|  |     model = Net() | ||||||
|  |     criterion = nn.MSELoss() | ||||||
|  |     optimizer = SGD(model.parameters(), lr=learning_rate) | ||||||
|  |  | ||||||
|  |     # set the tracking URI to the model registry | ||||||
|  |     mlflow.set_tracking_uri('http://127.0.0.1:5000') | ||||||
|  |  | ||||||
|  |     # start the MLflow run | ||||||
|  |     with mlflow.start_run(): | ||||||
|  |  | ||||||
|  |         # train the model and log the loss | ||||||
|  |         loss = train(model, dataloader, criterion, optimizer, epochs) | ||||||
|  |  | ||||||
|  |         # log some additional metrics | ||||||
|  |         mlflow.log_metric('final_loss', loss.item()) | ||||||
|  |         mlflow.log_param('learning_rate', learning_rate) | ||||||
|  |  | ||||||
|  |         # log trained  model | ||||||
|  |         mlflow.pytorch.log_model(model, 'model') | ||||||
|  |  | ||||||
|  |     print('Completed.') | ||||||
		Reference in New Issue
	
	Block a user