diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 5b48f7c..48c195a 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -171,7 +171,7 @@ class SiameseGLVQ(GLVQ): latent_protos = self.backbone_dependent(protos) d = euclidean_distance(x, latent_protos) y_pred = wtac(d, plabels) - return y_pred.numpy() + return y_pred class GRLVQ(GLVQ): @@ -206,7 +206,7 @@ class GRLVQ(GLVQ): latent_protos = protos @ torch.diag(self.relevances) d = squared_euclidean_distance(x, latent_protos) y_pred = wtac(d, plabels) - return y_pred.numpy() + return y_pred class GMLVQ(GLVQ): @@ -261,7 +261,7 @@ class GMLVQ(GLVQ): latent_protos = self.omega_layer(protos) d = squared_euclidean_distance(x, latent_protos) y_pred = wtac(d, plabels) - return y_pred.numpy() + return y_pred class ImageGMLVQ(GMLVQ, PrototypeImageModel): @@ -306,4 +306,4 @@ class LVQMLN(GLVQ): latent_protos, plabels = self.proto_layer() d = euclidean_distance(x, latent_protos) y_pred = wtac(d, plabels) - return y_pred.numpy() + return y_pred