From 1b90edee59056b73f043d135fb16ae10b253f16a Mon Sep 17 00:00:00 2001 From: deng Date: Wed, 19 Apr 2023 11:46:33 +0800 Subject: [PATCH] add demo to download optimized model on client --- .gitignore | 1 + README.md | 2 ++ optimize_model.py | 40 ++++++++++++++++++++++++++++++++++------ 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 99986cc..315a2b0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .DS_Store +__pycache__ model \ No newline at end of file diff --git a/README.md b/README.md index 31257d0..80835c4 100644 --- a/README.md +++ b/README.md @@ -29,5 +29,7 @@ Try to use [MLflow](https://mlflow.org) platform to log PyTorch model training, * a script to test MLflow REST api * **log_unsupported_model.py** * a sample script to apply mlflow.pyfunc to package unsupported ml model which can be logged and registered by mlflow +* **optimize_model.py** + * a sample script to demonstrate how to use MLflow and TensorRT libs to optimize Pytorch model on edge devices and fetch it out on client ###### tags: `MLOps` \ No newline at end of file diff --git a/optimize_model.py b/optimize_model.py index 343ded8..f7549d0 100644 --- a/optimize_model.py +++ b/optimize_model.py @@ -10,11 +10,13 @@ import torch import mlflow -def optimize_pytorch_model(run_id: str) -> None: - """Optimize Pytorch model on MLflow server, the optimized model will be sent back +def optimize_pytorch_model(run_id: str, + model_artifact_dir: str = 'model') -> None: + """Optimize Pytorch model from MLflow server on edge device Args: run_id (str): mlflow run id + model_artifact_dir (str, optional): model dir of run on server. Defaults to 'model'. """ download_path = Path('./model/downloaded_pytorch_model') @@ -22,26 +24,52 @@ def optimize_pytorch_model(run_id: str) -> None: print(f'Remove existed dir: {download_path}') shutil.rmtree(download_path) - # Download Pytorch model to local file system - mlflow_model = mlflow.pytorch.load_model(f'runs:/{run_id}/model') + # Download model artifacts to local file system + mlflow_model = mlflow.pytorch.load_model(Path(f'runs:/{run_id}').joinpath(model_artifact_dir).as_posix()) mlflow.pytorch.save_model(mlflow_model, download_path) # Optimize model model = torch.load(download_path.joinpath('data/model.pth')) dummy_input = torch.randn(5) torch.onnx.export(model, dummy_input, download_path.joinpath('data/model.onnx')) - # we can not call TensorRT on macOS, so imagine we get a serialized model + # we can not call TensorRT on macOS, so imagine we get a serialized model😘 download_path.joinpath('data/model.trt').touch() - # Save optimized model to given run + # Sent optimized model back to given run with mlflow.start_run(run_id=run_id): mlflow.log_artifact(download_path.joinpath('data/model.trt'), 'model/data') print(f'Optimized model had been uploaded to server: {mlflow.get_tracking_uri()}') +def download_optimized_model(run_id: str, + save_dir: str, + model_artifact_path: str = 'model/data/model.trt') -> None: + """Download optimized model from MLflow server on clent + + Args: + run_id (str): mlflow run id + save_dir (str): dir of local file system to save model + model_artifact_path (str, optional): artifact path of model on server. Defaults to 'model/data/model.trt'. + """ + + mlflow.artifacts.download_artifacts( + run_id= run_id, + artifact_path=model_artifact_path, + dst_path=save_dir + ) + + print(f'Optimized model had been saved, please check: {Path(save_dir).joinpath(model_artifact_path)}') + + if __name__ == '__main__': mlflow.set_tracking_uri('http://127.0.0.1:5001') + optimize_pytorch_model( run_id='f1b7b9a5ba934f158c07975a8a332de5' ) + + download_optimized_model( + run_id='f1b7b9a5ba934f158c07975a8a332de5', + save_dir='./model/download_tensorrt' + )