log learning rate

This commit is contained in:
2026-06-18 11:20:08 +08:00
parent 253d16f84e
commit c5d39c8e16

View File

@ -23,9 +23,9 @@ class Train:
self.config = load_config(config_path)['train'] self.config = load_config(config_path)['train']
self._device = torch.device(self.config['device_type']) self._device = torch.device(self.config['device_type'])
self._ensure_deterministic() self._ensure_reproducibility()
def _ensure_deterministic(self) -> None: def _ensure_reproducibility(self) -> None:
torch.use_deterministic_algorithms(mode=True, warn_only=True) torch.use_deterministic_algorithms(mode=True, warn_only=True)
random.seed(self.config['random_seed']) random.seed(self.config['random_seed'])
np.random.seed(self.config['random_seed']) np.random.seed(self.config['random_seed'])
@ -125,6 +125,7 @@ class Train:
metrics.update(outputs, targets) metrics.update(outputs, targets)
train_metrics = metrics.compute() train_metrics = metrics.compute()
avg_train_loss = total_train_loss / len(train_dataloader) avg_train_loss = total_train_loss / len(train_dataloader)
train_learning_rate = round(optimizer.param_groups[0]['lr'], 6)
metrics.reset() metrics.reset()
model.eval() model.eval()
@ -141,6 +142,7 @@ class Train:
avg_valid_loss = total_valid_loss / len(valid_dataloader) avg_valid_loss = total_valid_loss / len(valid_dataloader)
live.log_metric('train/loss', avg_train_loss) live.log_metric('train/loss', avg_train_loss)
live.log_metric('train/learning_rate', train_learning_rate)
live.log_metric('train/accuracy', train_metrics['MulticlassAccuracy'].item()) live.log_metric('train/accuracy', train_metrics['MulticlassAccuracy'].item())
live.log_metric('train/precision', train_metrics['MulticlassPrecision'].item()) live.log_metric('train/precision', train_metrics['MulticlassPrecision'].item())
live.log_metric('train/recall', train_metrics['MulticlassRecall'].item()) live.log_metric('train/recall', train_metrics['MulticlassRecall'].item())