28 lines
530 B
Python
28 lines
530 B
Python
# predict.py
|
|
#
|
|
# author: deng
|
|
# date : 20230221
|
|
|
|
import torch
|
|
import mlflow
|
|
|
|
|
|
def main():
|
|
|
|
# set MLflow server
|
|
mlflow.set_tracking_uri('http://127.0.0.1:5001')
|
|
|
|
# load production model
|
|
model = mlflow.pytorch.load_model('models:/fortune_predict_model/production')
|
|
|
|
# predict
|
|
my_personal_info = torch.randn(5)
|
|
my_fortune = model(my_personal_info)
|
|
print(my_fortune)
|
|
|
|
# save model and env to local file system
|
|
mlflow.pytorch.save_model(model, './model')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |