[BUGFIX] Log loss in NG and GNG
This commit is contained in:
		@@ -99,9 +99,14 @@ class NeuralGas(UnsupervisedPrototypeModel):
 | 
				
			|||||||
        # TODO Check if the batch has labels
 | 
					        # TODO Check if the batch has labels
 | 
				
			||||||
        x = train_batch[0]
 | 
					        x = train_batch[0]
 | 
				
			||||||
        d = self.compute_distances(x)
 | 
					        d = self.compute_distances(x)
 | 
				
			||||||
        cost, _ = self.energy_layer(d)
 | 
					        loss, _ = self.energy_layer(d)
 | 
				
			||||||
        self.topology_layer(d)
 | 
					        self.topology_layer(d)
 | 
				
			||||||
        return cost
 | 
					        self.log("loss", loss)
 | 
				
			||||||
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # def training_epoch_end(self, training_step_outputs):
 | 
				
			||||||
 | 
					    #     print(f"{self.trainer.lr_schedulers}")
 | 
				
			||||||
 | 
					    #     print(f"{self.trainer.lr_schedulers[0]['scheduler'].optimizer}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GrowingNeuralGas(NeuralGas):
 | 
					class GrowingNeuralGas(NeuralGas):
 | 
				
			||||||
@@ -121,7 +126,7 @@ class GrowingNeuralGas(NeuralGas):
 | 
				
			|||||||
        # TODO Check if the batch has labels
 | 
					        # TODO Check if the batch has labels
 | 
				
			||||||
        x = train_batch[0]
 | 
					        x = train_batch[0]
 | 
				
			||||||
        d = self.compute_distances(x)
 | 
					        d = self.compute_distances(x)
 | 
				
			||||||
        cost, order = self.energy_layer(d)
 | 
					        loss, order = self.energy_layer(d)
 | 
				
			||||||
        winner = order[:, 0]
 | 
					        winner = order[:, 0]
 | 
				
			||||||
        mask = torch.zeros_like(d)
 | 
					        mask = torch.zeros_like(d)
 | 
				
			||||||
        mask[torch.arange(len(mask)), winner] = 1.0
 | 
					        mask[torch.arange(len(mask)), winner] = 1.0
 | 
				
			||||||
@@ -131,7 +136,8 @@ class GrowingNeuralGas(NeuralGas):
 | 
				
			|||||||
        self.errors *= self.hparams.step_reduction
 | 
					        self.errors *= self.hparams.step_reduction
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.topology_layer(d)
 | 
					        self.topology_layer(d)
 | 
				
			||||||
        return cost
 | 
					        self.log("loss", loss)
 | 
				
			||||||
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def configure_callbacks(self):
 | 
					    def configure_callbacks(self):
 | 
				
			||||||
        return [
 | 
					        return [
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user