package unsupported ml model

This commit is contained in:
deng 2023-03-10 15:03:21 +08:00
parent 1f86146b12
commit 4bf037de15
4 changed files with 86 additions and 4 deletions

2
.gitignore vendored
View File

@ -1,2 +1,2 @@
.DS_Store .DS_Store
fortune_predict_model model

View File

@ -27,6 +27,7 @@ Try to use [MLflow](https://mlflow.org) platform to log PyTorch model training,
* a sample code to call registered model to predict testing data and save model to local file system * a sample code to call registered model to predict testing data and save model to local file system
* **get_registered_model_via_rest_api.py** * **get_registered_model_via_rest_api.py**
* a script to test MLflow REST api * a script to test MLflow REST api
* **log_unsupported_model.py**
* a sample script to apply mlflow.pyfunc to package unsupported ml model which can be logged and registered by mlflow
###### tags: `MLOps` ###### tags: `MLOps`

77
log_unsupported_model.py Normal file
View File

@ -0,0 +1,77 @@
# log_unsupported_model.py
#
# author: deng
# date : 20230309
import mlflow
import pandas as pd
class CustomModel(mlflow.pyfunc.PythonModel):
""" A mlflow wrapper to package unsupported model """
def __init__(self):
self.model = None
def load_model(self):
# load your custom model here
self.model = lambda value: value * 2
def predict(self,
model_input: pd.DataFrame) -> pd.DataFrame:
if self.model is None:
self.load_model()
output = model_input.apply(self.model)
return output
def log_model(server_uri:str,
exp_name: str,
registered_model_name: str) -> None:
# init mlflow
mlflow.set_tracking_uri(server_uri)
mlflow.set_experiment(exp_name)
# register custom model
model = CustomModel()
mlflow.pyfunc.log_model(
artifact_path='model',
python_model=model,
registered_model_name=registered_model_name
)
def pull_model(server_uri: str,
exp_name: str,
registered_model_name: str) -> None:
# init mlflow
mlflow.set_tracking_uri(server_uri)
mlflow.set_experiment(exp_name)
# pull model from registry
model = mlflow.pyfunc.load_model(f'models:/{registered_model_name}/latest')
model = model.unwrap_python_model() # get CustomModel object
print(f'Model loaded. model type: {type(model)}')
# test model availability
fake_data = pd.DataFrame([1, 3, 5])
output = model.predict(fake_data)
print(f'input data: {fake_data}, predictions: {output}')
# save it to local file system
mlflow.pyfunc.save_model(path='./model', python_model=model)
if __name__ == '__main__':
server_uri = 'http://127.0.0.1:5001'
exp_name = 'custom_model'
registered_model_name = 'custom_model'
log_model(server_uri, exp_name, registered_model_name)
pull_model(server_uri, exp_name, registered_model_name)

View File

@ -7,7 +7,7 @@ import torch
import mlflow import mlflow
if __name__ == '__main__': def main():
# set MLflow server # set MLflow server
mlflow.set_tracking_uri('http://127.0.0.1:5001') mlflow.set_tracking_uri('http://127.0.0.1:5001')
@ -21,4 +21,8 @@ if __name__ == '__main__':
print(my_fortune) print(my_fortune)
# save model and env to local file system # save model and env to local file system
mlflow.pytorch.save_model(model, './fortune_predict_model') mlflow.pytorch.save_model(model, './model')
if __name__ == '__main__':
main()