diff --git a/log_unsupported_model.py b/log_unsupported_model.py index 7291d2d..bfa33e2 100644 --- a/log_unsupported_model.py +++ b/log_unsupported_model.py @@ -64,7 +64,7 @@ def pull_model(server_uri: str, print(f'input data: {fake_data}, predictions: {output}') # save it to local file system - mlflow.pyfunc.save_model(path='./model', python_model=model) + mlflow.pyfunc.save_model(path=f'./model/{exp_name}', python_model=model) if __name__ == '__main__': diff --git a/predict.py b/predict.py index 77c1cb1..958689a 100644 --- a/predict.py +++ b/predict.py @@ -21,7 +21,7 @@ def main(): print(my_fortune) # save model and env to local file system - mlflow.pytorch.save_model(model, './model') + mlflow.pytorch.save_model(model, './model/fortune_predict_model') if __name__ == '__main__':