package unsupported ml model
This commit is contained in:
parent
1f86146b12
commit
4bf037de15
|
@ -1,2 +1,2 @@
|
||||||
.DS_Store
|
.DS_Store
|
||||||
fortune_predict_model
|
model
|
|
@ -27,6 +27,7 @@ Try to use [MLflow](https://mlflow.org) platform to log PyTorch model training,
|
||||||
* a sample code to call registered model to predict testing data and save model to local file system
|
* a sample code to call registered model to predict testing data and save model to local file system
|
||||||
* **get_registered_model_via_rest_api.py**
|
* **get_registered_model_via_rest_api.py**
|
||||||
* a script to test MLflow REST api
|
* a script to test MLflow REST api
|
||||||
|
* **log_unsupported_model.py**
|
||||||
|
* a sample script to apply mlflow.pyfunc to package unsupported ml model which can be logged and registered by mlflow
|
||||||
|
|
||||||
###### tags: `MLOps`
|
###### tags: `MLOps`
|
|
@ -0,0 +1,77 @@
|
||||||
|
# log_unsupported_model.py
|
||||||
|
#
|
||||||
|
# author: deng
|
||||||
|
# date : 20230309
|
||||||
|
|
||||||
|
import mlflow
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
class CustomModel(mlflow.pyfunc.PythonModel):
|
||||||
|
""" A mlflow wrapper to package unsupported model """
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
# load your custom model here
|
||||||
|
self.model = lambda value: value * 2
|
||||||
|
|
||||||
|
def predict(self,
|
||||||
|
model_input: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
|
||||||
|
if self.model is None:
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
output = model_input.apply(self.model)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def log_model(server_uri:str,
|
||||||
|
exp_name: str,
|
||||||
|
registered_model_name: str) -> None:
|
||||||
|
|
||||||
|
# init mlflow
|
||||||
|
mlflow.set_tracking_uri(server_uri)
|
||||||
|
mlflow.set_experiment(exp_name)
|
||||||
|
|
||||||
|
# register custom model
|
||||||
|
model = CustomModel()
|
||||||
|
mlflow.pyfunc.log_model(
|
||||||
|
artifact_path='model',
|
||||||
|
python_model=model,
|
||||||
|
registered_model_name=registered_model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pull_model(server_uri: str,
|
||||||
|
exp_name: str,
|
||||||
|
registered_model_name: str) -> None:
|
||||||
|
|
||||||
|
# init mlflow
|
||||||
|
mlflow.set_tracking_uri(server_uri)
|
||||||
|
mlflow.set_experiment(exp_name)
|
||||||
|
|
||||||
|
# pull model from registry
|
||||||
|
model = mlflow.pyfunc.load_model(f'models:/{registered_model_name}/latest')
|
||||||
|
model = model.unwrap_python_model() # get CustomModel object
|
||||||
|
print(f'Model loaded. model type: {type(model)}')
|
||||||
|
|
||||||
|
# test model availability
|
||||||
|
fake_data = pd.DataFrame([1, 3, 5])
|
||||||
|
output = model.predict(fake_data)
|
||||||
|
print(f'input data: {fake_data}, predictions: {output}')
|
||||||
|
|
||||||
|
# save it to local file system
|
||||||
|
mlflow.pyfunc.save_model(path='./model', python_model=model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
server_uri = 'http://127.0.0.1:5001'
|
||||||
|
exp_name = 'custom_model'
|
||||||
|
registered_model_name = 'custom_model'
|
||||||
|
|
||||||
|
log_model(server_uri, exp_name, registered_model_name)
|
||||||
|
pull_model(server_uri, exp_name, registered_model_name)
|
|
@ -7,7 +7,7 @@ import torch
|
||||||
import mlflow
|
import mlflow
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
def main():
|
||||||
|
|
||||||
# set MLflow server
|
# set MLflow server
|
||||||
mlflow.set_tracking_uri('http://127.0.0.1:5001')
|
mlflow.set_tracking_uri('http://127.0.0.1:5001')
|
||||||
|
@ -21,4 +21,8 @@ if __name__ == '__main__':
|
||||||
print(my_fortune)
|
print(my_fortune)
|
||||||
|
|
||||||
# save model and env to local file system
|
# save model and env to local file system
|
||||||
mlflow.pytorch.save_model(model, './fortune_predict_model')
|
mlflow.pytorch.save_model(model, './model')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue