test_mlflow/optimize_model.py

76 lines
2.4 KiB
Python

# optimize_model.py
#
# author: deng
# date : 20230418
import shutil
from pathlib import Path
import torch
import mlflow
def optimize_pytorch_model(run_id: str,
model_artifact_dir: str = 'model') -> None:
"""Optimize Pytorch model from MLflow server on edge device
Args:
run_id (str): mlflow run id
model_artifact_dir (str, optional): model dir of run on server. Defaults to 'model'.
"""
download_path = Path('./model/downloaded_pytorch_model')
if download_path.is_dir():
print(f'Remove existed dir: {download_path}')
shutil.rmtree(download_path)
# Download model artifacts to local file system
mlflow_model = mlflow.pytorch.load_model(Path(f'runs:/{run_id}').joinpath(model_artifact_dir).as_posix())
mlflow.pytorch.save_model(mlflow_model, download_path)
# Optimize model
model = torch.load(download_path.joinpath('data/model.pth'))
dummy_input = torch.randn(5)
torch.onnx.export(model, dummy_input, download_path.joinpath('data/model.onnx'))
# we can not call TensorRT on macOS, so imagine we get a serialized model😘
download_path.joinpath('data/model.trt').touch()
# Sent optimized model back to given run
with mlflow.start_run(run_id=run_id):
mlflow.log_artifact(download_path.joinpath('data/model.trt'), 'model/data')
print(f'Optimized model had been uploaded to server: {mlflow.get_tracking_uri()}')
def download_optimized_model(run_id: str,
save_dir: str,
model_artifact_path: str = 'model/data/model.trt') -> None:
"""Download optimized model from MLflow server on clent
Args:
run_id (str): mlflow run id
save_dir (str): dir of local file system to save model
model_artifact_path (str, optional): artifact path of model on server. Defaults to 'model/data/model.trt'.
"""
mlflow.artifacts.download_artifacts(
run_id= run_id,
artifact_path=model_artifact_path,
dst_path=save_dir
)
print(f'Optimized model had been saved, please check: {Path(save_dir).joinpath(model_artifact_path)}')
if __name__ == '__main__':
mlflow.set_tracking_uri('http://127.0.0.1:5001')
optimize_pytorch_model(
run_id='f1b7b9a5ba934f158c07975a8a332de5'
)
download_optimized_model(
run_id='f1b7b9a5ba934f158c07975a8a332de5',
save_dir='./model/download_tensorrt'
)