# log_unsupported_model.py # # author: deng # date : 20230309 import mlflow import pandas as pd class CustomModel(mlflow.pyfunc.PythonModel): """ A mlflow wrapper to package unsupported model """ def __init__(self): self.model = None def load_model(self): # load your custom model here self.model = lambda value: value * 2 def predict(self, model_input: pd.DataFrame) -> pd.DataFrame: if self.model is None: self.load_model() output = model_input.apply(self.model) return output def log_model(server_uri:str, exp_name: str, registered_model_name: str) -> None: # init mlflow mlflow.set_tracking_uri(server_uri) mlflow.set_experiment(exp_name) # register custom model model = CustomModel() mlflow.pyfunc.log_model( artifact_path='model', python_model=model, registered_model_name=registered_model_name ) def pull_model(server_uri: str, exp_name: str, registered_model_name: str) -> None: # init mlflow mlflow.set_tracking_uri(server_uri) mlflow.set_experiment(exp_name) # pull model from registry model = mlflow.pyfunc.load_model(f'models:/{registered_model_name}/latest') model = model.unwrap_python_model() # get CustomModel object print(f'Model loaded. model type: {type(model)}') # test model availability fake_data = pd.DataFrame([1, 3, 5]) output = model.predict(fake_data) print(f'input data: {fake_data}, predictions: {output}') # save it to local file system mlflow.pyfunc.save_model(path='./model', python_model=model) if __name__ == '__main__': server_uri = 'http://127.0.0.1:5001' exp_name = 'custom_model' registered_model_name = 'custom_model' log_model(server_uri, exp_name, registered_model_name) pull_model(server_uri, exp_name, registered_model_name)