Compare commits

...

18 Commits

Author SHA1 Message Date
2ce9a31883 update service data 2023-04-19 11:46:42 +08:00
1b90edee59 add demo to download optimized model on client 2023-04-19 11:46:33 +08:00
e1f143736e update service data 2023-04-18 17:13:15 +08:00
3c39c48242 script to optimize pytorch model on server 2023-04-18 17:13:05 +08:00
ac6400e93a add one more layer to model dir 2023-04-11 11:56:56 +08:00
0e6a5b8925 update service 2023-04-11 11:56:07 +08:00
578f0ceea1 update service 2023-03-10 15:03:31 +08:00
4bf037de15 package unsupported ml model 2023-03-10 15:03:21 +08:00
1f86146b12 replace bash script to docker-compose to build server 2023-03-07 16:33:45 +08:00
ec19042d0d update file description and reference 2023-03-01 17:13:17 +08:00
a0ac14d0f7 mlflow 1.30 -> 2.1 2023-03-01 14:56:27 +08:00
0327ebf1f4 mod service file structure 2023-03-01 14:55:51 +08:00
3c8580f0f4 update data storage 2023-02-26 05:10:58 +08:00
8001876359 test rest api 2023-02-26 05:10:28 +08:00
b31dbcd0f0 ignore saved model 2023-02-22 16:27:19 +08:00
2dd734b87b update server data 2023-02-22 16:26:02 +08:00
7015b5c1a5 record model input&output, save model to file system 2023-02-22 16:25:26 +08:00
6e80927c40 add .gitignore 2023-02-21 16:09:05 +08:00
63 changed files with 465 additions and 338 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
.DS_Store
__pycache__
model

View File

@ -1,3 +1,35 @@
# test_mlflow # Abstract
測試使用MLflow紀錄Pytorch模型訓練以及從Model registry中拉取Production model進行推論。 Try to use [MLflow](https://mlflow.org) platform to log PyTorch model training, and pull production model from model registry to run inference⛩
# Requirements
* MacOS 12.5
* Docker 20.10
# Dirs
* **service**
* House MLflow service data, including MLflow artifacts, backend store and model registry
* **env**
* **mlflow.yaml**
* conda env yaml to run this repo
# Files
* **docker-compose.yaml**
* a yaml to apply docker-compose to start MLflow service with basic configuration (run ```docker-compose -f docker-compose.yaml up```)
* **test_pytorch_m1.py**
* a script to test PyTorch on Apple M1 platform with GPU acceleration
* **train.py**
* a sample code to apply PyTorch to train a small neural network to predict fortune with MLflow logging
* **predict.py**
* a sample code to call registered model to predict testing data and save model to local file system
* **get_registered_model_via_rest_api.py**
* a script to test MLflow REST api
* **log_unsupported_model.py**
* a sample script to apply mlflow.pyfunc to package unsupported ml model which can be logged and registered by mlflow
* **optimize_model.py**
* a sample script to demonstrate how to use MLflow and TensorRT libs to optimize Pytorch model on edge devices and fetch it out on client
###### tags: `MLOps`

View File

@ -1,16 +0,0 @@
artifact_path: model
flavors:
python_function:
data: data
env: conda.yaml
loader_module: mlflow.pytorch
pickle_module_name: mlflow.pytorch.pickle_module
python_version: 3.10.9
pytorch:
code: null
model_data: data
pytorch_version: 1.13.1
mlflow_version: 1.30.0
model_uuid: 2382b7a39c064e7b9b1465cfd84140a3
run_id: 24469fc083d6470a9cad7f17a6eeeea0
utc_time_created: '2023-02-21 05:57:41.973454'

View File

@ -1,11 +0,0 @@
channels:
- conda-forge
dependencies:
- python=3.10.9
- pip<=23.0.1
- pip:
- mlflow
- cloudpickle==2.2.1
- torch==1.13.1
- tqdm==4.64.1
name: mlflow-env

View File

@ -1,7 +0,0 @@
python: 3.10.9
build_dependencies:
- pip==23.0.1
- setuptools==67.3.2
- wheel==0.38.4
dependencies:
- -r requirements.txt

View File

@ -1,4 +0,0 @@
mlflow
cloudpickle==2.2.1
torch==1.13.1
tqdm==4.64.1

View File

@ -1,16 +0,0 @@
artifact_path: cls_model
flavors:
python_function:
data: data
env: conda.yaml
loader_module: mlflow.pytorch
pickle_module_name: mlflow.pytorch.pickle_module
python_version: 3.10.9
pytorch:
code: null
model_data: data
pytorch_version: 1.13.1
mlflow_version: 1.30.0
model_uuid: e40643f3e1b9481896e1ae6ed30e8654
run_id: 2820b379bfc945358bfd516e5577846c
utc_time_created: '2023-02-21 05:33:10.779919'

View File

@ -1,10 +0,0 @@
channels:
- conda-forge
dependencies:
- python=3.10.9
- pip<=23.0.1
- pip:
- mlflow
- cloudpickle==2.2.1
- torch==1.13.1
name: mlflow-env

View File

@ -1,7 +0,0 @@
python: 3.10.9
build_dependencies:
- pip==23.0.1
- setuptools==67.3.2
- wheel==0.38.4
dependencies:
- -r requirements.txt

View File

@ -1,3 +0,0 @@
mlflow
cloudpickle==2.2.1
torch==1.13.1

View File

@ -1,16 +0,0 @@
artifact_path: models
flavors:
python_function:
data: data
env: conda.yaml
loader_module: mlflow.pytorch
pickle_module_name: mlflow.pytorch.pickle_module
python_version: 3.10.9
pytorch:
code: null
model_data: data
pytorch_version: 1.13.1
mlflow_version: 1.30.0
model_uuid: faf1bec9ecb64581b22a0b8e09a9cca8
run_id: 3ef01a1e3e3d4ba2be705da789bbb8e1
utc_time_created: '2023-02-21 05:07:17.344052'

View File

@ -1,10 +0,0 @@
channels:
- conda-forge
dependencies:
- python=3.10.9
- pip<=23.0.1
- pip:
- mlflow
- cloudpickle==2.2.1
- torch==1.13.1
name: mlflow-env

View File

@ -1,7 +0,0 @@
python: 3.10.9
build_dependencies:
- pip==23.0.1
- setuptools==67.3.2
- wheel==0.38.4
dependencies:
- -r requirements.txt

View File

@ -1,3 +0,0 @@
mlflow
cloudpickle==2.2.1
torch==1.13.1

View File

@ -1,16 +0,0 @@
artifact_path: cls_model
flavors:
python_function:
data: data
env: conda.yaml
loader_module: mlflow.pytorch
pickle_module_name: mlflow.pytorch.pickle_module
python_version: 3.10.9
pytorch:
code: null
model_data: data
pytorch_version: 1.13.1
mlflow_version: 1.30.0
model_uuid: a0ecc970cadb47a9b839283e9514732d
run_id: 63c7363339e042f4848d9041ba8deb82
utc_time_created: '2023-02-21 05:37:55.904472'

View File

@ -1,10 +0,0 @@
channels:
- conda-forge
dependencies:
- python=3.10.9
- pip<=23.0.1
- pip:
- mlflow
- cloudpickle==2.2.1
- torch==1.13.1
name: mlflow-env

View File

@ -1,7 +0,0 @@
python: 3.10.9
build_dependencies:
- pip==23.0.1
- setuptools==67.3.2
- wheel==0.38.4
dependencies:
- -r requirements.txt

View File

@ -1,3 +0,0 @@
mlflow
cloudpickle==2.2.1
torch==1.13.1

View File

@ -1,16 +0,0 @@
artifact_path: models
flavors:
python_function:
data: data
env: conda.yaml
loader_module: mlflow.pytorch
pickle_module_name: mlflow.pytorch.pickle_module
python_version: 3.10.9
pytorch:
code: null
model_data: data
pytorch_version: 1.13.1
mlflow_version: 1.30.0
model_uuid: 4ebe94bd0249452a90b3497d3b00a1c3
run_id: 6845ef0d54024cb3bdb32050f6a46fea
utc_time_created: '2023-02-21 05:25:14.020335'

View File

@ -1,10 +0,0 @@
channels:
- conda-forge
dependencies:
- python=3.10.9
- pip<=23.0.1
- pip:
- mlflow
- cloudpickle==2.2.1
- torch==1.13.1
name: mlflow-env

View File

@ -1,7 +0,0 @@
python: 3.10.9
build_dependencies:
- pip==23.0.1
- setuptools==67.3.2
- wheel==0.38.4
dependencies:
- -r requirements.txt

View File

@ -1,3 +0,0 @@
mlflow
cloudpickle==2.2.1
torch==1.13.1

View File

@ -1,16 +0,0 @@
artifact_path: models
flavors:
python_function:
data: data
env: conda.yaml
loader_module: mlflow.pytorch
pickle_module_name: mlflow.pytorch.pickle_module
python_version: 3.10.9
pytorch:
code: null
model_data: data
pytorch_version: 1.13.1
mlflow_version: 1.30.0
model_uuid: 8cd1e70114e548ea8d9bfb1bf468e285
run_id: 68e8a3cbbafa46538ebb8a60d80f185d
utc_time_created: '2023-02-21 05:05:46.624814'

View File

@ -1,10 +0,0 @@
channels:
- conda-forge
dependencies:
- python=3.10.9
- pip<=23.0.1
- pip:
- mlflow
- cloudpickle==2.2.1
- torch==1.13.1
name: mlflow-env

View File

@ -1,7 +0,0 @@
python: 3.10.9
build_dependencies:
- pip==23.0.1
- setuptools==67.3.2
- wheel==0.38.4
dependencies:
- -r requirements.txt

View File

@ -1,3 +0,0 @@
mlflow
cloudpickle==2.2.1
torch==1.13.1

View File

@ -1,16 +0,0 @@
artifact_path: models
flavors:
python_function:
data: data
env: conda.yaml
loader_module: mlflow.pytorch
pickle_module_name: mlflow.pytorch.pickle_module
python_version: 3.10.9
pytorch:
code: null
model_data: data
pytorch_version: 1.13.1
mlflow_version: 1.30.0
model_uuid: dd1b1e3a6b5f4274a5690a8843751ff3
run_id: 8ba27f225a7442be8816977c2077c510
utc_time_created: '2023-02-21 05:05:04.660670'

View File

@ -1,10 +0,0 @@
channels:
- conda-forge
dependencies:
- python=3.10.9
- pip<=23.0.1
- pip:
- mlflow
- cloudpickle==2.2.1
- torch==1.13.1
name: mlflow-env

View File

@ -1,7 +0,0 @@
python: 3.10.9
build_dependencies:
- pip==23.0.1
- setuptools==67.3.2
- wheel==0.38.4
dependencies:
- -r requirements.txt

View File

@ -1,3 +0,0 @@
mlflow
cloudpickle==2.2.1
torch==1.13.1

View File

@ -1,16 +0,0 @@
artifact_path: cls_model
flavors:
python_function:
data: data
env: conda.yaml
loader_module: mlflow.pytorch
pickle_module_name: mlflow.pytorch.pickle_module
python_version: 3.10.9
pytorch:
code: null
model_data: data
pytorch_version: 1.13.1
mlflow_version: 1.30.0
model_uuid: aaa800b217da4dd0b8f17e8dbfdc5c45
run_id: f1320882f24c4f489cbf85159627eaf8
utc_time_created: '2023-02-21 05:34:08.242864'

View File

@ -1,10 +0,0 @@
channels:
- conda-forge
dependencies:
- python=3.10.9
- pip<=23.0.1
- pip:
- mlflow
- cloudpickle==2.2.1
- torch==1.13.1
name: mlflow-env

View File

@ -1,7 +0,0 @@
python: 3.10.9
build_dependencies:
- pip==23.0.1
- setuptools==67.3.2
- wheel==0.38.4
dependencies:
- -r requirements.txt

View File

@ -1,3 +0,0 @@
mlflow
cloudpickle==2.2.1
torch==1.13.1

18
docker-compose.yaml Normal file
View File

@ -0,0 +1,18 @@
version: '3.7'
services:
mlflow_server:
image: ghcr.io/mlflow/mlflow:v2.1.1
restart: always
ports:
- 5001:5001
volumes:
- ~/python/test_mlflow/service:/home
healthcheck:
test: ["CMD", "curl", "-f", "http://0.0.0.0:5001"]
interval: 30s
timeout: 10s
retries: 3
start_period: 20s
command: bash -c "apt update && apt install -y curl && mlflow server --host 0.0.0.0 --port 5001 --backend-store-uri sqlite:////home/backend.db --registry-store-uri sqlite:////home/registry.db --artifacts-destination /home/artifacts --serve-artifacts"

View File

@ -1,15 +1,31 @@
name: torch name: mlflow
channels: channels:
- pytorch - pytorch
- conda-forge - conda-forge
dependencies: dependencies:
- alembic=1.9.4 - alembic=1.9.4
- aom=3.5.0 - aom=3.5.0
- appdirs=1.4.4 - arrow-cpp=9.0.0
- aws-c-auth=0.6.24
- aws-c-cal=0.5.20
- aws-c-common=0.8.11
- aws-c-compression=0.2.16
- aws-c-event-stream=0.2.18
- aws-c-http=0.7.4
- aws-c-io=0.13.17
- aws-c-mqtt=0.8.6
- aws-c-s3=0.2.4
- aws-c-sdkutils=0.1.7
- aws-checksums=0.1.14
- aws-crt-cpp=0.19.7
- aws-sdk-cpp=1.10.57
- bcrypt=3.2.2 - bcrypt=3.2.2
- blinker=1.5 - blinker=1.5
- brotli=1.0.9
- brotli-bin=1.0.9
- brotlipy=0.7.0 - brotlipy=0.7.0
- bzip2=1.0.8 - bzip2=1.0.8
- c-ares=1.18.1
- ca-certificates=2022.12.7 - ca-certificates=2022.12.7
- certifi=2022.12.7 - certifi=2022.12.7
- cffi=1.15.1 - cffi=1.15.1
@ -18,7 +34,9 @@ dependencies:
- cloudpickle=2.2.1 - cloudpickle=2.2.1
- colorama=0.4.6 - colorama=0.4.6
- configparser=5.3.0 - configparser=5.3.0
- contourpy=1.0.7
- cryptography=39.0.1 - cryptography=39.0.1
- cycler=0.11.0
- databricks-cli=0.17.4 - databricks-cli=0.17.4
- docker-py=6.0.0 - docker-py=6.0.0
- entrypoints=0.4 - entrypoints=0.4
@ -32,10 +50,13 @@ dependencies:
- fontconfig=2.14.2 - fontconfig=2.14.2
- fonts-conda-ecosystem=1 - fonts-conda-ecosystem=1
- fonts-conda-forge=1 - fonts-conda-forge=1
- fonttools=4.38.0
- freetype=2.12.1 - freetype=2.12.1
- gettext=0.21.1 - gettext=0.21.1
- gflags=2.2.2
- gitdb=4.0.10 - gitdb=4.0.10
- gitpython=3.1.31 - gitpython=3.1.31
- glog=0.6.0
- gmp=6.2.1 - gmp=6.2.1
- gnutls=3.7.8 - gnutls=3.7.8
- greenlet=2.0.2 - greenlet=2.0.2
@ -46,56 +67,85 @@ dependencies:
- importlib_resources=5.12.0 - importlib_resources=5.12.0
- itsdangerous=2.1.2 - itsdangerous=2.1.2
- jinja2=3.1.2 - jinja2=3.1.2
- joblib=1.2.0
- jpeg=9e - jpeg=9e
- kiwisolver=1.4.4
- krb5=1.20.1
- lame=3.100 - lame=3.100
- lcms2=2.14 - lcms2=2.14
- lerc=4.0.0 - lerc=4.0.0
- libabseil=20230125.0
- libblas=3.9.0 - libblas=3.9.0
- libbrotlicommon=1.0.9
- libbrotlidec=1.0.9
- libbrotlienc=1.0.9
- libcblas=3.9.0 - libcblas=3.9.0
- libcxx=14.0.6 - libcrc32c=1.1.2
- libcurl=7.88.1
- libcxx=15.0.7
- libdeflate=1.17 - libdeflate=1.17
- libedit=3.1.20191231
- libev=4.33
- libevent=2.1.10
- libffi=3.4.2 - libffi=3.4.2
- libgfortran=5.0.0 - libgfortran=5.0.0
- libgfortran5=11.3.0 - libgfortran5=11.3.0
- libgoogle-cloud=2.7.0
- libgrpc=1.51.1
- libiconv=1.17 - libiconv=1.17
- libidn2=2.3.4 - libidn2=2.3.4
- liblapack=3.9.0 - liblapack=3.9.0
- libllvm11=11.1.0
- libnghttp2=1.51.0
- libopenblas=0.3.21 - libopenblas=0.3.21
- libopus=1.3.1 - libopus=1.3.1
- libpng=1.6.39 - libpng=1.6.39
- libprotobuf=3.21.12 - libprotobuf=3.21.12
- libsodium=1.0.18 - libsodium=1.0.18
- libsqlite=3.40.0 - libsqlite=3.40.0
- libssh2=1.10.0
- libtasn1=4.19.0 - libtasn1=4.19.0
- libthrift=0.18.0
- libtiff=4.5.0 - libtiff=4.5.0
- libunistring=0.9.10 - libunistring=0.9.10
- libutf8proc=2.8.0
- libvpx=1.11.0 - libvpx=1.11.0
- libwebp-base=1.2.4 - libwebp-base=1.2.4
- libxcb=1.13 - libxcb=1.13
- libxml2=2.10.3 - libxml2=2.10.3
- libzlib=1.2.13 - libzlib=1.2.13
- llvm-openmp=15.0.7 - llvm-openmp=15.0.7
- llvmlite=0.39.1
- lz4-c=1.9.4
- mako=1.2.4 - mako=1.2.4
- markdown=3.4.1
- markupsafe=2.1.2 - markupsafe=2.1.2
- mlflow=1.30.0 - matplotlib-base=3.7.0
- mlflow=2.1.1
- munkres=1.1.4
- ncurses=6.3 - ncurses=6.3
- nettle=3.8.1 - nettle=3.8.1
- numpy=1.24.2 - numba=0.56.4
- numpy=1.23.5
- oauthlib=3.2.2 - oauthlib=3.2.2
- openh264=2.3.1 - openh264=2.3.1
- openjpeg=2.5.0 - openjpeg=2.5.0
- openssl=3.0.8 - openssl=3.0.8
- orc=1.8.2
- p11-kit=0.24.1 - p11-kit=0.24.1
- packaging=21.3 - packaging=22.0
- pandas=1.5.3 - pandas=1.5.3
- paramiko=3.0.0 - paramiko=3.0.0
- parquet-cpp=1.5.1
- pillow=9.4.0 - pillow=9.4.0
- pip=23.0.1 - pip=23.0.1
- pooch=1.6.0 - platformdirs=3.0.0
- pooch=1.7.0
- prometheus_client=0.16.0 - prometheus_client=0.16.0
- prometheus_flask_exporter=0.22.0 - prometheus_flask_exporter=0.22.2
- protobuf=4.21.12 - protobuf=4.21.12
- pthread-stubs=0.4 - pthread-stubs=0.4
- pyarrow=9.0.0
- pycparser=2.21 - pycparser=2.21
- pyjwt=2.6.0 - pyjwt=2.6.0
- pynacl=1.5.0 - pynacl=1.5.0
@ -110,22 +160,30 @@ dependencies:
- pywin32-on-windows=0.1.0 - pywin32-on-windows=0.1.0
- pyyaml=6.0 - pyyaml=6.0
- querystring_parser=1.2.4 - querystring_parser=1.2.4
- re2=2023.02.02
- readline=8.1.2 - readline=8.1.2
- requests=2.28.2 - requests=2.28.2
- scipy=1.10.0 - scikit-learn=1.2.1
- setuptools=67.3.2 - scipy=1.10.1
- setuptools=67.4.0
- shap=0.41.0
- six=1.16.0 - six=1.16.0
- slicer=0.0.7
- smmap=3.0.5 - smmap=3.0.5
- snappy=1.1.9
- sqlalchemy=1.4.46 - sqlalchemy=1.4.46
- sqlparse=0.4.3 - sqlparse=0.4.3
- svt-av1=1.4.1 - svt-av1=1.4.1
- tabulate=0.9.0 - tabulate=0.9.0
- threadpoolctl=3.1.0
- tk=8.6.12 - tk=8.6.12
- torchaudio=0.13.1 - torchaudio=0.13.1
- torchvision=0.14.1 - torchvision=0.14.1
- tqdm=4.64.1 - tqdm=4.64.1
- typing-extensions=4.4.0
- typing_extensions=4.4.0 - typing_extensions=4.4.0
- tzdata=2022g - tzdata=2022g
- unicodedata2=15.0.0
- urllib3=1.26.14 - urllib3=1.26.14
- websocket-client=1.5.1 - websocket-client=1.5.1
- werkzeug=2.2.3 - werkzeug=2.2.3
@ -136,6 +194,7 @@ dependencies:
- xorg-libxdmcp=1.1.3 - xorg-libxdmcp=1.1.3
- xz=5.2.6 - xz=5.2.6
- yaml=0.2.5 - yaml=0.2.5
- zipp=3.14.0 - zipp=3.15.0
- zlib=1.2.13
- zstd=1.5.2 - zstd=1.5.2
prefix: /Users/xiao_deng/miniforge3/envs/torch prefix: /Users/xiao_deng/miniforge3/envs/mlflow

View File

@ -0,0 +1,30 @@
# get_registered_model_via_rest_api.py
#
# author: deng
# date : 20230224
import json
import requests
def main():
registered_model_name = 'fortune_predict_model'
production_model_version = None
query = {'name': registered_model_name}
res = requests.get('http://127.0.0.1:5001/api/2.0/mlflow/registered-models/get', params=query)
content = json.loads(res.text)
print(content)
for model in content['registered_model']['latest_versions']:
if model['current_stage'] == 'Production':
production_model_version = model['version']
if production_model_version is not None:
query = {'name': registered_model_name, 'version': production_model_version}
res = requests.get('http://127.0.0.1:5001/api/2.0/mlflow/model-versions/get-download-uri', params=query)
print(res.text)
if __name__ == '__main__':
main()

77
log_unsupported_model.py Normal file
View File

@ -0,0 +1,77 @@
# log_unsupported_model.py
#
# author: deng
# date : 20230309
import mlflow
import pandas as pd
class CustomModel(mlflow.pyfunc.PythonModel):
""" A mlflow wrapper to package unsupported model """
def __init__(self):
self.model = None
def load_model(self):
# load your custom model here
self.model = lambda value: value * 2
def predict(self,
model_input: pd.DataFrame) -> pd.DataFrame:
if self.model is None:
self.load_model()
output = model_input.apply(self.model)
return output
def log_model(server_uri:str,
exp_name: str,
registered_model_name: str) -> None:
# init mlflow
mlflow.set_tracking_uri(server_uri)
mlflow.set_experiment(exp_name)
# register custom model
model = CustomModel()
mlflow.pyfunc.log_model(
artifact_path='model',
python_model=model,
registered_model_name=registered_model_name
)
def pull_model(server_uri: str,
exp_name: str,
registered_model_name: str) -> None:
# init mlflow
mlflow.set_tracking_uri(server_uri)
mlflow.set_experiment(exp_name)
# pull model from registry
model = mlflow.pyfunc.load_model(f'models:/{registered_model_name}/latest')
model = model.unwrap_python_model() # get CustomModel object
print(f'Model loaded. model type: {type(model)}')
# test model availability
fake_data = pd.DataFrame([1, 3, 5])
output = model.predict(fake_data)
print(f'input data: {fake_data}, predictions: {output}')
# save it to local file system
mlflow.pyfunc.save_model(path=f'./model/{exp_name}', python_model=model)
if __name__ == '__main__':
server_uri = 'http://127.0.0.1:5001'
exp_name = 'custom_model'
registered_model_name = 'custom_model'
log_model(server_uri, exp_name, registered_model_name)
pull_model(server_uri, exp_name, registered_model_name)

BIN
mlflow.db

Binary file not shown.

75
optimize_model.py Normal file
View File

@ -0,0 +1,75 @@
# optimize_model.py
#
# author: deng
# date : 20230418
import shutil
from pathlib import Path
import torch
import mlflow
def optimize_pytorch_model(run_id: str,
model_artifact_dir: str = 'model') -> None:
"""Optimize Pytorch model from MLflow server on edge device
Args:
run_id (str): mlflow run id
model_artifact_dir (str, optional): model dir of run on server. Defaults to 'model'.
"""
download_path = Path('./model/downloaded_pytorch_model')
if download_path.is_dir():
print(f'Remove existed dir: {download_path}')
shutil.rmtree(download_path)
# Download model artifacts to local file system
mlflow_model = mlflow.pytorch.load_model(Path(f'runs:/{run_id}').joinpath(model_artifact_dir).as_posix())
mlflow.pytorch.save_model(mlflow_model, download_path)
# Optimize model
model = torch.load(download_path.joinpath('data/model.pth'))
dummy_input = torch.randn(5)
torch.onnx.export(model, dummy_input, download_path.joinpath('data/model.onnx'))
# we can not call TensorRT on macOS, so imagine we get a serialized model😘
download_path.joinpath('data/model.trt').touch()
# Sent optimized model back to given run
with mlflow.start_run(run_id=run_id):
mlflow.log_artifact(download_path.joinpath('data/model.trt'), 'model/data')
print(f'Optimized model had been uploaded to server: {mlflow.get_tracking_uri()}')
def download_optimized_model(run_id: str,
save_dir: str,
model_artifact_path: str = 'model/data/model.trt') -> None:
"""Download optimized model from MLflow server on clent
Args:
run_id (str): mlflow run id
save_dir (str): dir of local file system to save model
model_artifact_path (str, optional): artifact path of model on server. Defaults to 'model/data/model.trt'.
"""
mlflow.artifacts.download_artifacts(
run_id= run_id,
artifact_path=model_artifact_path,
dst_path=save_dir
)
print(f'Optimized model had been saved, please check: {Path(save_dir).joinpath(model_artifact_path)}')
if __name__ == '__main__':
mlflow.set_tracking_uri('http://127.0.0.1:5001')
optimize_pytorch_model(
run_id='f1b7b9a5ba934f158c07975a8a332de5'
)
download_optimized_model(
run_id='f1b7b9a5ba934f158c07975a8a332de5',
save_dir='./model/download_tensorrt'
)

View File

@ -7,15 +7,22 @@ import torch
import mlflow import mlflow
if __name__ == '__main__': def main():
# set MLflow server # set MLflow server
mlflow.set_tracking_uri('http://127.0.0.1:5000') mlflow.set_tracking_uri('http://127.0.0.1:5001')
# load production model # load production model
model = mlflow.pytorch.load_model('models:/cls_model/production') model = mlflow.pytorch.load_model('models:/fortune_predict_model/production')
# predict # predict
fake_data = torch.randn(10) my_personal_info = torch.randn(5)
output = model(fake_data) my_fortune = model(my_personal_info)
print(output) print(my_fortune)
# save model and env to local file system
mlflow.pytorch.save_model(model, './model/fortune_predict_model')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,98 @@
# train.py
#
# author: deng
# date : 20230221
import torch
import torch.nn as nn
from torch.optim import SGD
import mlflow
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import Schema, ColSpec
from tqdm import tqdm
class Net(nn.Module):
""" define a simple neural network model """
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(5, 3)
self.fc2 = nn.Linear(3, 1)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
def train(model, dataloader, criterion, optimizer, epochs):
""" define the training function """
for epoch in tqdm(range(epochs), 'Epochs'):
for batch, (inputs, labels) in enumerate(dataloader):
# forwarding
outputs = model(inputs)
loss = criterion(outputs, labels)
# update gradient
optimizer.zero_grad()
loss.backward()
optimizer.step()
# log loss
mlflow.log_metric('train_loss', loss.item(), step=epoch)
return loss
if __name__ == '__main__':
# set hyper parameters
learning_rate = 1e-2
batch_size = 10
epochs = 20
# create a dataloader with fake data
dataloader = [(torch.randn(5), torch.randn(1)) for _ in range(100)]
dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size)
# create the model, criterion, and optimizer
model = Net()
criterion = nn.MSELoss()
optimizer = SGD(model.parameters(), lr=learning_rate)
# set the tracking URI to the model registry
mlflow.set_tracking_uri('http://127.0.0.1:5001')
mlflow.set_experiment('train_fortune_predict_model')
# start a new MLflow run
with mlflow.start_run():
# train the model
loss = train(model, dataloader, criterion, optimizer, epochs)
# log some additional metrics
mlflow.log_metric('final_loss', loss.item())
mlflow.log_param('learning_rate', learning_rate)
mlflow.log_param('batch_size', batch_size)
# create a signature to record model input and output info
input_schema = Schema([
ColSpec('float', 'age'),
ColSpec('float', 'mood level'),
ColSpec('float', 'health level'),
ColSpec('float', 'hungry level'),
ColSpec('float', 'sexy level')
])
output_schema = Schema([ColSpec('float', 'fortune')])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
# log trained model
mlflow.pytorch.log_model(model, 'model', signature=signature)
# log training code
mlflow.log_artifact('./train.py', 'code')
print('Completed.')

BIN
service/backend.db Normal file

Binary file not shown.

BIN
service/registry.db Normal file

Binary file not shown.

View File

@ -1,7 +0,0 @@
#!/bin/bash
# start_mlflow_server.sh
#
# author: deng
# date : 20230221
mlflow server --backend-store-uri sqlite:///mlflow.db --default-artifact-root ./artifacts

14
test_pytorch_m1.py Normal file
View File

@ -0,0 +1,14 @@
# test_pytorch_m1.py
# Ref: https://towardsdatascience.com/installing-pytorch-on-apple-m1-chip-with-gpu-acceleration-3351dc44d67c
#
# author: deng
# date : 20230301
import torch
import math
print('This ensures that the current MacOS version is at least 12.3+')
print(torch.backends.mps.is_available())
print('\nThis ensures that the current current PyTorch installation was built with MPS activated.')
print(torch.backends.mps.is_built())

View File

@ -6,16 +6,18 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.optim import SGD from torch.optim import SGD
from tqdm import tqdm
import mlflow import mlflow
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import Schema, ColSpec
from tqdm import tqdm
class Net(nn.Module): class Net(nn.Module):
""" define a simple neural network model """ """ define a simple neural network model """
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5) self.fc1 = nn.Linear(5, 3)
self.fc2 = nn.Linear(5, 1) self.fc2 = nn.Linear(3, 1)
def forward(self, x): def forward(self, x):
x = self.fc1(x) x = self.fc1(x)
@ -28,7 +30,7 @@ def train(model, dataloader, criterion, optimizer, epochs):
""" define the training function """ """ define the training function """
for epoch in tqdm(range(epochs), 'Epochs'): for epoch in tqdm(range(epochs), 'Epochs'):
for i, (inputs, labels) in enumerate(dataloader): for batch, (inputs, labels) in enumerate(dataloader):
# forwarding # forwarding
outputs = model(inputs) outputs = model(inputs)
@ -40,7 +42,7 @@ def train(model, dataloader, criterion, optimizer, epochs):
optimizer.step() optimizer.step()
# log loss # log loss
mlflow.log_metric('train_loss', loss.item(), step=i) mlflow.log_metric('train_loss', loss.item(), step=epoch)
return loss return loss
@ -49,11 +51,12 @@ if __name__ == '__main__':
# set hyper parameters # set hyper parameters
learning_rate = 1e-2 learning_rate = 1e-2
batch_size = 10
epochs = 20 epochs = 20
# create a dataloader with fake data # create a dataloader with fake data
dataloader = [(torch.randn(10), torch.randn(1)) for _ in range(100)] dataloader = [(torch.randn(5), torch.randn(1)) for _ in range(100)]
dataloader = torch.utils.data.DataLoader(dataloader, batch_size=10) dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size)
# create the model, criterion, and optimizer # create the model, criterion, and optimizer
model = Net() model = Net()
@ -61,19 +64,35 @@ if __name__ == '__main__':
optimizer = SGD(model.parameters(), lr=learning_rate) optimizer = SGD(model.parameters(), lr=learning_rate)
# set the tracking URI to the model registry # set the tracking URI to the model registry
mlflow.set_tracking_uri('http://127.0.0.1:5000') mlflow.set_tracking_uri('http://127.0.0.1:5001')
mlflow.set_experiment('train_fortune_predict_model')
# start the MLflow run # start a new MLflow run
with mlflow.start_run(): with mlflow.start_run():
# train the model and log the loss # train the model
loss = train(model, dataloader, criterion, optimizer, epochs) loss = train(model, dataloader, criterion, optimizer, epochs)
# log some additional metrics # log some additional metrics
mlflow.log_metric('final_loss', loss.item()) mlflow.log_metric('final_loss', loss.item())
mlflow.log_param('learning_rate', learning_rate) mlflow.log_param('learning_rate', learning_rate)
mlflow.log_param('batch_size', batch_size)
# create a signature to record model input and output info
input_schema = Schema([
ColSpec('float', 'age'),
ColSpec('float', 'mood level'),
ColSpec('float', 'health level'),
ColSpec('float', 'hungry level'),
ColSpec('float', 'sexy level')
])
output_schema = Schema([ColSpec('float', 'fortune')])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
# log trained model # log trained model
mlflow.pytorch.log_model(model, 'model') mlflow.pytorch.log_model(model, 'model', signature=signature)
# log training code
mlflow.log_artifact('./train.py', 'code')
print('Completed.') print('Completed.')