predict_latent no longer returns numpy

This commit is contained in:
Jensun Ravichandran 2021-05-15 12:52:16 +02:00
parent ebc42a4aa8
commit b7684ae512

View File

@ -171,7 +171,7 @@ class SiameseGLVQ(GLVQ):
latent_protos = self.backbone_dependent(protos)
d = euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()
return y_pred
class GRLVQ(GLVQ):
@ -206,7 +206,7 @@ class GRLVQ(GLVQ):
latent_protos = protos @ torch.diag(self.relevances)
d = squared_euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()
return y_pred
class GMLVQ(GLVQ):
@ -261,7 +261,7 @@ class GMLVQ(GLVQ):
latent_protos = self.omega_layer(protos)
d = squared_euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()
return y_pred
class ImageGMLVQ(GMLVQ, PrototypeImageModel):
@ -306,4 +306,4 @@ class LVQMLN(GLVQ):
latent_protos, plabels = self.proto_layer()
d = euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()
return y_pred