script to optimize pytorch model on server
This commit is contained in:
		
							
								
								
									
										47
									
								
								optimize_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								optimize_model.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,47 @@ | |||||||
|  | # optimize_model.py | ||||||
|  | # | ||||||
|  | # author: deng | ||||||
|  | # date  : 20230418 | ||||||
|  |  | ||||||
|  | import shutil | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import mlflow | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def optimize_pytorch_model(run_id: str) -> None: | ||||||
|  |     """Optimize Pytorch model on MLflow server, the optimized model will be sent back | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         run_id (str): mlflow run id | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     download_path = Path('./model/downloaded_pytorch_model') | ||||||
|  |     if download_path.is_dir(): | ||||||
|  |         print(f'Remove existed dir: {download_path}') | ||||||
|  |         shutil.rmtree(download_path) | ||||||
|  |  | ||||||
|  |     # Download Pytorch model to local file system | ||||||
|  |     mlflow_model = mlflow.pytorch.load_model(f'runs:/{run_id}/model') | ||||||
|  |     mlflow.pytorch.save_model(mlflow_model, download_path) | ||||||
|  |  | ||||||
|  |     # Optimize model | ||||||
|  |     model = torch.load(download_path.joinpath('data/model.pth')) | ||||||
|  |     dummy_input = torch.randn(5) | ||||||
|  |     torch.onnx.export(model, dummy_input, download_path.joinpath('data/model.onnx')) | ||||||
|  |     # we can not call TensorRT on macOS, so imagine we get a serialized model | ||||||
|  |     download_path.joinpath('data/model.trt').touch() | ||||||
|  |  | ||||||
|  |     # Save optimized model to given run | ||||||
|  |     with mlflow.start_run(run_id=run_id): | ||||||
|  |         mlflow.log_artifact(download_path.joinpath('data/model.trt'), 'model/data') | ||||||
|  |         print(f'Optimized model had been uploaded to server: {mlflow.get_tracking_uri()}') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |     mlflow.set_tracking_uri('http://127.0.0.1:5001') | ||||||
|  |     optimize_pytorch_model( | ||||||
|  |         run_id='f1b7b9a5ba934f158c07975a8a332de5' | ||||||
|  |     ) | ||||||
		Reference in New Issue
	
	Block a user