test_mlflow/log_unsupported_model.py

77 lines
1.9 KiB
Python

# log_unsupported_model.py
#
# author: deng
# date : 20230309
import mlflow
import pandas as pd
class CustomModel(mlflow.pyfunc.PythonModel):
""" A mlflow wrapper to package unsupported model """
def __init__(self):
self.model = None
def load_model(self):
# load your custom model here
self.model = lambda value: value * 2
def predict(self,
model_input: pd.DataFrame) -> pd.DataFrame:
if self.model is None:
self.load_model()
output = model_input.apply(self.model)
return output
def log_model(server_uri:str,
exp_name: str,
registered_model_name: str) -> None:
# init mlflow
mlflow.set_tracking_uri(server_uri)
mlflow.set_experiment(exp_name)
# register custom model
model = CustomModel()
mlflow.pyfunc.log_model(
artifact_path='model',
python_model=model,
registered_model_name=registered_model_name
)
def pull_model(server_uri: str,
exp_name: str,
registered_model_name: str) -> None:
# init mlflow
mlflow.set_tracking_uri(server_uri)
mlflow.set_experiment(exp_name)
# pull model from registry
model = mlflow.pyfunc.load_model(f'models:/{registered_model_name}/latest')
model = model.unwrap_python_model() # get CustomModel object
print(f'Model loaded. model type: {type(model)}')
# test model availability
fake_data = pd.DataFrame([1, 3, 5])
output = model.predict(fake_data)
print(f'input data: {fake_data}, predictions: {output}')
# save it to local file system
mlflow.pyfunc.save_model(path=f'./model/{exp_name}', python_model=model)
if __name__ == '__main__':
server_uri = 'http://127.0.0.1:5001'
exp_name = 'custom_model'
registered_model_name = 'custom_model'
log_model(server_uri, exp_name, registered_model_name)
pull_model(server_uri, exp_name, registered_model_name)