Compare commits
	
		
			2 Commits
		
	
	
		
			b31dbcd0f0
			...
			3c8580f0f4
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 3c8580f0f4 | |||
| 8001876359 | 
| @ -0,0 +1,98 @@ | |||||||
|  | # train.py | ||||||
|  | # | ||||||
|  | # author: deng | ||||||
|  | # date  : 20230221 | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from torch.optim import SGD | ||||||
|  | import mlflow | ||||||
|  | from mlflow.models.signature import ModelSignature | ||||||
|  | from mlflow.types.schema import Schema, ColSpec | ||||||
|  | from tqdm import tqdm | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Net(nn.Module): | ||||||
|  |     """ define a simple neural network model """ | ||||||
|  |     def __init__(self): | ||||||
|  |         super(Net, self).__init__() | ||||||
|  |         self.fc1 = nn.Linear(5, 3) | ||||||
|  |         self.fc2 = nn.Linear(3, 1) | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = self.fc1(x) | ||||||
|  |         x = torch.relu(x) | ||||||
|  |         x = self.fc2(x) | ||||||
|  |         return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def train(model, dataloader, criterion, optimizer, epochs): | ||||||
|  |     """ define the training function """ | ||||||
|  |     for epoch in tqdm(range(epochs), 'Epochs'): | ||||||
|  |  | ||||||
|  |         for batch, (inputs, labels) in enumerate(dataloader): | ||||||
|  |  | ||||||
|  |             # forwarding | ||||||
|  |             outputs = model(inputs) | ||||||
|  |             loss = criterion(outputs, labels) | ||||||
|  |  | ||||||
|  |             # update gradient | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |  | ||||||
|  |         # log loss | ||||||
|  |         mlflow.log_metric('train_loss', loss.item(), step=epoch) | ||||||
|  |  | ||||||
|  |     return loss | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |     # set hyper parameters | ||||||
|  |     learning_rate = 1e-2 | ||||||
|  |     batch_size = 10 | ||||||
|  |     epochs = 20 | ||||||
|  |  | ||||||
|  |     # create a dataloader with fake data | ||||||
|  |     dataloader = [(torch.randn(5), torch.randn(1)) for _ in range(100)] | ||||||
|  |     dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size) | ||||||
|  |  | ||||||
|  |     # create the model, criterion, and optimizer | ||||||
|  |     model = Net() | ||||||
|  |     criterion = nn.MSELoss() | ||||||
|  |     optimizer = SGD(model.parameters(), lr=learning_rate) | ||||||
|  |  | ||||||
|  |     # set the tracking URI to the model registry | ||||||
|  |     mlflow.set_tracking_uri('http://127.0.0.1:5000') | ||||||
|  |     mlflow.set_experiment('train_fortune_predict_model') | ||||||
|  |  | ||||||
|  |     # start a new MLflow run | ||||||
|  |     with mlflow.start_run(): | ||||||
|  |  | ||||||
|  |         # train the model | ||||||
|  |         loss = train(model, dataloader, criterion, optimizer, epochs) | ||||||
|  |  | ||||||
|  |         # log some additional metrics | ||||||
|  |         mlflow.log_metric('final_loss', loss.item()) | ||||||
|  |         mlflow.log_param('learning_rate', learning_rate) | ||||||
|  |         mlflow.log_param('batch_size', batch_size) | ||||||
|  |  | ||||||
|  |         # create a signature to record model input and output info | ||||||
|  |         input_schema = Schema([ | ||||||
|  |             ColSpec('float', 'age'), | ||||||
|  |             ColSpec('float', 'mood level'), | ||||||
|  |             ColSpec('float', 'health level'), | ||||||
|  |             ColSpec('float', 'hungry level'), | ||||||
|  |             ColSpec('float', 'sexy level') | ||||||
|  |         ]) | ||||||
|  |         output_schema = Schema([ColSpec('float', 'fortune')]) | ||||||
|  |         signature = ModelSignature(inputs=input_schema, outputs=output_schema) | ||||||
|  |  | ||||||
|  |         # log trained model | ||||||
|  |         mlflow.pytorch.log_model(model, 'model', signature=signature) | ||||||
|  |  | ||||||
|  |         # log training code | ||||||
|  |         mlflow.log_artifact('./train.py', 'code') | ||||||
|  |  | ||||||
|  |     print('Completed.') | ||||||
| @ -0,0 +1,21 @@ | |||||||
|  | artifact_path: model | ||||||
|  | flavors: | ||||||
|  |   python_function: | ||||||
|  |     data: data | ||||||
|  |     env: conda.yaml | ||||||
|  |     loader_module: mlflow.pytorch | ||||||
|  |     pickle_module_name: mlflow.pytorch.pickle_module | ||||||
|  |     python_version: 3.10.9 | ||||||
|  |   pytorch: | ||||||
|  |     code: null | ||||||
|  |     model_data: data | ||||||
|  |     pytorch_version: 1.13.1 | ||||||
|  | mlflow_version: 1.30.0 | ||||||
|  | model_uuid: 1e929c95d90347419e3e0a49d5d783fd | ||||||
|  | run_id: 128f833fc0a2426db86e5073db557a3e | ||||||
|  | signature: | ||||||
|  |   inputs: '[{"name": "age", "type": "float"}, {"name": "mood level", "type": "float"}, | ||||||
|  |     {"name": "health level", "type": "float"}, {"name": "hungry level", "type": "float"}, | ||||||
|  |     {"name": "sexy level", "type": "float"}]' | ||||||
|  |   outputs: '[{"name": "fortune", "type": "float"}]' | ||||||
|  | utc_time_created: '2023-02-23 01:38:39.421914' | ||||||
| @ -0,0 +1,11 @@ | |||||||
|  | channels: | ||||||
|  | - conda-forge | ||||||
|  | dependencies: | ||||||
|  | - python=3.10.9 | ||||||
|  | - pip<=23.0.1 | ||||||
|  | - pip: | ||||||
|  |   - mlflow | ||||||
|  |   - cloudpickle==2.2.1 | ||||||
|  |   - torch==1.13.1 | ||||||
|  |   - tqdm==4.64.1 | ||||||
|  | name: mlflow-env | ||||||
										
											Binary file not shown.
										
									
								
							| @ -0,0 +1 @@ | |||||||
|  | mlflow.pytorch.pickle_module | ||||||
| @ -0,0 +1,7 @@ | |||||||
|  | python: 3.10.9 | ||||||
|  | build_dependencies: | ||||||
|  | - pip==23.0.1 | ||||||
|  | - setuptools==67.3.2 | ||||||
|  | - wheel==0.38.4 | ||||||
|  | dependencies: | ||||||
|  | - -r requirements.txt | ||||||
| @ -0,0 +1,4 @@ | |||||||
|  | mlflow | ||||||
|  | cloudpickle==2.2.1 | ||||||
|  | torch==1.13.1 | ||||||
|  | tqdm==4.64.1 | ||||||
							
								
								
									
										30
									
								
								get_registered_model_via_rest_api.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								get_registered_model_via_rest_api.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,30 @@ | |||||||
|  | # get_registered_model_via_rest_api.py | ||||||
|  | # | ||||||
|  | # author: deng | ||||||
|  | # date  : 20230224 | ||||||
|  |  | ||||||
|  | import json | ||||||
|  | import requests | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |  | ||||||
|  |     registered_model_name = 'fortune_predict_model' | ||||||
|  |     production_model_version = None | ||||||
|  |  | ||||||
|  |     query = {'name': registered_model_name} | ||||||
|  |     res = requests.get('http://127.0.0.1:5000/api/2.0/mlflow/registered-models/get', params=query) | ||||||
|  |     content = json.loads(res.text) | ||||||
|  |     print(content) | ||||||
|  |  | ||||||
|  |     for model in content['registered_model']['latest_versions']: | ||||||
|  |  | ||||||
|  |         if model['current_stage'] == 'Production': | ||||||
|  |             production_model_version = model['version'] | ||||||
|  |  | ||||||
|  |     if production_model_version is not None: | ||||||
|  |         query = {'name': registered_model_name, 'version': production_model_version} | ||||||
|  |     res = requests.get('http://127.0.0.1:5000/api/2.0/mlflow/model-versions/get-download-uri', params=query) | ||||||
|  |     print(res.text) | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     main() | ||||||
							
								
								
									
										0
									
								
								start_mlflow_server.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							
							
						
						
									
										0
									
								
								start_mlflow_server.sh
									
									
									
									
									
										
										
										Normal file → Executable file
									
								
							
		Reference in New Issue
	
	Block a user
	