# train.py # # author: deng # date : 20230221 import torch import torch.nn as nn from torch.optim import SGD import mlflow from mlflow.models.signature import ModelSignature from mlflow.types.schema import Schema, ColSpec from tqdm import tqdm class Net(nn.Module): """ define a simple neural network model """ def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(5, 3) self.fc2 = nn.Linear(3, 1) def forward(self, x): x = self.fc1(x) x = torch.relu(x) x = self.fc2(x) return x def train(model, dataloader, criterion, optimizer, epochs): """ define the training function """ for epoch in tqdm(range(epochs), 'Epochs'): for batch, (inputs, labels) in enumerate(dataloader): # forwarding outputs = model(inputs) loss = criterion(outputs, labels) # update gradient optimizer.zero_grad() loss.backward() optimizer.step() # log loss mlflow.log_metric('train_loss', loss.item(), step=epoch) return loss if __name__ == '__main__': # set hyper parameters learning_rate = 1e-2 batch_size = 10 epochs = 20 # create a dataloader with fake data dataloader = [(torch.randn(5), torch.randn(1)) for _ in range(100)] dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size) # create the model, criterion, and optimizer model = Net() criterion = nn.MSELoss() optimizer = SGD(model.parameters(), lr=learning_rate) # set the tracking URI to the model registry mlflow.set_tracking_uri('http://127.0.0.1:5001') mlflow.set_experiment('train_fortune_predict_model') # start a new MLflow run with mlflow.start_run(): # train the model loss = train(model, dataloader, criterion, optimizer, epochs) # log some additional metrics mlflow.log_metric('final_loss', loss.item()) mlflow.log_param('learning_rate', learning_rate) mlflow.log_param('batch_size', batch_size) # create a signature to record model input and output info input_schema = Schema([ ColSpec('float', 'age'), ColSpec('float', 'mood level'), ColSpec('float', 'health level'), ColSpec('float', 'hungry level'), ColSpec('float', 'sexy level') ]) output_schema = Schema([ColSpec('float', 'fortune')]) signature = ModelSignature(inputs=input_schema, outputs=output_schema) # log trained model mlflow.pytorch.log_model(model, 'model', signature=signature) # log training code mlflow.log_artifact('./train.py', 'code') print('Completed.')