[BUGFIX] Log loss in NG and GNG

This commit is contained in:
Jensun Ravichandran 2021-06-11 18:50:14 +02:00
parent ea33196a50
commit e62a8e6582

View File

@ -99,9 +99,14 @@ class NeuralGas(UnsupervisedPrototypeModel):
# TODO Check if the batch has labels # TODO Check if the batch has labels
x = train_batch[0] x = train_batch[0]
d = self.compute_distances(x) d = self.compute_distances(x)
cost, _ = self.energy_layer(d) loss, _ = self.energy_layer(d)
self.topology_layer(d) self.topology_layer(d)
return cost self.log("loss", loss)
return loss
# def training_epoch_end(self, training_step_outputs):
# print(f"{self.trainer.lr_schedulers}")
# print(f"{self.trainer.lr_schedulers[0]['scheduler'].optimizer}")
class GrowingNeuralGas(NeuralGas): class GrowingNeuralGas(NeuralGas):
@ -121,7 +126,7 @@ class GrowingNeuralGas(NeuralGas):
# TODO Check if the batch has labels # TODO Check if the batch has labels
x = train_batch[0] x = train_batch[0]
d = self.compute_distances(x) d = self.compute_distances(x)
cost, order = self.energy_layer(d) loss, order = self.energy_layer(d)
winner = order[:, 0] winner = order[:, 0]
mask = torch.zeros_like(d) mask = torch.zeros_like(d)
mask[torch.arange(len(mask)), winner] = 1.0 mask[torch.arange(len(mask)), winner] = 1.0
@ -131,7 +136,8 @@ class GrowingNeuralGas(NeuralGas):
self.errors *= self.hparams.step_reduction self.errors *= self.hparams.step_reduction
self.topology_layer(d) self.topology_layer(d)
return cost self.log("loss", loss)
return loss
def configure_callbacks(self): def configure_callbacks(self):
return [ return [