diff --git a/optimize_model.py b/optimize_model.py new file mode 100644 index 0000000..343ded8 --- /dev/null +++ b/optimize_model.py @@ -0,0 +1,47 @@ +# optimize_model.py +# +# author: deng +# date : 20230418 + +import shutil +from pathlib import Path + +import torch +import mlflow + + +def optimize_pytorch_model(run_id: str) -> None: + """Optimize Pytorch model on MLflow server, the optimized model will be sent back + + Args: + run_id (str): mlflow run id + """ + + download_path = Path('./model/downloaded_pytorch_model') + if download_path.is_dir(): + print(f'Remove existed dir: {download_path}') + shutil.rmtree(download_path) + + # Download Pytorch model to local file system + mlflow_model = mlflow.pytorch.load_model(f'runs:/{run_id}/model') + mlflow.pytorch.save_model(mlflow_model, download_path) + + # Optimize model + model = torch.load(download_path.joinpath('data/model.pth')) + dummy_input = torch.randn(5) + torch.onnx.export(model, dummy_input, download_path.joinpath('data/model.onnx')) + # we can not call TensorRT on macOS, so imagine we get a serialized model + download_path.joinpath('data/model.trt').touch() + + # Save optimized model to given run + with mlflow.start_run(run_id=run_id): + mlflow.log_artifact(download_path.joinpath('data/model.trt'), 'model/data') + print(f'Optimized model had been uploaded to server: {mlflow.get_tracking_uri()}') + + +if __name__ == '__main__': + + mlflow.set_tracking_uri('http://127.0.0.1:5001') + optimize_pytorch_model( + run_id='f1b7b9a5ba934f158c07975a8a332de5' + )