test_mlflow/predict.py

21 lines
377 B
Python

# predict.py
#
# author: deng
# date : 20230221
import torch
import mlflow
if __name__ == '__main__':
# set MLflow server
mlflow.set_tracking_uri('http://127.0.0.1:5000')
# load production model
model = mlflow.pytorch.load_model('models:/cls_model/production')
# predict
fake_data = torch.randn(10)
output = model(fake_data)
print(output)