update data storage
This commit is contained in:
		| @ -0,0 +1,98 @@ | |||||||
|  | # train.py | ||||||
|  | # | ||||||
|  | # author: deng | ||||||
|  | # date  : 20230221 | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from torch.optim import SGD | ||||||
|  | import mlflow | ||||||
|  | from mlflow.models.signature import ModelSignature | ||||||
|  | from mlflow.types.schema import Schema, ColSpec | ||||||
|  | from tqdm import tqdm | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Net(nn.Module): | ||||||
|  |     """ define a simple neural network model """ | ||||||
|  |     def __init__(self): | ||||||
|  |         super(Net, self).__init__() | ||||||
|  |         self.fc1 = nn.Linear(5, 3) | ||||||
|  |         self.fc2 = nn.Linear(3, 1) | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = self.fc1(x) | ||||||
|  |         x = torch.relu(x) | ||||||
|  |         x = self.fc2(x) | ||||||
|  |         return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def train(model, dataloader, criterion, optimizer, epochs): | ||||||
|  |     """ define the training function """ | ||||||
|  |     for epoch in tqdm(range(epochs), 'Epochs'): | ||||||
|  |  | ||||||
|  |         for batch, (inputs, labels) in enumerate(dataloader): | ||||||
|  |  | ||||||
|  |             # forwarding | ||||||
|  |             outputs = model(inputs) | ||||||
|  |             loss = criterion(outputs, labels) | ||||||
|  |  | ||||||
|  |             # update gradient | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |  | ||||||
|  |         # log loss | ||||||
|  |         mlflow.log_metric('train_loss', loss.item(), step=epoch) | ||||||
|  |  | ||||||
|  |     return loss | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |     # set hyper parameters | ||||||
|  |     learning_rate = 1e-2 | ||||||
|  |     batch_size = 10 | ||||||
|  |     epochs = 20 | ||||||
|  |  | ||||||
|  |     # create a dataloader with fake data | ||||||
|  |     dataloader = [(torch.randn(5), torch.randn(1)) for _ in range(100)] | ||||||
|  |     dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size) | ||||||
|  |  | ||||||
|  |     # create the model, criterion, and optimizer | ||||||
|  |     model = Net() | ||||||
|  |     criterion = nn.MSELoss() | ||||||
|  |     optimizer = SGD(model.parameters(), lr=learning_rate) | ||||||
|  |  | ||||||
|  |     # set the tracking URI to the model registry | ||||||
|  |     mlflow.set_tracking_uri('http://127.0.0.1:5000') | ||||||
|  |     mlflow.set_experiment('train_fortune_predict_model') | ||||||
|  |  | ||||||
|  |     # start a new MLflow run | ||||||
|  |     with mlflow.start_run(): | ||||||
|  |  | ||||||
|  |         # train the model | ||||||
|  |         loss = train(model, dataloader, criterion, optimizer, epochs) | ||||||
|  |  | ||||||
|  |         # log some additional metrics | ||||||
|  |         mlflow.log_metric('final_loss', loss.item()) | ||||||
|  |         mlflow.log_param('learning_rate', learning_rate) | ||||||
|  |         mlflow.log_param('batch_size', batch_size) | ||||||
|  |  | ||||||
|  |         # create a signature to record model input and output info | ||||||
|  |         input_schema = Schema([ | ||||||
|  |             ColSpec('float', 'age'), | ||||||
|  |             ColSpec('float', 'mood level'), | ||||||
|  |             ColSpec('float', 'health level'), | ||||||
|  |             ColSpec('float', 'hungry level'), | ||||||
|  |             ColSpec('float', 'sexy level') | ||||||
|  |         ]) | ||||||
|  |         output_schema = Schema([ColSpec('float', 'fortune')]) | ||||||
|  |         signature = ModelSignature(inputs=input_schema, outputs=output_schema) | ||||||
|  |  | ||||||
|  |         # log trained model | ||||||
|  |         mlflow.pytorch.log_model(model, 'model', signature=signature) | ||||||
|  |  | ||||||
|  |         # log training code | ||||||
|  |         mlflow.log_artifact('./train.py', 'code') | ||||||
|  |  | ||||||
|  |     print('Completed.') | ||||||
| @ -0,0 +1,21 @@ | |||||||
|  | artifact_path: model | ||||||
|  | flavors: | ||||||
|  |   python_function: | ||||||
|  |     data: data | ||||||
|  |     env: conda.yaml | ||||||
|  |     loader_module: mlflow.pytorch | ||||||
|  |     pickle_module_name: mlflow.pytorch.pickle_module | ||||||
|  |     python_version: 3.10.9 | ||||||
|  |   pytorch: | ||||||
|  |     code: null | ||||||
|  |     model_data: data | ||||||
|  |     pytorch_version: 1.13.1 | ||||||
|  | mlflow_version: 1.30.0 | ||||||
|  | model_uuid: 1e929c95d90347419e3e0a49d5d783fd | ||||||
|  | run_id: 128f833fc0a2426db86e5073db557a3e | ||||||
|  | signature: | ||||||
|  |   inputs: '[{"name": "age", "type": "float"}, {"name": "mood level", "type": "float"}, | ||||||
|  |     {"name": "health level", "type": "float"}, {"name": "hungry level", "type": "float"}, | ||||||
|  |     {"name": "sexy level", "type": "float"}]' | ||||||
|  |   outputs: '[{"name": "fortune", "type": "float"}]' | ||||||
|  | utc_time_created: '2023-02-23 01:38:39.421914' | ||||||
| @ -0,0 +1,11 @@ | |||||||
|  | channels: | ||||||
|  | - conda-forge | ||||||
|  | dependencies: | ||||||
|  | - python=3.10.9 | ||||||
|  | - pip<=23.0.1 | ||||||
|  | - pip: | ||||||
|  |   - mlflow | ||||||
|  |   - cloudpickle==2.2.1 | ||||||
|  |   - torch==1.13.1 | ||||||
|  |   - tqdm==4.64.1 | ||||||
|  | name: mlflow-env | ||||||
										
											Binary file not shown.
										
									
								
							| @ -0,0 +1 @@ | |||||||
|  | mlflow.pytorch.pickle_module | ||||||
| @ -0,0 +1,7 @@ | |||||||
|  | python: 3.10.9 | ||||||
|  | build_dependencies: | ||||||
|  | - pip==23.0.1 | ||||||
|  | - setuptools==67.3.2 | ||||||
|  | - wheel==0.38.4 | ||||||
|  | dependencies: | ||||||
|  | - -r requirements.txt | ||||||
| @ -0,0 +1,4 @@ | |||||||
|  | mlflow | ||||||
|  | cloudpickle==2.2.1 | ||||||
|  | torch==1.13.1 | ||||||
|  | tqdm==4.64.1 | ||||||
		Reference in New Issue
	
	Block a user