# optimize_model.py # # author: deng # date : 20230418 import shutil from pathlib import Path import torch import mlflow def optimize_pytorch_model(run_id: str, model_artifact_dir: str = 'model') -> None: """Optimize Pytorch model from MLflow server on edge device Args: run_id (str): mlflow run id model_artifact_dir (str, optional): model dir of run on server. Defaults to 'model'. """ download_path = Path('./model/downloaded_pytorch_model') if download_path.is_dir(): print(f'Remove existed dir: {download_path}') shutil.rmtree(download_path) # Download model artifacts to local file system mlflow_model = mlflow.pytorch.load_model(Path(f'runs:/{run_id}').joinpath(model_artifact_dir).as_posix()) mlflow.pytorch.save_model(mlflow_model, download_path) # Optimize model model = torch.load(download_path.joinpath('data/model.pth')) dummy_input = torch.randn(5) torch.onnx.export(model, dummy_input, download_path.joinpath('data/model.onnx')) # we can not call TensorRT on macOS, so imagine we get a serialized model😘 download_path.joinpath('data/model.trt').touch() # Sent optimized model back to given run with mlflow.start_run(run_id=run_id): mlflow.log_artifact(download_path.joinpath('data/model.trt'), 'model/data') print(f'Optimized model had been uploaded to server: {mlflow.get_tracking_uri()}') def download_optimized_model(run_id: str, save_dir: str, model_artifact_path: str = 'model/data/model.trt') -> None: """Download optimized model from MLflow server on clent Args: run_id (str): mlflow run id save_dir (str): dir of local file system to save model model_artifact_path (str, optional): artifact path of model on server. Defaults to 'model/data/model.trt'. """ mlflow.artifacts.download_artifacts( run_id= run_id, artifact_path=model_artifact_path, dst_path=save_dir ) print(f'Optimized model had been saved, please check: {Path(save_dir).joinpath(model_artifact_path)}') if __name__ == '__main__': mlflow.set_tracking_uri('http://127.0.0.1:5001') optimize_pytorch_model( run_id='f1b7b9a5ba934f158c07975a8a332de5' ) download_optimized_model( run_id='f1b7b9a5ba934f158c07975a8a332de5', save_dir='./model/download_tensorrt' )