replace bash script to docker-compose to build server
This commit is contained in:
		
							
								
								
									
										11
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								README.md
									
									
									
									
									
								
							| @ -7,17 +7,18 @@ Try to use [MLflow](https://mlflow.org) platform to log PyTorch model training, | ||||
| * MacOS 12.5 | ||||
| * Docker 20.10 | ||||
|  | ||||
| # Dir | ||||
| # Dirs | ||||
|  | ||||
| * **service** | ||||
|   * House MLflow service data, including MLflow artifacts, backend store and model registry | ||||
| * **env** | ||||
|   * **mlflow.yaml** | ||||
|     * conda env yaml to run this repo | ||||
|  | ||||
| # Files | ||||
|  | ||||
| * **conda.yaml** | ||||
|   * conda env yaml to run this repo | ||||
| * **start_mlflow_server.sh** | ||||
|   * a script to start MLflow server with basic configuration | ||||
| * **docker-compose.yaml** | ||||
|   * a yaml to apply docker-compose to start MLflow service with basic configuration (run ```docker-compose -f docker-compose.yaml up```) | ||||
| * **test_pytorch_m1.py** | ||||
|   * a script to test PyTorch on Apple M1 platform with GPU acceleration | ||||
| * **train.py** | ||||
|  | ||||
							
								
								
									
										18
									
								
								docker-compose.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								docker-compose.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,18 @@ | ||||
| version: '3.7' | ||||
|  | ||||
| services: | ||||
|   mlflow_server: | ||||
|     image: ghcr.io/mlflow/mlflow:v2.1.1 | ||||
|     restart: always | ||||
|     ports: | ||||
|       - 5001:5001 | ||||
|     volumes: | ||||
|       - ~/python/test_mlflow/service:/home | ||||
|     healthcheck: | ||||
|       test: ["CMD", "curl", "-f", "http://0.0.0.0:5001"] | ||||
|       interval: 30s | ||||
|       timeout: 10s | ||||
|       retries: 3 | ||||
|       start_period: 20s | ||||
|     command: bash -c "apt update && apt install -y curl && mlflow server --host 0.0.0.0 --port 5001 --backend-store-uri sqlite:////home/backend.db --registry-store-uri sqlite:////home/registry.db --artifacts-destination /home/artifacts --serve-artifacts" | ||||
|  | ||||
							
								
								
									
										0
									
								
								conda.yaml → env/mlflow.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										0
									
								
								conda.yaml → env/mlflow.yaml
									
									
									
									
										vendored
									
									
								
							| @ -12,7 +12,7 @@ def main(): | ||||
|     production_model_version = None | ||||
|  | ||||
|     query = {'name': registered_model_name} | ||||
|     res = requests.get('http://127.0.0.1:5000/api/2.0/mlflow/registered-models/get', params=query) | ||||
|     res = requests.get('http://127.0.0.1:5001/api/2.0/mlflow/registered-models/get', params=query) | ||||
|     content = json.loads(res.text) | ||||
|     print(content) | ||||
|  | ||||
| @ -23,7 +23,7 @@ def main(): | ||||
|  | ||||
|     if production_model_version is not None: | ||||
|         query = {'name': registered_model_name, 'version': production_model_version} | ||||
|     res = requests.get('http://127.0.0.1:5000/api/2.0/mlflow/model-versions/get-download-uri', params=query) | ||||
|     res = requests.get('http://127.0.0.1:5001/api/2.0/mlflow/model-versions/get-download-uri', params=query) | ||||
|     print(res.text) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
| @ -10,7 +10,7 @@ import mlflow | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|     # set MLflow server | ||||
|     mlflow.set_tracking_uri('http://127.0.0.1:5000') | ||||
|     mlflow.set_tracking_uri('http://127.0.0.1:5001') | ||||
|  | ||||
|     # load production model | ||||
|     model = mlflow.pytorch.load_model('models:/fortune_predict_model/production') | ||||
|  | ||||
| @ -1,7 +0,0 @@ | ||||
| #!/bin/bash | ||||
| # start_mlflow_server.sh | ||||
| # | ||||
| # author: deng | ||||
| # date  : 20230221 | ||||
|  | ||||
| mlflow server --backend-store-uri sqlite:///service/backend.db --registry-store-uri sqlite:///service/registry.db --default-artifact-root ./service/artifacts --host 127.0.0.1 --port 5000 --serve-artifacts | ||||
							
								
								
									
										2
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								train.py
									
									
									
									
									
								
							| @ -64,7 +64,7 @@ if __name__ == '__main__': | ||||
|     optimizer = SGD(model.parameters(), lr=learning_rate) | ||||
|  | ||||
|     # set the tracking URI to the model registry | ||||
|     mlflow.set_tracking_uri('http://127.0.0.1:5000') | ||||
|     mlflow.set_tracking_uri('http://127.0.0.1:5001') | ||||
|     mlflow.set_experiment('train_fortune_predict_model') | ||||
|  | ||||
|     # start a new MLflow run | ||||
|  | ||||
		Reference in New Issue
	
	Block a user