record model input&output, save model to file system
This commit is contained in:
		
							
								
								
									
										9
									
								
								env.yaml
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								env.yaml
									
									
									
									
									
								
							| @ -1,6 +1,7 @@ | |||||||
| name: torch | name: torch | ||||||
| channels: | channels: | ||||||
|   - pytorch |   - pytorch | ||||||
|  |   - anaconda | ||||||
|   - conda-forge |   - conda-forge | ||||||
| dependencies: | dependencies: | ||||||
|   - alembic=1.9.4 |   - alembic=1.9.4 | ||||||
| @ -8,10 +9,11 @@ dependencies: | |||||||
|   - appdirs=1.4.4 |   - appdirs=1.4.4 | ||||||
|   - bcrypt=3.2.2 |   - bcrypt=3.2.2 | ||||||
|   - blinker=1.5 |   - blinker=1.5 | ||||||
|  |   - bottleneck=1.3.4 | ||||||
|   - brotlipy=0.7.0 |   - brotlipy=0.7.0 | ||||||
|   - bzip2=1.0.8 |   - bzip2=1.0.8 | ||||||
|   - ca-certificates=2022.12.7 |   - ca-certificates=2022.4.26 | ||||||
|   - certifi=2022.12.7 |   - certifi=2022.6.15 | ||||||
|   - cffi=1.15.1 |   - cffi=1.15.1 | ||||||
|   - charset-normalizer=2.1.1 |   - charset-normalizer=2.1.1 | ||||||
|   - click=8.1.3 |   - click=8.1.3 | ||||||
| @ -80,6 +82,7 @@ dependencies: | |||||||
|   - mlflow=1.30.0 |   - mlflow=1.30.0 | ||||||
|   - ncurses=6.3 |   - ncurses=6.3 | ||||||
|   - nettle=3.8.1 |   - nettle=3.8.1 | ||||||
|  |   - numexpr=2.8.1 | ||||||
|   - numpy=1.24.2 |   - numpy=1.24.2 | ||||||
|   - oauthlib=3.2.2 |   - oauthlib=3.2.2 | ||||||
|   - openh264=2.3.1 |   - openh264=2.3.1 | ||||||
| @ -87,7 +90,7 @@ dependencies: | |||||||
|   - openssl=3.0.8 |   - openssl=3.0.8 | ||||||
|   - p11-kit=0.24.1 |   - p11-kit=0.24.1 | ||||||
|   - packaging=21.3 |   - packaging=21.3 | ||||||
|   - pandas=1.5.3 |   - pandas=1.4.2 | ||||||
|   - paramiko=3.0.0 |   - paramiko=3.0.0 | ||||||
|   - pillow=9.4.0 |   - pillow=9.4.0 | ||||||
|   - pip=23.0.1 |   - pip=23.0.1 | ||||||
|  | |||||||
							
								
								
									
										11
									
								
								predict.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								predict.py
									
									
									
									
									
								
							| @ -13,9 +13,12 @@ if __name__ == '__main__': | |||||||
|     mlflow.set_tracking_uri('http://127.0.0.1:5000') |     mlflow.set_tracking_uri('http://127.0.0.1:5000') | ||||||
|  |  | ||||||
|     # load production model |     # load production model | ||||||
|     model = mlflow.pytorch.load_model('models:/cls_model/production') |     model = mlflow.pytorch.load_model('models:/fortune_predict_model/production') | ||||||
|  |  | ||||||
|     # predict |     # predict | ||||||
|     fake_data = torch.randn(10) |     my_personal_info = torch.randn(5) | ||||||
|     output = model(fake_data) |     my_fortune = model(my_personal_info) | ||||||
|     print(output) |     print(my_fortune) | ||||||
|  |  | ||||||
|  |     # save model and env to local file system | ||||||
|  |     mlflow.pytorch.save_model(model, './fortune_predict_model') | ||||||
| @ -4,4 +4,4 @@ | |||||||
| # author: deng | # author: deng | ||||||
| # date  : 20230221 | # date  : 20230221 | ||||||
|  |  | ||||||
| mlflow server --backend-store-uri sqlite:///mlflow.db --default-artifact-root ./artifacts | mlflow server --backend-store-uri sqlite:///mlflow.db --default-artifact-root ./artifacts --port 5000 | ||||||
							
								
								
									
										39
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										39
									
								
								train.py
									
									
									
									
									
								
							| @ -6,16 +6,18 @@ | |||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from torch.optim import SGD | from torch.optim import SGD | ||||||
| from tqdm import tqdm |  | ||||||
| import mlflow | import mlflow | ||||||
|  | from mlflow.models.signature import ModelSignature | ||||||
|  | from mlflow.types.schema import Schema, ColSpec | ||||||
|  | from tqdm import tqdm | ||||||
|  |  | ||||||
|  |  | ||||||
| class Net(nn.Module): | class Net(nn.Module): | ||||||
|     """ define a simple neural network model """ |     """ define a simple neural network model """ | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         super(Net, self).__init__() |         super(Net, self).__init__() | ||||||
|         self.fc1 = nn.Linear(10, 5) |         self.fc1 = nn.Linear(5, 3) | ||||||
|         self.fc2 = nn.Linear(5, 1) |         self.fc2 = nn.Linear(3, 1) | ||||||
|  |  | ||||||
|     def forward(self, x): |     def forward(self, x): | ||||||
|         x = self.fc1(x) |         x = self.fc1(x) | ||||||
| @ -28,7 +30,7 @@ def train(model, dataloader, criterion, optimizer, epochs): | |||||||
|     """ define the training function """ |     """ define the training function """ | ||||||
|     for epoch in tqdm(range(epochs), 'Epochs'): |     for epoch in tqdm(range(epochs), 'Epochs'): | ||||||
|  |  | ||||||
|         for i, (inputs, labels) in enumerate(dataloader): |         for batch, (inputs, labels) in enumerate(dataloader): | ||||||
|  |  | ||||||
|             # forwarding |             # forwarding | ||||||
|             outputs = model(inputs) |             outputs = model(inputs) | ||||||
| @ -40,7 +42,7 @@ def train(model, dataloader, criterion, optimizer, epochs): | |||||||
|             optimizer.step() |             optimizer.step() | ||||||
|  |  | ||||||
|         # log loss |         # log loss | ||||||
|             mlflow.log_metric('train_loss', loss.item(), step=i) |         mlflow.log_metric('train_loss', loss.item(), step=epoch) | ||||||
|  |  | ||||||
|     return loss |     return loss | ||||||
|  |  | ||||||
| @ -49,11 +51,12 @@ if __name__ == '__main__': | |||||||
|  |  | ||||||
|     # set hyper parameters |     # set hyper parameters | ||||||
|     learning_rate = 1e-2 |     learning_rate = 1e-2 | ||||||
|  |     batch_size = 10 | ||||||
|     epochs = 20 |     epochs = 20 | ||||||
|  |  | ||||||
|     # create a dataloader with fake data |     # create a dataloader with fake data | ||||||
|     dataloader = [(torch.randn(10), torch.randn(1)) for _ in range(100)] |     dataloader = [(torch.randn(5), torch.randn(1)) for _ in range(100)] | ||||||
|     dataloader = torch.utils.data.DataLoader(dataloader, batch_size=10) |     dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size) | ||||||
|  |  | ||||||
|     # create the model, criterion, and optimizer |     # create the model, criterion, and optimizer | ||||||
|     model = Net() |     model = Net() | ||||||
| @ -62,18 +65,34 @@ if __name__ == '__main__': | |||||||
|  |  | ||||||
|     # set the tracking URI to the model registry |     # set the tracking URI to the model registry | ||||||
|     mlflow.set_tracking_uri('http://127.0.0.1:5000') |     mlflow.set_tracking_uri('http://127.0.0.1:5000') | ||||||
|  |     mlflow.set_experiment('train_fortune_predict_model') | ||||||
|  |  | ||||||
|     # start the MLflow run |     # start a new MLflow run | ||||||
|     with mlflow.start_run(): |     with mlflow.start_run(): | ||||||
|  |  | ||||||
|         # train the model and log the loss |         # train the model | ||||||
|         loss = train(model, dataloader, criterion, optimizer, epochs) |         loss = train(model, dataloader, criterion, optimizer, epochs) | ||||||
|  |  | ||||||
|         # log some additional metrics |         # log some additional metrics | ||||||
|         mlflow.log_metric('final_loss', loss.item()) |         mlflow.log_metric('final_loss', loss.item()) | ||||||
|         mlflow.log_param('learning_rate', learning_rate) |         mlflow.log_param('learning_rate', learning_rate) | ||||||
|  |         mlflow.log_param('batch_size', batch_size) | ||||||
|  |  | ||||||
|  |         # create a signature to record model input and output info | ||||||
|  |         input_schema = Schema([ | ||||||
|  |             ColSpec('float', 'age'), | ||||||
|  |             ColSpec('float', 'mood level'), | ||||||
|  |             ColSpec('float', 'health level'), | ||||||
|  |             ColSpec('float', 'hungry level'), | ||||||
|  |             ColSpec('float', 'sexy level') | ||||||
|  |         ]) | ||||||
|  |         output_schema = Schema([ColSpec('float', 'fortune')]) | ||||||
|  |         signature = ModelSignature(inputs=input_schema, outputs=output_schema) | ||||||
|  |  | ||||||
|         # log trained model |         # log trained model | ||||||
|         mlflow.pytorch.log_model(model, 'model') |         mlflow.pytorch.log_model(model, 'model', signature=signature) | ||||||
|  |  | ||||||
|  |         # log training code | ||||||
|  |         mlflow.log_artifact('./train.py', 'code') | ||||||
|  |  | ||||||
|     print('Completed.') |     print('Completed.') | ||||||
		Reference in New Issue
	
	Block a user