From 4bf037de155c5458df5389bd5437d622a40694b2 Mon Sep 17 00:00:00 2001 From: deng Date: Fri, 10 Mar 2023 15:03:21 +0800 Subject: [PATCH] package unsupported ml model --- .gitignore | 2 +- README.md | 3 +- log_unsupported_model.py | 77 ++++++++++++++++++++++++++++++++++++++++ predict.py | 8 +++-- 4 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 log_unsupported_model.py diff --git a/.gitignore b/.gitignore index ee30a05..99986cc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ .DS_Store -fortune_predict_model \ No newline at end of file +model \ No newline at end of file diff --git a/README.md b/README.md index 812209a..31257d0 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Try to use [MLflow](https://mlflow.org) platform to log PyTorch model training, * a sample code to call registered model to predict testing data and save model to local file system * **get_registered_model_via_rest_api.py** * a script to test MLflow REST api - +* **log_unsupported_model.py** + * a sample script to apply mlflow.pyfunc to package unsupported ml model which can be logged and registered by mlflow ###### tags: `MLOps` \ No newline at end of file diff --git a/log_unsupported_model.py b/log_unsupported_model.py new file mode 100644 index 0000000..7291d2d --- /dev/null +++ b/log_unsupported_model.py @@ -0,0 +1,77 @@ +# log_unsupported_model.py +# +# author: deng +# date : 20230309 + +import mlflow +import pandas as pd + + +class CustomModel(mlflow.pyfunc.PythonModel): + """ A mlflow wrapper to package unsupported model """ + + def __init__(self): + self.model = None + + def load_model(self): + # load your custom model here + self.model = lambda value: value * 2 + + def predict(self, + model_input: pd.DataFrame) -> pd.DataFrame: + + if self.model is None: + self.load_model() + + output = model_input.apply(self.model) + + return output + + +def log_model(server_uri:str, + exp_name: str, + registered_model_name: str) -> None: + + # init mlflow + mlflow.set_tracking_uri(server_uri) + mlflow.set_experiment(exp_name) + + # register custom model + model = CustomModel() + mlflow.pyfunc.log_model( + artifact_path='model', + python_model=model, + registered_model_name=registered_model_name + ) + + +def pull_model(server_uri: str, + exp_name: str, + registered_model_name: str) -> None: + + # init mlflow + mlflow.set_tracking_uri(server_uri) + mlflow.set_experiment(exp_name) + + # pull model from registry + model = mlflow.pyfunc.load_model(f'models:/{registered_model_name}/latest') + model = model.unwrap_python_model() # get CustomModel object + print(f'Model loaded. model type: {type(model)}') + + # test model availability + fake_data = pd.DataFrame([1, 3, 5]) + output = model.predict(fake_data) + print(f'input data: {fake_data}, predictions: {output}') + + # save it to local file system + mlflow.pyfunc.save_model(path='./model', python_model=model) + + +if __name__ == '__main__': + + server_uri = 'http://127.0.0.1:5001' + exp_name = 'custom_model' + registered_model_name = 'custom_model' + + log_model(server_uri, exp_name, registered_model_name) + pull_model(server_uri, exp_name, registered_model_name) \ No newline at end of file diff --git a/predict.py b/predict.py index cfdcc75..77c1cb1 100644 --- a/predict.py +++ b/predict.py @@ -7,7 +7,7 @@ import torch import mlflow -if __name__ == '__main__': +def main(): # set MLflow server mlflow.set_tracking_uri('http://127.0.0.1:5001') @@ -21,4 +21,8 @@ if __name__ == '__main__': print(my_fortune) # save model and env to local file system - mlflow.pytorch.save_model(model, './fortune_predict_model') \ No newline at end of file + mlflow.pytorch.save_model(model, './model') + + +if __name__ == '__main__': + main() \ No newline at end of file