21 lines
377 B
Python
21 lines
377 B
Python
# predict.py
|
|
#
|
|
# author: deng
|
|
# date : 20230221
|
|
|
|
import torch
|
|
import mlflow
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
# set MLflow server
|
|
mlflow.set_tracking_uri('http://127.0.0.1:5000')
|
|
|
|
# load production model
|
|
model = mlflow.pytorch.load_model('models:/cls_model/production')
|
|
|
|
# predict
|
|
fake_data = torch.randn(10)
|
|
output = model(fake_data)
|
|
print(output) |