77 lines
1.9 KiB
Python
77 lines
1.9 KiB
Python
# log_unsupported_model.py
|
|
#
|
|
# author: deng
|
|
# date : 20230309
|
|
|
|
import mlflow
|
|
import pandas as pd
|
|
|
|
|
|
class CustomModel(mlflow.pyfunc.PythonModel):
|
|
""" A mlflow wrapper to package unsupported model """
|
|
|
|
def __init__(self):
|
|
self.model = None
|
|
|
|
def load_model(self):
|
|
# load your custom model here
|
|
self.model = lambda value: value * 2
|
|
|
|
def predict(self,
|
|
model_input: pd.DataFrame) -> pd.DataFrame:
|
|
|
|
if self.model is None:
|
|
self.load_model()
|
|
|
|
output = model_input.apply(self.model)
|
|
|
|
return output
|
|
|
|
|
|
def log_model(server_uri:str,
|
|
exp_name: str,
|
|
registered_model_name: str) -> None:
|
|
|
|
# init mlflow
|
|
mlflow.set_tracking_uri(server_uri)
|
|
mlflow.set_experiment(exp_name)
|
|
|
|
# register custom model
|
|
model = CustomModel()
|
|
mlflow.pyfunc.log_model(
|
|
artifact_path='model',
|
|
python_model=model,
|
|
registered_model_name=registered_model_name
|
|
)
|
|
|
|
|
|
def pull_model(server_uri: str,
|
|
exp_name: str,
|
|
registered_model_name: str) -> None:
|
|
|
|
# init mlflow
|
|
mlflow.set_tracking_uri(server_uri)
|
|
mlflow.set_experiment(exp_name)
|
|
|
|
# pull model from registry
|
|
model = mlflow.pyfunc.load_model(f'models:/{registered_model_name}/latest')
|
|
model = model.unwrap_python_model() # get CustomModel object
|
|
print(f'Model loaded. model type: {type(model)}')
|
|
|
|
# test model availability
|
|
fake_data = pd.DataFrame([1, 3, 5])
|
|
output = model.predict(fake_data)
|
|
print(f'input data: {fake_data}, predictions: {output}')
|
|
|
|
# save it to local file system
|
|
mlflow.pyfunc.save_model(path='./model', python_model=model)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
server_uri = 'http://127.0.0.1:5001'
|
|
exp_name = 'custom_model'
|
|
registered_model_name = 'custom_model'
|
|
|
|
log_model(server_uri, exp_name, registered_model_name)
|
|
pull_model(server_uri, exp_name, registered_model_name) |