package unsupported ml model
This commit is contained in:
parent
1f86146b12
commit
4bf037de15
|
@ -1,2 +1,2 @@
|
|||
.DS_Store
|
||||
fortune_predict_model
|
||||
model
|
|
@ -27,6 +27,7 @@ Try to use [MLflow](https://mlflow.org) platform to log PyTorch model training,
|
|||
* a sample code to call registered model to predict testing data and save model to local file system
|
||||
* **get_registered_model_via_rest_api.py**
|
||||
* a script to test MLflow REST api
|
||||
|
||||
* **log_unsupported_model.py**
|
||||
* a sample script to apply mlflow.pyfunc to package unsupported ml model which can be logged and registered by mlflow
|
||||
|
||||
###### tags: `MLOps`
|
|
@ -0,0 +1,77 @@
|
|||
# log_unsupported_model.py
|
||||
#
|
||||
# author: deng
|
||||
# date : 20230309
|
||||
|
||||
import mlflow
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class CustomModel(mlflow.pyfunc.PythonModel):
|
||||
""" A mlflow wrapper to package unsupported model """
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
|
||||
def load_model(self):
|
||||
# load your custom model here
|
||||
self.model = lambda value: value * 2
|
||||
|
||||
def predict(self,
|
||||
model_input: pd.DataFrame) -> pd.DataFrame:
|
||||
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
|
||||
output = model_input.apply(self.model)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def log_model(server_uri:str,
|
||||
exp_name: str,
|
||||
registered_model_name: str) -> None:
|
||||
|
||||
# init mlflow
|
||||
mlflow.set_tracking_uri(server_uri)
|
||||
mlflow.set_experiment(exp_name)
|
||||
|
||||
# register custom model
|
||||
model = CustomModel()
|
||||
mlflow.pyfunc.log_model(
|
||||
artifact_path='model',
|
||||
python_model=model,
|
||||
registered_model_name=registered_model_name
|
||||
)
|
||||
|
||||
|
||||
def pull_model(server_uri: str,
|
||||
exp_name: str,
|
||||
registered_model_name: str) -> None:
|
||||
|
||||
# init mlflow
|
||||
mlflow.set_tracking_uri(server_uri)
|
||||
mlflow.set_experiment(exp_name)
|
||||
|
||||
# pull model from registry
|
||||
model = mlflow.pyfunc.load_model(f'models:/{registered_model_name}/latest')
|
||||
model = model.unwrap_python_model() # get CustomModel object
|
||||
print(f'Model loaded. model type: {type(model)}')
|
||||
|
||||
# test model availability
|
||||
fake_data = pd.DataFrame([1, 3, 5])
|
||||
output = model.predict(fake_data)
|
||||
print(f'input data: {fake_data}, predictions: {output}')
|
||||
|
||||
# save it to local file system
|
||||
mlflow.pyfunc.save_model(path='./model', python_model=model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
server_uri = 'http://127.0.0.1:5001'
|
||||
exp_name = 'custom_model'
|
||||
registered_model_name = 'custom_model'
|
||||
|
||||
log_model(server_uri, exp_name, registered_model_name)
|
||||
pull_model(server_uri, exp_name, registered_model_name)
|
|
@ -7,7 +7,7 @@ import torch
|
|||
import mlflow
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def main():
|
||||
|
||||
# set MLflow server
|
||||
mlflow.set_tracking_uri('http://127.0.0.1:5001')
|
||||
|
@ -21,4 +21,8 @@ if __name__ == '__main__':
|
|||
print(my_fortune)
|
||||
|
||||
# save model and env to local file system
|
||||
mlflow.pytorch.save_model(model, './fortune_predict_model')
|
||||
mlflow.pytorch.save_model(model, './model')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue