Compare commits
	
		
			2 Commits
		
	
	
		
			e1f143736e
			...
			2ce9a31883
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 2ce9a31883 | |||
| 1b90edee59 | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -1,2 +1,3 @@ | |||||||
| .DS_Store | .DS_Store | ||||||
|  | __pycache__ | ||||||
| model | model | ||||||
| @ -29,5 +29,7 @@ Try to use [MLflow](https://mlflow.org) platform to log PyTorch model training, | |||||||
|   * a script to test MLflow REST api |   * a script to test MLflow REST api | ||||||
| * **log_unsupported_model.py** | * **log_unsupported_model.py** | ||||||
|   * a sample script to apply mlflow.pyfunc to package unsupported ml model which can be logged and registered by mlflow |   * a sample script to apply mlflow.pyfunc to package unsupported ml model which can be logged and registered by mlflow | ||||||
|  | * **optimize_model.py** | ||||||
|  |   * a sample script to demonstrate how to use MLflow and TensorRT libs to optimize Pytorch model on edge devices and fetch it out on client | ||||||
|  |  | ||||||
| ###### tags: `MLOps` | ###### tags: `MLOps` | ||||||
| @ -10,11 +10,13 @@ import torch | |||||||
| import mlflow | import mlflow | ||||||
|  |  | ||||||
|  |  | ||||||
| def optimize_pytorch_model(run_id: str) -> None: | def optimize_pytorch_model(run_id: str, | ||||||
|     """Optimize Pytorch model on MLflow server, the optimized model will be sent back |                            model_artifact_dir: str = 'model') -> None: | ||||||
|  |     """Optimize Pytorch model from MLflow server on edge device | ||||||
|  |  | ||||||
|     Args: |     Args: | ||||||
|         run_id (str): mlflow run id |         run_id (str): mlflow run id | ||||||
|  |         model_artifact_dir (str, optional): model dir of run on server. Defaults to 'model'. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     download_path = Path('./model/downloaded_pytorch_model') |     download_path = Path('./model/downloaded_pytorch_model') | ||||||
| @ -22,26 +24,52 @@ def optimize_pytorch_model(run_id: str) -> None: | |||||||
|         print(f'Remove existed dir: {download_path}') |         print(f'Remove existed dir: {download_path}') | ||||||
|         shutil.rmtree(download_path) |         shutil.rmtree(download_path) | ||||||
|  |  | ||||||
|     # Download Pytorch model to local file system |     # Download model artifacts to local file system | ||||||
|     mlflow_model = mlflow.pytorch.load_model(f'runs:/{run_id}/model') |     mlflow_model = mlflow.pytorch.load_model(Path(f'runs:/{run_id}').joinpath(model_artifact_dir).as_posix()) | ||||||
|     mlflow.pytorch.save_model(mlflow_model, download_path) |     mlflow.pytorch.save_model(mlflow_model, download_path) | ||||||
|  |  | ||||||
|     # Optimize model |     # Optimize model | ||||||
|     model = torch.load(download_path.joinpath('data/model.pth')) |     model = torch.load(download_path.joinpath('data/model.pth')) | ||||||
|     dummy_input = torch.randn(5) |     dummy_input = torch.randn(5) | ||||||
|     torch.onnx.export(model, dummy_input, download_path.joinpath('data/model.onnx')) |     torch.onnx.export(model, dummy_input, download_path.joinpath('data/model.onnx')) | ||||||
|     # we can not call TensorRT on macOS, so imagine we get a serialized model |     # we can not call TensorRT on macOS, so imagine we get a serialized model😘 | ||||||
|     download_path.joinpath('data/model.trt').touch() |     download_path.joinpath('data/model.trt').touch() | ||||||
|  |  | ||||||
|     # Save optimized model to given run |     # Sent optimized model back to given run | ||||||
|     with mlflow.start_run(run_id=run_id): |     with mlflow.start_run(run_id=run_id): | ||||||
|         mlflow.log_artifact(download_path.joinpath('data/model.trt'), 'model/data') |         mlflow.log_artifact(download_path.joinpath('data/model.trt'), 'model/data') | ||||||
|         print(f'Optimized model had been uploaded to server: {mlflow.get_tracking_uri()}') |         print(f'Optimized model had been uploaded to server: {mlflow.get_tracking_uri()}') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def download_optimized_model(run_id: str, | ||||||
|  |                              save_dir: str, | ||||||
|  |                              model_artifact_path: str = 'model/data/model.trt') -> None: | ||||||
|  |     """Download optimized model from MLflow server on clent | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         run_id (str): mlflow run id | ||||||
|  |         save_dir (str): dir of local file system to save model | ||||||
|  |         model_artifact_path (str, optional): artifact path of model on server. Defaults to 'model/data/model.trt'. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     mlflow.artifacts.download_artifacts( | ||||||
|  |         run_id= run_id, | ||||||
|  |         artifact_path=model_artifact_path, | ||||||
|  |         dst_path=save_dir | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     print(f'Optimized model had been saved, please check: {Path(save_dir).joinpath(model_artifact_path)}') | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|  |  | ||||||
|     mlflow.set_tracking_uri('http://127.0.0.1:5001') |     mlflow.set_tracking_uri('http://127.0.0.1:5001') | ||||||
|  |  | ||||||
|     optimize_pytorch_model( |     optimize_pytorch_model( | ||||||
|         run_id='f1b7b9a5ba934f158c07975a8a332de5' |         run_id='f1b7b9a5ba934f158c07975a8a332de5' | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |     download_optimized_model( | ||||||
|  |         run_id='f1b7b9a5ba934f158c07975a8a332de5', | ||||||
|  |         save_dir='./model/download_tensorrt' | ||||||
|  |     ) | ||||||
|  | |||||||
										
											Binary file not shown.
										
									
								
							
		Reference in New Issue
	
	Block a user
	