48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
# optimize_model.py
|
|
#
|
|
# author: deng
|
|
# date : 20230418
|
|
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import mlflow
|
|
|
|
|
|
def optimize_pytorch_model(run_id: str) -> None:
|
|
"""Optimize Pytorch model on MLflow server, the optimized model will be sent back
|
|
|
|
Args:
|
|
run_id (str): mlflow run id
|
|
"""
|
|
|
|
download_path = Path('./model/downloaded_pytorch_model')
|
|
if download_path.is_dir():
|
|
print(f'Remove existed dir: {download_path}')
|
|
shutil.rmtree(download_path)
|
|
|
|
# Download Pytorch model to local file system
|
|
mlflow_model = mlflow.pytorch.load_model(f'runs:/{run_id}/model')
|
|
mlflow.pytorch.save_model(mlflow_model, download_path)
|
|
|
|
# Optimize model
|
|
model = torch.load(download_path.joinpath('data/model.pth'))
|
|
dummy_input = torch.randn(5)
|
|
torch.onnx.export(model, dummy_input, download_path.joinpath('data/model.onnx'))
|
|
# we can not call TensorRT on macOS, so imagine we get a serialized model
|
|
download_path.joinpath('data/model.trt').touch()
|
|
|
|
# Save optimized model to given run
|
|
with mlflow.start_run(run_id=run_id):
|
|
mlflow.log_artifact(download_path.joinpath('data/model.trt'), 'model/data')
|
|
print(f'Optimized model had been uploaded to server: {mlflow.get_tracking_uri()}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
mlflow.set_tracking_uri('http://127.0.0.1:5001')
|
|
optimize_pytorch_model(
|
|
run_id='f1b7b9a5ba934f158c07975a8a332de5'
|
|
)
|