[BUGFIX] GNG Example

This commit is contained in:
Alexander Engelsberger
2021-06-03 15:42:54 +02:00
parent 0bc385fe7b
commit 47db1965ee
2 changed files with 10 additions and 6 deletions

View File

@@ -64,7 +64,7 @@ class GLVQ(AbstractPrototypeModel):
def forward(self, x):
distances = self._forward(x)
y_pred = self.predict_from_distances(distances)
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.int()]
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.long()]
return y_pred
def predict_from_distances(self, distances):