[BUGFIX] GNG Example
This commit is contained in:
@@ -64,7 +64,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
def forward(self, x):
|
||||
distances = self._forward(x)
|
||||
y_pred = self.predict_from_distances(distances)
|
||||
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.int()]
|
||||
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.long()]
|
||||
return y_pred
|
||||
|
||||
def predict_from_distances(self, distances):
|
||||
|
Reference in New Issue
Block a user