package unsupported ml model
This commit is contained in:
		
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -1,2 +1,2 @@ | |||||||
| .DS_Store | .DS_Store | ||||||
| fortune_predict_model | model | ||||||
| @ -27,6 +27,7 @@ Try to use [MLflow](https://mlflow.org) platform to log PyTorch model training, | |||||||
|   * a sample code to call registered model to predict testing data and save model to local file system |   * a sample code to call registered model to predict testing data and save model to local file system | ||||||
| * **get_registered_model_via_rest_api.py** | * **get_registered_model_via_rest_api.py** | ||||||
|   * a script to test MLflow REST api |   * a script to test MLflow REST api | ||||||
|  | * **log_unsupported_model.py** | ||||||
|  |   * a sample script to apply mlflow.pyfunc to package unsupported ml model which can be logged and registered by mlflow | ||||||
|  |  | ||||||
| ###### tags: `MLOps` | ###### tags: `MLOps` | ||||||
							
								
								
									
										77
									
								
								log_unsupported_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								log_unsupported_model.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,77 @@ | |||||||
|  | # log_unsupported_model.py | ||||||
|  | # | ||||||
|  | # author: deng | ||||||
|  | # date  : 20230309 | ||||||
|  |  | ||||||
|  | import mlflow | ||||||
|  | import pandas as pd | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CustomModel(mlflow.pyfunc.PythonModel): | ||||||
|  |     """ A mlflow wrapper to package unsupported model """ | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         self.model = None | ||||||
|  |  | ||||||
|  |     def load_model(self): | ||||||
|  |         # load your custom model here | ||||||
|  |         self.model = lambda value: value * 2 | ||||||
|  |  | ||||||
|  |     def predict(self, | ||||||
|  |                 model_input: pd.DataFrame) -> pd.DataFrame: | ||||||
|  |  | ||||||
|  |         if self.model is None: | ||||||
|  |             self.load_model() | ||||||
|  |  | ||||||
|  |         output = model_input.apply(self.model) | ||||||
|  |  | ||||||
|  |         return output | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def log_model(server_uri:str, | ||||||
|  |               exp_name: str, | ||||||
|  |               registered_model_name: str) -> None: | ||||||
|  |  | ||||||
|  |     # init mlflow | ||||||
|  |     mlflow.set_tracking_uri(server_uri) | ||||||
|  |     mlflow.set_experiment(exp_name) | ||||||
|  |  | ||||||
|  |     # register custom model | ||||||
|  |     model = CustomModel() | ||||||
|  |     mlflow.pyfunc.log_model( | ||||||
|  |         artifact_path='model', | ||||||
|  |         python_model=model, | ||||||
|  |         registered_model_name=registered_model_name | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def pull_model(server_uri: str, | ||||||
|  |                exp_name: str, | ||||||
|  |                registered_model_name: str) -> None: | ||||||
|  |  | ||||||
|  |     # init mlflow | ||||||
|  |     mlflow.set_tracking_uri(server_uri) | ||||||
|  |     mlflow.set_experiment(exp_name) | ||||||
|  |  | ||||||
|  |     # pull model from registry | ||||||
|  |     model = mlflow.pyfunc.load_model(f'models:/{registered_model_name}/latest') | ||||||
|  |     model = model.unwrap_python_model()  # get CustomModel object | ||||||
|  |     print(f'Model loaded. model type: {type(model)}') | ||||||
|  |  | ||||||
|  |     # test model availability | ||||||
|  |     fake_data = pd.DataFrame([1, 3, 5]) | ||||||
|  |     output = model.predict(fake_data) | ||||||
|  |     print(f'input data: {fake_data}, predictions: {output}') | ||||||
|  |  | ||||||
|  |     # save it to local file system | ||||||
|  |     mlflow.pyfunc.save_model(path='./model', python_model=model) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |     server_uri = 'http://127.0.0.1:5001' | ||||||
|  |     exp_name = 'custom_model' | ||||||
|  |     registered_model_name = 'custom_model' | ||||||
|  |  | ||||||
|  |     log_model(server_uri, exp_name, registered_model_name) | ||||||
|  |     pull_model(server_uri, exp_name, registered_model_name) | ||||||
| @ -7,7 +7,7 @@ import torch | |||||||
| import mlflow | import mlflow | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | def main(): | ||||||
|  |  | ||||||
|     # set MLflow server |     # set MLflow server | ||||||
|     mlflow.set_tracking_uri('http://127.0.0.1:5001') |     mlflow.set_tracking_uri('http://127.0.0.1:5001') | ||||||
| @ -21,4 +21,8 @@ if __name__ == '__main__': | |||||||
|     print(my_fortune) |     print(my_fortune) | ||||||
|  |  | ||||||
|     # save model and env to local file system |     # save model and env to local file system | ||||||
|     mlflow.pytorch.save_model(model, './fortune_predict_model') |     mlflow.pytorch.save_model(model, './model') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     main() | ||||||
		Reference in New Issue
	
	Block a user