From c5d39c8e16c8a8b6bd518c3c4ba90e48c9bc4953 Mon Sep 17 00:00:00 2001 From: deng Date: Thu, 18 Jun 2026 11:20:08 +0800 Subject: [PATCH] log learning rate --- quickdraw_bot/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/quickdraw_bot/train.py b/quickdraw_bot/train.py index 83f7e61..695c5f0 100644 --- a/quickdraw_bot/train.py +++ b/quickdraw_bot/train.py @@ -23,9 +23,9 @@ class Train: self.config = load_config(config_path)['train'] self._device = torch.device(self.config['device_type']) - self._ensure_deterministic() + self._ensure_reproducibility() - def _ensure_deterministic(self) -> None: + def _ensure_reproducibility(self) -> None: torch.use_deterministic_algorithms(mode=True, warn_only=True) random.seed(self.config['random_seed']) np.random.seed(self.config['random_seed']) @@ -125,6 +125,7 @@ class Train: metrics.update(outputs, targets) train_metrics = metrics.compute() avg_train_loss = total_train_loss / len(train_dataloader) + train_learning_rate = round(optimizer.param_groups[0]['lr'], 6) metrics.reset() model.eval() @@ -141,6 +142,7 @@ class Train: avg_valid_loss = total_valid_loss / len(valid_dataloader) live.log_metric('train/loss', avg_train_loss) + live.log_metric('train/learning_rate', train_learning_rate) live.log_metric('train/accuracy', train_metrics['MulticlassAccuracy'].item()) live.log_metric('train/precision', train_metrics['MulticlassPrecision'].item()) live.log_metric('train/recall', train_metrics['MulticlassRecall'].item())