82 lines
2.0 KiB
Python
82 lines
2.0 KiB
Python
# train.py
|
|
#
|
|
# author: deng
|
|
# date : 20230221
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.optim import SGD
|
|
from tqdm import tqdm
|
|
import mlflow
|
|
|
|
|
|
class Net(nn.Module):
|
|
""" define a simple neural network model """
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.fc1 = nn.Linear(10, 5)
|
|
self.fc2 = nn.Linear(5, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = torch.relu(x)
|
|
x = self.fc2(x)
|
|
return x
|
|
|
|
|
|
def train(model, dataloader, criterion, optimizer, epochs):
|
|
""" define the training function """
|
|
for epoch in tqdm(range(epochs), 'Epochs'):
|
|
|
|
for i, (inputs, labels) in enumerate(dataloader):
|
|
|
|
# forwarding
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
|
|
# update gradient
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# log loss
|
|
mlflow.log_metric('train_loss', loss.item(), step=i)
|
|
|
|
return loss
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
# set hyper parameters
|
|
learning_rate = 1e-2
|
|
epochs = 20
|
|
|
|
# create a dataloader with fake data
|
|
dataloader = [(torch.randn(10), torch.randn(1)) for _ in range(100)]
|
|
dataloader = torch.utils.data.DataLoader(dataloader, batch_size=10)
|
|
|
|
# create the model, criterion, and optimizer
|
|
model = Net()
|
|
criterion = nn.MSELoss()
|
|
optimizer = SGD(model.parameters(), lr=learning_rate)
|
|
|
|
# set the tracking URI to the model registry
|
|
mlflow.set_tracking_uri('http://127.0.0.1:5000')
|
|
|
|
# start the MLflow run
|
|
with mlflow.start_run():
|
|
|
|
# train the model and log the loss
|
|
loss = train(model, dataloader, criterion, optimizer, epochs)
|
|
|
|
# log some additional metrics
|
|
mlflow.log_metric('final_loss', loss.item())
|
|
mlflow.log_param('learning_rate', learning_rate)
|
|
|
|
# log trained model
|
|
mlflow.pytorch.log_model(model, 'model')
|
|
|
|
# log training code
|
|
mlflow.log_artifact('./train.py')
|
|
|
|
print('Completed.') |