From e62a8e6582bedf53c1ff8a5f78c3f7b47ec3014b Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 11 Jun 2021 18:50:14 +0200 Subject: [PATCH] [BUGFIX] Log loss in NG and GNG --- prototorch/models/unsupervised.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/prototorch/models/unsupervised.py b/prototorch/models/unsupervised.py index af17fd8..d171115 100644 --- a/prototorch/models/unsupervised.py +++ b/prototorch/models/unsupervised.py @@ -99,9 +99,14 @@ class NeuralGas(UnsupervisedPrototypeModel): # TODO Check if the batch has labels x = train_batch[0] d = self.compute_distances(x) - cost, _ = self.energy_layer(d) + loss, _ = self.energy_layer(d) self.topology_layer(d) - return cost + self.log("loss", loss) + return loss + + # def training_epoch_end(self, training_step_outputs): + # print(f"{self.trainer.lr_schedulers}") + # print(f"{self.trainer.lr_schedulers[0]['scheduler'].optimizer}") class GrowingNeuralGas(NeuralGas): @@ -121,7 +126,7 @@ class GrowingNeuralGas(NeuralGas): # TODO Check if the batch has labels x = train_batch[0] d = self.compute_distances(x) - cost, order = self.energy_layer(d) + loss, order = self.energy_layer(d) winner = order[:, 0] mask = torch.zeros_like(d) mask[torch.arange(len(mask)), winner] = 1.0 @@ -131,7 +136,8 @@ class GrowingNeuralGas(NeuralGas): self.errors *= self.hparams.step_reduction self.topology_layer(d) - return cost + self.log("loss", loss) + return loss def configure_callbacks(self): return [