Compare commits
	
		
			2 Commits
		
	
	
		
			1f86146b12
			...
			578f0ceea1
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 578f0ceea1 | |||
| 4bf037de15 | 
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -1,2 +1,2 @@ | ||||
| .DS_Store | ||||
| fortune_predict_model | ||||
| model | ||||
| @ -27,6 +27,7 @@ Try to use [MLflow](https://mlflow.org) platform to log PyTorch model training, | ||||
|   * a sample code to call registered model to predict testing data and save model to local file system | ||||
| * **get_registered_model_via_rest_api.py** | ||||
|   * a script to test MLflow REST api | ||||
|  | ||||
| * **log_unsupported_model.py** | ||||
|   * a sample script to apply mlflow.pyfunc to package unsupported ml model which can be logged and registered by mlflow | ||||
|  | ||||
| ###### tags: `MLOps` | ||||
							
								
								
									
										77
									
								
								log_unsupported_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								log_unsupported_model.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,77 @@ | ||||
| # log_unsupported_model.py | ||||
| # | ||||
| # author: deng | ||||
| # date  : 20230309 | ||||
|  | ||||
| import mlflow | ||||
| import pandas as pd | ||||
|  | ||||
|  | ||||
| class CustomModel(mlflow.pyfunc.PythonModel): | ||||
|     """ A mlflow wrapper to package unsupported model """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.model = None | ||||
|  | ||||
|     def load_model(self): | ||||
|         # load your custom model here | ||||
|         self.model = lambda value: value * 2 | ||||
|  | ||||
|     def predict(self, | ||||
|                 model_input: pd.DataFrame) -> pd.DataFrame: | ||||
|  | ||||
|         if self.model is None: | ||||
|             self.load_model() | ||||
|  | ||||
|         output = model_input.apply(self.model) | ||||
|  | ||||
|         return output | ||||
|  | ||||
|  | ||||
| def log_model(server_uri:str, | ||||
|               exp_name: str, | ||||
|               registered_model_name: str) -> None: | ||||
|  | ||||
|     # init mlflow | ||||
|     mlflow.set_tracking_uri(server_uri) | ||||
|     mlflow.set_experiment(exp_name) | ||||
|  | ||||
|     # register custom model | ||||
|     model = CustomModel() | ||||
|     mlflow.pyfunc.log_model( | ||||
|         artifact_path='model', | ||||
|         python_model=model, | ||||
|         registered_model_name=registered_model_name | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def pull_model(server_uri: str, | ||||
|                exp_name: str, | ||||
|                registered_model_name: str) -> None: | ||||
|  | ||||
|     # init mlflow | ||||
|     mlflow.set_tracking_uri(server_uri) | ||||
|     mlflow.set_experiment(exp_name) | ||||
|  | ||||
|     # pull model from registry | ||||
|     model = mlflow.pyfunc.load_model(f'models:/{registered_model_name}/latest') | ||||
|     model = model.unwrap_python_model()  # get CustomModel object | ||||
|     print(f'Model loaded. model type: {type(model)}') | ||||
|  | ||||
|     # test model availability | ||||
|     fake_data = pd.DataFrame([1, 3, 5]) | ||||
|     output = model.predict(fake_data) | ||||
|     print(f'input data: {fake_data}, predictions: {output}') | ||||
|  | ||||
|     # save it to local file system | ||||
|     mlflow.pyfunc.save_model(path='./model', python_model=model) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|     server_uri = 'http://127.0.0.1:5001' | ||||
|     exp_name = 'custom_model' | ||||
|     registered_model_name = 'custom_model' | ||||
|  | ||||
|     log_model(server_uri, exp_name, registered_model_name) | ||||
|     pull_model(server_uri, exp_name, registered_model_name) | ||||
| @ -7,7 +7,7 @@ import torch | ||||
| import mlflow | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
| def main(): | ||||
|  | ||||
|     # set MLflow server | ||||
|     mlflow.set_tracking_uri('http://127.0.0.1:5001') | ||||
| @ -21,4 +21,8 @@ if __name__ == '__main__': | ||||
|     print(my_fortune) | ||||
|  | ||||
|     # save model and env to local file system | ||||
|     mlflow.pytorch.save_model(model, './fortune_predict_model') | ||||
|     mlflow.pytorch.save_model(model, './model') | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
		Reference in New Issue
	
	Block a user
	