test_mlflow/predict.py

24 lines
521 B
Python

# predict.py
#
# author: deng
# date : 20230221
import torch
import mlflow
if __name__ == '__main__':
# set MLflow server
mlflow.set_tracking_uri('http://127.0.0.1:5001')
# load production model
model = mlflow.pytorch.load_model('models:/fortune_predict_model/production')
# predict
my_personal_info = torch.randn(5)
my_fortune = model(my_personal_info)
print(my_fortune)
# save model and env to local file system
mlflow.pytorch.save_model(model, './fortune_predict_model')