This commit is contained in:
許登傑
2023-02-21 14:15:39 +08:00
commit e1c2b82109
53 changed files with 546 additions and 0 deletions

21
predict.py Normal file
View File

@ -0,0 +1,21 @@
# predict.py
#
# author: deng
# date : 20230221
import torch
import mlflow
if __name__ == '__main__':
# set MLflow server
mlflow.set_tracking_uri('http://127.0.0.1:5000')
# load production model
model = mlflow.pytorch.load_model('models:/cls_model/production')
# predict
fake_data = torch.randn(10)
output = model(fake_data)
print(output)