update server data

This commit is contained in:
2023-02-22 16:26:02 +08:00
parent 7015b5c1a5
commit 2dd734b87b
134 changed files with 2502 additions and 0 deletions

View File

@ -0,0 +1,83 @@
# train.py
#
# author: deng
# date : 20230221
import torch
import torch.nn as nn
from torch.optim import SGD
from tqdm import tqdm
import mlflow
class Net(nn.Module):
""" define a simple neural network model """
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
def train(model, dataloader, criterion, optimizer, epochs):
""" define the training function """
for epoch in tqdm(range(epochs), 'Epochs'):
for i, (inputs, labels) in enumerate(dataloader):
# forwarding
outputs = model(inputs)
loss = criterion(outputs, labels)
# update gradient
optimizer.zero_grad()
loss.backward()
optimizer.step()
# log loss
mlflow.log_metric('train_loss', loss.item(), step=i)
return loss
if __name__ == '__main__':
# set hyper parameters
learning_rate = 1e-2
epochs = 20
# create a dataloader with fake data
dataloader = [(torch.randn(10), torch.randn(1)) for _ in range(100)]
dataloader = torch.utils.data.DataLoader(dataloader, batch_size=10)
# create the model, criterion, and optimizer
model = Net()
criterion = nn.MSELoss()
optimizer = SGD(model.parameters(), lr=learning_rate)
# set the tracking URI to the model registry
mlflow.set_tracking_uri('http://127.0.0.1:5000')
mlflow.set_experiment('/mlflow_testing')
# start the MLflow run
with mlflow.start_run():
# train the model
loss = train(model, dataloader, criterion, optimizer, epochs)
# log some additional metrics
mlflow.log_metric('final_loss', loss.item())
mlflow.log_param('learning_rate', learning_rate)
# log trained model
mlflow.pytorch.log_model(model, 'model')
# log training code
mlflow.log_artifact('./train.py', 'code')
print('Completed.')

View File

@ -0,0 +1,16 @@
artifact_path: model
flavors:
python_function:
data: data
env: conda.yaml
loader_module: mlflow.pytorch
pickle_module_name: mlflow.pytorch.pickle_module
python_version: 3.10.9
pytorch:
code: null
model_data: data
pytorch_version: 1.13.1
mlflow_version: 1.30.0
model_uuid: ff8b845d6a174ffabfc49a18673c6c04
run_id: c248a4299f97423987a9496a2241ab1a
utc_time_created: '2023-02-22 01:10:55.971443'

View File

@ -0,0 +1,11 @@
channels:
- conda-forge
dependencies:
- python=3.10.9
- pip<=23.0.1
- pip:
- mlflow
- cloudpickle==2.2.1
- torch==1.13.1
- tqdm==4.64.1
name: mlflow-env

View File

@ -0,0 +1 @@
mlflow.pytorch.pickle_module

View File

@ -0,0 +1,7 @@
python: 3.10.9
build_dependencies:
- pip==23.0.1
- setuptools==67.3.2
- wheel==0.38.4
dependencies:
- -r requirements.txt

View File

@ -0,0 +1,4 @@
mlflow
cloudpickle==2.2.1
torch==1.13.1
tqdm==4.64.1