diff --git a/README.md b/README.md index ac0e0e1..812209a 100644 --- a/README.md +++ b/README.md @@ -7,17 +7,18 @@ Try to use [MLflow](https://mlflow.org) platform to log PyTorch model training, * MacOS 12.5 * Docker 20.10 -# Dir +# Dirs * **service** * House MLflow service data, including MLflow artifacts, backend store and model registry +* **env** + * **mlflow.yaml** + * conda env yaml to run this repo # Files -* **conda.yaml** - * conda env yaml to run this repo -* **start_mlflow_server.sh** - * a script to start MLflow server with basic configuration +* **docker-compose.yaml** + * a yaml to apply docker-compose to start MLflow service with basic configuration (run ```docker-compose -f docker-compose.yaml up```) * **test_pytorch_m1.py** * a script to test PyTorch on Apple M1 platform with GPU acceleration * **train.py** diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..98577e3 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,18 @@ +version: '3.7' + +services: + mlflow_server: + image: ghcr.io/mlflow/mlflow:v2.1.1 + restart: always + ports: + - 5001:5001 + volumes: + - ~/python/test_mlflow/service:/home + healthcheck: + test: ["CMD", "curl", "-f", "http://0.0.0.0:5001"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + command: bash -c "apt update && apt install -y curl && mlflow server --host 0.0.0.0 --port 5001 --backend-store-uri sqlite:////home/backend.db --registry-store-uri sqlite:////home/registry.db --artifacts-destination /home/artifacts --serve-artifacts" + diff --git a/conda.yaml b/env/mlflow.yaml similarity index 100% rename from conda.yaml rename to env/mlflow.yaml diff --git a/get_registered_model_via_rest_api.py b/get_registered_model_via_rest_api.py index c5f6664..30e33bc 100644 --- a/get_registered_model_via_rest_api.py +++ b/get_registered_model_via_rest_api.py @@ -12,7 +12,7 @@ def main(): production_model_version = None query = {'name': registered_model_name} - res = requests.get('http://127.0.0.1:5000/api/2.0/mlflow/registered-models/get', params=query) + res = requests.get('http://127.0.0.1:5001/api/2.0/mlflow/registered-models/get', params=query) content = json.loads(res.text) print(content) @@ -23,7 +23,7 @@ def main(): if production_model_version is not None: query = {'name': registered_model_name, 'version': production_model_version} - res = requests.get('http://127.0.0.1:5000/api/2.0/mlflow/model-versions/get-download-uri', params=query) + res = requests.get('http://127.0.0.1:5001/api/2.0/mlflow/model-versions/get-download-uri', params=query) print(res.text) if __name__ == '__main__': diff --git a/predict.py b/predict.py index a4b5d3b..cfdcc75 100644 --- a/predict.py +++ b/predict.py @@ -10,7 +10,7 @@ import mlflow if __name__ == '__main__': # set MLflow server - mlflow.set_tracking_uri('http://127.0.0.1:5000') + mlflow.set_tracking_uri('http://127.0.0.1:5001') # load production model model = mlflow.pytorch.load_model('models:/fortune_predict_model/production') diff --git a/start_mlflow_server.sh b/start_mlflow_server.sh deleted file mode 100755 index 564b436..0000000 --- a/start_mlflow_server.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -# start_mlflow_server.sh -# -# author: deng -# date : 20230221 - -mlflow server --backend-store-uri sqlite:///service/backend.db --registry-store-uri sqlite:///service/registry.db --default-artifact-root ./service/artifacts --host 127.0.0.1 --port 5000 --serve-artifacts \ No newline at end of file diff --git a/train.py b/train.py index b7f91d2..efa0f3b 100644 --- a/train.py +++ b/train.py @@ -64,7 +64,7 @@ if __name__ == '__main__': optimizer = SGD(model.parameters(), lr=learning_rate) # set the tracking URI to the model registry - mlflow.set_tracking_uri('http://127.0.0.1:5000') + mlflow.set_tracking_uri('http://127.0.0.1:5001') mlflow.set_experiment('train_fortune_predict_model') # start a new MLflow run