add demo to download optimized model on client
This commit is contained in:
parent
e1f143736e
commit
1b90edee59
|
@ -1,2 +1,3 @@
|
|||
.DS_Store
|
||||
__pycache__
|
||||
model
|
|
@ -29,5 +29,7 @@ Try to use [MLflow](https://mlflow.org) platform to log PyTorch model training,
|
|||
* a script to test MLflow REST api
|
||||
* **log_unsupported_model.py**
|
||||
* a sample script to apply mlflow.pyfunc to package unsupported ml model which can be logged and registered by mlflow
|
||||
* **optimize_model.py**
|
||||
* a sample script to demonstrate how to use MLflow and TensorRT libs to optimize Pytorch model on edge devices and fetch it out on client
|
||||
|
||||
###### tags: `MLOps`
|
|
@ -10,11 +10,13 @@ import torch
|
|||
import mlflow
|
||||
|
||||
|
||||
def optimize_pytorch_model(run_id: str) -> None:
|
||||
"""Optimize Pytorch model on MLflow server, the optimized model will be sent back
|
||||
def optimize_pytorch_model(run_id: str,
|
||||
model_artifact_dir: str = 'model') -> None:
|
||||
"""Optimize Pytorch model from MLflow server on edge device
|
||||
|
||||
Args:
|
||||
run_id (str): mlflow run id
|
||||
model_artifact_dir (str, optional): model dir of run on server. Defaults to 'model'.
|
||||
"""
|
||||
|
||||
download_path = Path('./model/downloaded_pytorch_model')
|
||||
|
@ -22,26 +24,52 @@ def optimize_pytorch_model(run_id: str) -> None:
|
|||
print(f'Remove existed dir: {download_path}')
|
||||
shutil.rmtree(download_path)
|
||||
|
||||
# Download Pytorch model to local file system
|
||||
mlflow_model = mlflow.pytorch.load_model(f'runs:/{run_id}/model')
|
||||
# Download model artifacts to local file system
|
||||
mlflow_model = mlflow.pytorch.load_model(Path(f'runs:/{run_id}').joinpath(model_artifact_dir).as_posix())
|
||||
mlflow.pytorch.save_model(mlflow_model, download_path)
|
||||
|
||||
# Optimize model
|
||||
model = torch.load(download_path.joinpath('data/model.pth'))
|
||||
dummy_input = torch.randn(5)
|
||||
torch.onnx.export(model, dummy_input, download_path.joinpath('data/model.onnx'))
|
||||
# we can not call TensorRT on macOS, so imagine we get a serialized model
|
||||
# we can not call TensorRT on macOS, so imagine we get a serialized model😘
|
||||
download_path.joinpath('data/model.trt').touch()
|
||||
|
||||
# Save optimized model to given run
|
||||
# Sent optimized model back to given run
|
||||
with mlflow.start_run(run_id=run_id):
|
||||
mlflow.log_artifact(download_path.joinpath('data/model.trt'), 'model/data')
|
||||
print(f'Optimized model had been uploaded to server: {mlflow.get_tracking_uri()}')
|
||||
|
||||
|
||||
def download_optimized_model(run_id: str,
|
||||
save_dir: str,
|
||||
model_artifact_path: str = 'model/data/model.trt') -> None:
|
||||
"""Download optimized model from MLflow server on clent
|
||||
|
||||
Args:
|
||||
run_id (str): mlflow run id
|
||||
save_dir (str): dir of local file system to save model
|
||||
model_artifact_path (str, optional): artifact path of model on server. Defaults to 'model/data/model.trt'.
|
||||
"""
|
||||
|
||||
mlflow.artifacts.download_artifacts(
|
||||
run_id= run_id,
|
||||
artifact_path=model_artifact_path,
|
||||
dst_path=save_dir
|
||||
)
|
||||
|
||||
print(f'Optimized model had been saved, please check: {Path(save_dir).joinpath(model_artifact_path)}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
mlflow.set_tracking_uri('http://127.0.0.1:5001')
|
||||
|
||||
optimize_pytorch_model(
|
||||
run_id='f1b7b9a5ba934f158c07975a8a332de5'
|
||||
)
|
||||
|
||||
download_optimized_model(
|
||||
run_id='f1b7b9a5ba934f158c07975a8a332de5',
|
||||
save_dir='./model/download_tensorrt'
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue