log learning rate
This commit is contained in:
@ -23,9 +23,9 @@ class Train:
|
|||||||
self.config = load_config(config_path)['train']
|
self.config = load_config(config_path)['train']
|
||||||
self._device = torch.device(self.config['device_type'])
|
self._device = torch.device(self.config['device_type'])
|
||||||
|
|
||||||
self._ensure_deterministic()
|
self._ensure_reproducibility()
|
||||||
|
|
||||||
def _ensure_deterministic(self) -> None:
|
def _ensure_reproducibility(self) -> None:
|
||||||
torch.use_deterministic_algorithms(mode=True, warn_only=True)
|
torch.use_deterministic_algorithms(mode=True, warn_only=True)
|
||||||
random.seed(self.config['random_seed'])
|
random.seed(self.config['random_seed'])
|
||||||
np.random.seed(self.config['random_seed'])
|
np.random.seed(self.config['random_seed'])
|
||||||
@ -125,6 +125,7 @@ class Train:
|
|||||||
metrics.update(outputs, targets)
|
metrics.update(outputs, targets)
|
||||||
train_metrics = metrics.compute()
|
train_metrics = metrics.compute()
|
||||||
avg_train_loss = total_train_loss / len(train_dataloader)
|
avg_train_loss = total_train_loss / len(train_dataloader)
|
||||||
|
train_learning_rate = round(optimizer.param_groups[0]['lr'], 6)
|
||||||
|
|
||||||
metrics.reset()
|
metrics.reset()
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -141,6 +142,7 @@ class Train:
|
|||||||
avg_valid_loss = total_valid_loss / len(valid_dataloader)
|
avg_valid_loss = total_valid_loss / len(valid_dataloader)
|
||||||
|
|
||||||
live.log_metric('train/loss', avg_train_loss)
|
live.log_metric('train/loss', avg_train_loss)
|
||||||
|
live.log_metric('train/learning_rate', train_learning_rate)
|
||||||
live.log_metric('train/accuracy', train_metrics['MulticlassAccuracy'].item())
|
live.log_metric('train/accuracy', train_metrics['MulticlassAccuracy'].item())
|
||||||
live.log_metric('train/precision', train_metrics['MulticlassPrecision'].item())
|
live.log_metric('train/precision', train_metrics['MulticlassPrecision'].item())
|
||||||
live.log_metric('train/recall', train_metrics['MulticlassRecall'].item())
|
live.log_metric('train/recall', train_metrics['MulticlassRecall'].item())
|
||||||
|
|||||||
Reference in New Issue
Block a user