Compare commits
	
		
			2 Commits
		
	
	
		
			ac6400e93a
			...
			e1f143736e
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| e1f143736e | |||
| 3c39c48242 | 
							
								
								
									
										47
									
								
								optimize_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								optimize_model.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,47 @@ | ||||
| # optimize_model.py | ||||
| # | ||||
| # author: deng | ||||
| # date  : 20230418 | ||||
|  | ||||
| import shutil | ||||
| from pathlib import Path | ||||
|  | ||||
| import torch | ||||
| import mlflow | ||||
|  | ||||
|  | ||||
| def optimize_pytorch_model(run_id: str) -> None: | ||||
|     """Optimize Pytorch model on MLflow server, the optimized model will be sent back | ||||
|  | ||||
|     Args: | ||||
|         run_id (str): mlflow run id | ||||
|     """ | ||||
|  | ||||
|     download_path = Path('./model/downloaded_pytorch_model') | ||||
|     if download_path.is_dir(): | ||||
|         print(f'Remove existed dir: {download_path}') | ||||
|         shutil.rmtree(download_path) | ||||
|  | ||||
|     # Download Pytorch model to local file system | ||||
|     mlflow_model = mlflow.pytorch.load_model(f'runs:/{run_id}/model') | ||||
|     mlflow.pytorch.save_model(mlflow_model, download_path) | ||||
|  | ||||
|     # Optimize model | ||||
|     model = torch.load(download_path.joinpath('data/model.pth')) | ||||
|     dummy_input = torch.randn(5) | ||||
|     torch.onnx.export(model, dummy_input, download_path.joinpath('data/model.onnx')) | ||||
|     # we can not call TensorRT on macOS, so imagine we get a serialized model | ||||
|     download_path.joinpath('data/model.trt').touch() | ||||
|  | ||||
|     # Save optimized model to given run | ||||
|     with mlflow.start_run(run_id=run_id): | ||||
|         mlflow.log_artifact(download_path.joinpath('data/model.trt'), 'model/data') | ||||
|         print(f'Optimized model had been uploaded to server: {mlflow.get_tracking_uri()}') | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|     mlflow.set_tracking_uri('http://127.0.0.1:5001') | ||||
|     optimize_pytorch_model( | ||||
|         run_id='f1b7b9a5ba934f158c07975a8a332de5' | ||||
|     ) | ||||
| @ -0,0 +1,98 @@ | ||||
| # train.py | ||||
| # | ||||
| # author: deng | ||||
| # date  : 20230221 | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch.optim import SGD | ||||
| import mlflow | ||||
| from mlflow.models.signature import ModelSignature | ||||
| from mlflow.types.schema import Schema, ColSpec | ||||
| from tqdm import tqdm | ||||
|  | ||||
|  | ||||
| class Net(nn.Module): | ||||
|     """ define a simple neural network model """ | ||||
|     def __init__(self): | ||||
|         super(Net, self).__init__() | ||||
|         self.fc1 = nn.Linear(5, 3) | ||||
|         self.fc2 = nn.Linear(3, 1) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.fc1(x) | ||||
|         x = torch.relu(x) | ||||
|         x = self.fc2(x) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| def train(model, dataloader, criterion, optimizer, epochs): | ||||
|     """ define the training function """ | ||||
|     for epoch in tqdm(range(epochs), 'Epochs'): | ||||
|  | ||||
|         for batch, (inputs, labels) in enumerate(dataloader): | ||||
|  | ||||
|             # forwarding | ||||
|             outputs = model(inputs) | ||||
|             loss = criterion(outputs, labels) | ||||
|  | ||||
|             # update gradient | ||||
|             optimizer.zero_grad() | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # log loss | ||||
|         mlflow.log_metric('train_loss', loss.item(), step=epoch) | ||||
|  | ||||
|     return loss | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|     # set hyper parameters | ||||
|     learning_rate = 1e-2 | ||||
|     batch_size = 10 | ||||
|     epochs = 20 | ||||
|  | ||||
|     # create a dataloader with fake data | ||||
|     dataloader = [(torch.randn(5), torch.randn(1)) for _ in range(100)] | ||||
|     dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size) | ||||
|  | ||||
|     # create the model, criterion, and optimizer | ||||
|     model = Net() | ||||
|     criterion = nn.MSELoss() | ||||
|     optimizer = SGD(model.parameters(), lr=learning_rate) | ||||
|  | ||||
|     # set the tracking URI to the model registry | ||||
|     mlflow.set_tracking_uri('http://127.0.0.1:5001') | ||||
|     mlflow.set_experiment('train_fortune_predict_model') | ||||
|  | ||||
|     # start a new MLflow run | ||||
|     with mlflow.start_run(): | ||||
|  | ||||
|         # train the model | ||||
|         loss = train(model, dataloader, criterion, optimizer, epochs) | ||||
|  | ||||
|         # log some additional metrics | ||||
|         mlflow.log_metric('final_loss', loss.item()) | ||||
|         mlflow.log_param('learning_rate', learning_rate) | ||||
|         mlflow.log_param('batch_size', batch_size) | ||||
|  | ||||
|         # create a signature to record model input and output info | ||||
|         input_schema = Schema([ | ||||
|             ColSpec('float', 'age'), | ||||
|             ColSpec('float', 'mood level'), | ||||
|             ColSpec('float', 'health level'), | ||||
|             ColSpec('float', 'hungry level'), | ||||
|             ColSpec('float', 'sexy level') | ||||
|         ]) | ||||
|         output_schema = Schema([ColSpec('float', 'fortune')]) | ||||
|         signature = ModelSignature(inputs=input_schema, outputs=output_schema) | ||||
|  | ||||
|         # log trained model | ||||
|         mlflow.pytorch.log_model(model, 'model', signature=signature) | ||||
|  | ||||
|         # log training code | ||||
|         mlflow.log_artifact('./train.py', 'code') | ||||
|  | ||||
|     print('Completed.') | ||||
										
											Binary file not shown.
										
									
								
							
		Reference in New Issue
	
	Block a user
	