Compare commits
	
		
			2 Commits
		
	
	
		
			ac6400e93a
			...
			e1f143736e
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| e1f143736e | |||
| 3c39c48242 | 
							
								
								
									
										47
									
								
								optimize_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								optimize_model.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,47 @@ | |||||||
|  | # optimize_model.py | ||||||
|  | # | ||||||
|  | # author: deng | ||||||
|  | # date  : 20230418 | ||||||
|  |  | ||||||
|  | import shutil | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import mlflow | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def optimize_pytorch_model(run_id: str) -> None: | ||||||
|  |     """Optimize Pytorch model on MLflow server, the optimized model will be sent back | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         run_id (str): mlflow run id | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     download_path = Path('./model/downloaded_pytorch_model') | ||||||
|  |     if download_path.is_dir(): | ||||||
|  |         print(f'Remove existed dir: {download_path}') | ||||||
|  |         shutil.rmtree(download_path) | ||||||
|  |  | ||||||
|  |     # Download Pytorch model to local file system | ||||||
|  |     mlflow_model = mlflow.pytorch.load_model(f'runs:/{run_id}/model') | ||||||
|  |     mlflow.pytorch.save_model(mlflow_model, download_path) | ||||||
|  |  | ||||||
|  |     # Optimize model | ||||||
|  |     model = torch.load(download_path.joinpath('data/model.pth')) | ||||||
|  |     dummy_input = torch.randn(5) | ||||||
|  |     torch.onnx.export(model, dummy_input, download_path.joinpath('data/model.onnx')) | ||||||
|  |     # we can not call TensorRT on macOS, so imagine we get a serialized model | ||||||
|  |     download_path.joinpath('data/model.trt').touch() | ||||||
|  |  | ||||||
|  |     # Save optimized model to given run | ||||||
|  |     with mlflow.start_run(run_id=run_id): | ||||||
|  |         mlflow.log_artifact(download_path.joinpath('data/model.trt'), 'model/data') | ||||||
|  |         print(f'Optimized model had been uploaded to server: {mlflow.get_tracking_uri()}') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |     mlflow.set_tracking_uri('http://127.0.0.1:5001') | ||||||
|  |     optimize_pytorch_model( | ||||||
|  |         run_id='f1b7b9a5ba934f158c07975a8a332de5' | ||||||
|  |     ) | ||||||
| @ -0,0 +1,98 @@ | |||||||
|  | # train.py | ||||||
|  | # | ||||||
|  | # author: deng | ||||||
|  | # date  : 20230221 | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from torch.optim import SGD | ||||||
|  | import mlflow | ||||||
|  | from mlflow.models.signature import ModelSignature | ||||||
|  | from mlflow.types.schema import Schema, ColSpec | ||||||
|  | from tqdm import tqdm | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Net(nn.Module): | ||||||
|  |     """ define a simple neural network model """ | ||||||
|  |     def __init__(self): | ||||||
|  |         super(Net, self).__init__() | ||||||
|  |         self.fc1 = nn.Linear(5, 3) | ||||||
|  |         self.fc2 = nn.Linear(3, 1) | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = self.fc1(x) | ||||||
|  |         x = torch.relu(x) | ||||||
|  |         x = self.fc2(x) | ||||||
|  |         return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def train(model, dataloader, criterion, optimizer, epochs): | ||||||
|  |     """ define the training function """ | ||||||
|  |     for epoch in tqdm(range(epochs), 'Epochs'): | ||||||
|  |  | ||||||
|  |         for batch, (inputs, labels) in enumerate(dataloader): | ||||||
|  |  | ||||||
|  |             # forwarding | ||||||
|  |             outputs = model(inputs) | ||||||
|  |             loss = criterion(outputs, labels) | ||||||
|  |  | ||||||
|  |             # update gradient | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |  | ||||||
|  |         # log loss | ||||||
|  |         mlflow.log_metric('train_loss', loss.item(), step=epoch) | ||||||
|  |  | ||||||
|  |     return loss | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |     # set hyper parameters | ||||||
|  |     learning_rate = 1e-2 | ||||||
|  |     batch_size = 10 | ||||||
|  |     epochs = 20 | ||||||
|  |  | ||||||
|  |     # create a dataloader with fake data | ||||||
|  |     dataloader = [(torch.randn(5), torch.randn(1)) for _ in range(100)] | ||||||
|  |     dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size) | ||||||
|  |  | ||||||
|  |     # create the model, criterion, and optimizer | ||||||
|  |     model = Net() | ||||||
|  |     criterion = nn.MSELoss() | ||||||
|  |     optimizer = SGD(model.parameters(), lr=learning_rate) | ||||||
|  |  | ||||||
|  |     # set the tracking URI to the model registry | ||||||
|  |     mlflow.set_tracking_uri('http://127.0.0.1:5001') | ||||||
|  |     mlflow.set_experiment('train_fortune_predict_model') | ||||||
|  |  | ||||||
|  |     # start a new MLflow run | ||||||
|  |     with mlflow.start_run(): | ||||||
|  |  | ||||||
|  |         # train the model | ||||||
|  |         loss = train(model, dataloader, criterion, optimizer, epochs) | ||||||
|  |  | ||||||
|  |         # log some additional metrics | ||||||
|  |         mlflow.log_metric('final_loss', loss.item()) | ||||||
|  |         mlflow.log_param('learning_rate', learning_rate) | ||||||
|  |         mlflow.log_param('batch_size', batch_size) | ||||||
|  |  | ||||||
|  |         # create a signature to record model input and output info | ||||||
|  |         input_schema = Schema([ | ||||||
|  |             ColSpec('float', 'age'), | ||||||
|  |             ColSpec('float', 'mood level'), | ||||||
|  |             ColSpec('float', 'health level'), | ||||||
|  |             ColSpec('float', 'hungry level'), | ||||||
|  |             ColSpec('float', 'sexy level') | ||||||
|  |         ]) | ||||||
|  |         output_schema = Schema([ColSpec('float', 'fortune')]) | ||||||
|  |         signature = ModelSignature(inputs=input_schema, outputs=output_schema) | ||||||
|  |  | ||||||
|  |         # log trained model | ||||||
|  |         mlflow.pytorch.log_model(model, 'model', signature=signature) | ||||||
|  |  | ||||||
|  |         # log training code | ||||||
|  |         mlflow.log_artifact('./train.py', 'code') | ||||||
|  |  | ||||||
|  |     print('Completed.') | ||||||
										
											Binary file not shown.
										
									
								
							
		Reference in New Issue
	
	Block a user
	