update data storage
This commit is contained in:
		| @ -0,0 +1,98 @@ | ||||
| # train.py | ||||
| # | ||||
| # author: deng | ||||
| # date  : 20230221 | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from torch.optim import SGD | ||||
| import mlflow | ||||
| from mlflow.models.signature import ModelSignature | ||||
| from mlflow.types.schema import Schema, ColSpec | ||||
| from tqdm import tqdm | ||||
|  | ||||
|  | ||||
| class Net(nn.Module): | ||||
|     """ define a simple neural network model """ | ||||
|     def __init__(self): | ||||
|         super(Net, self).__init__() | ||||
|         self.fc1 = nn.Linear(5, 3) | ||||
|         self.fc2 = nn.Linear(3, 1) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.fc1(x) | ||||
|         x = torch.relu(x) | ||||
|         x = self.fc2(x) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| def train(model, dataloader, criterion, optimizer, epochs): | ||||
|     """ define the training function """ | ||||
|     for epoch in tqdm(range(epochs), 'Epochs'): | ||||
|  | ||||
|         for batch, (inputs, labels) in enumerate(dataloader): | ||||
|  | ||||
|             # forwarding | ||||
|             outputs = model(inputs) | ||||
|             loss = criterion(outputs, labels) | ||||
|  | ||||
|             # update gradient | ||||
|             optimizer.zero_grad() | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|  | ||||
|         # log loss | ||||
|         mlflow.log_metric('train_loss', loss.item(), step=epoch) | ||||
|  | ||||
|     return loss | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|     # set hyper parameters | ||||
|     learning_rate = 1e-2 | ||||
|     batch_size = 10 | ||||
|     epochs = 20 | ||||
|  | ||||
|     # create a dataloader with fake data | ||||
|     dataloader = [(torch.randn(5), torch.randn(1)) for _ in range(100)] | ||||
|     dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size) | ||||
|  | ||||
|     # create the model, criterion, and optimizer | ||||
|     model = Net() | ||||
|     criterion = nn.MSELoss() | ||||
|     optimizer = SGD(model.parameters(), lr=learning_rate) | ||||
|  | ||||
|     # set the tracking URI to the model registry | ||||
|     mlflow.set_tracking_uri('http://127.0.0.1:5000') | ||||
|     mlflow.set_experiment('train_fortune_predict_model') | ||||
|  | ||||
|     # start a new MLflow run | ||||
|     with mlflow.start_run(): | ||||
|  | ||||
|         # train the model | ||||
|         loss = train(model, dataloader, criterion, optimizer, epochs) | ||||
|  | ||||
|         # log some additional metrics | ||||
|         mlflow.log_metric('final_loss', loss.item()) | ||||
|         mlflow.log_param('learning_rate', learning_rate) | ||||
|         mlflow.log_param('batch_size', batch_size) | ||||
|  | ||||
|         # create a signature to record model input and output info | ||||
|         input_schema = Schema([ | ||||
|             ColSpec('float', 'age'), | ||||
|             ColSpec('float', 'mood level'), | ||||
|             ColSpec('float', 'health level'), | ||||
|             ColSpec('float', 'hungry level'), | ||||
|             ColSpec('float', 'sexy level') | ||||
|         ]) | ||||
|         output_schema = Schema([ColSpec('float', 'fortune')]) | ||||
|         signature = ModelSignature(inputs=input_schema, outputs=output_schema) | ||||
|  | ||||
|         # log trained model | ||||
|         mlflow.pytorch.log_model(model, 'model', signature=signature) | ||||
|  | ||||
|         # log training code | ||||
|         mlflow.log_artifact('./train.py', 'code') | ||||
|  | ||||
|     print('Completed.') | ||||
| @ -0,0 +1,21 @@ | ||||
| artifact_path: model | ||||
| flavors: | ||||
|   python_function: | ||||
|     data: data | ||||
|     env: conda.yaml | ||||
|     loader_module: mlflow.pytorch | ||||
|     pickle_module_name: mlflow.pytorch.pickle_module | ||||
|     python_version: 3.10.9 | ||||
|   pytorch: | ||||
|     code: null | ||||
|     model_data: data | ||||
|     pytorch_version: 1.13.1 | ||||
| mlflow_version: 1.30.0 | ||||
| model_uuid: 1e929c95d90347419e3e0a49d5d783fd | ||||
| run_id: 128f833fc0a2426db86e5073db557a3e | ||||
| signature: | ||||
|   inputs: '[{"name": "age", "type": "float"}, {"name": "mood level", "type": "float"}, | ||||
|     {"name": "health level", "type": "float"}, {"name": "hungry level", "type": "float"}, | ||||
|     {"name": "sexy level", "type": "float"}]' | ||||
|   outputs: '[{"name": "fortune", "type": "float"}]' | ||||
| utc_time_created: '2023-02-23 01:38:39.421914' | ||||
| @ -0,0 +1,11 @@ | ||||
| channels: | ||||
| - conda-forge | ||||
| dependencies: | ||||
| - python=3.10.9 | ||||
| - pip<=23.0.1 | ||||
| - pip: | ||||
|   - mlflow | ||||
|   - cloudpickle==2.2.1 | ||||
|   - torch==1.13.1 | ||||
|   - tqdm==4.64.1 | ||||
| name: mlflow-env | ||||
										
											Binary file not shown.
										
									
								
							| @ -0,0 +1 @@ | ||||
| mlflow.pytorch.pickle_module | ||||
| @ -0,0 +1,7 @@ | ||||
| python: 3.10.9 | ||||
| build_dependencies: | ||||
| - pip==23.0.1 | ||||
| - setuptools==67.3.2 | ||||
| - wheel==0.38.4 | ||||
| dependencies: | ||||
| - -r requirements.txt | ||||
| @ -0,0 +1,4 @@ | ||||
| mlflow | ||||
| cloudpickle==2.2.1 | ||||
| torch==1.13.1 | ||||
| tqdm==4.64.1 | ||||
		Reference in New Issue
	
	Block a user