predict_latent
no longer returns numpy
This commit is contained in:
parent
ebc42a4aa8
commit
b7684ae512
@ -171,7 +171,7 @@ class SiameseGLVQ(GLVQ):
|
|||||||
latent_protos = self.backbone_dependent(protos)
|
latent_protos = self.backbone_dependent(protos)
|
||||||
d = euclidean_distance(x, latent_protos)
|
d = euclidean_distance(x, latent_protos)
|
||||||
y_pred = wtac(d, plabels)
|
y_pred = wtac(d, plabels)
|
||||||
return y_pred.numpy()
|
return y_pred
|
||||||
|
|
||||||
|
|
||||||
class GRLVQ(GLVQ):
|
class GRLVQ(GLVQ):
|
||||||
@ -206,7 +206,7 @@ class GRLVQ(GLVQ):
|
|||||||
latent_protos = protos @ torch.diag(self.relevances)
|
latent_protos = protos @ torch.diag(self.relevances)
|
||||||
d = squared_euclidean_distance(x, latent_protos)
|
d = squared_euclidean_distance(x, latent_protos)
|
||||||
y_pred = wtac(d, plabels)
|
y_pred = wtac(d, plabels)
|
||||||
return y_pred.numpy()
|
return y_pred
|
||||||
|
|
||||||
|
|
||||||
class GMLVQ(GLVQ):
|
class GMLVQ(GLVQ):
|
||||||
@ -261,7 +261,7 @@ class GMLVQ(GLVQ):
|
|||||||
latent_protos = self.omega_layer(protos)
|
latent_protos = self.omega_layer(protos)
|
||||||
d = squared_euclidean_distance(x, latent_protos)
|
d = squared_euclidean_distance(x, latent_protos)
|
||||||
y_pred = wtac(d, plabels)
|
y_pred = wtac(d, plabels)
|
||||||
return y_pred.numpy()
|
return y_pred
|
||||||
|
|
||||||
|
|
||||||
class ImageGMLVQ(GMLVQ, PrototypeImageModel):
|
class ImageGMLVQ(GMLVQ, PrototypeImageModel):
|
||||||
@ -306,4 +306,4 @@ class LVQMLN(GLVQ):
|
|||||||
latent_protos, plabels = self.proto_layer()
|
latent_protos, plabels = self.proto_layer()
|
||||||
d = euclidean_distance(x, latent_protos)
|
d = euclidean_distance(x, latent_protos)
|
||||||
y_pred = wtac(d, plabels)
|
y_pred = wtac(d, plabels)
|
||||||
return y_pred.numpy()
|
return y_pred
|
||||||
|
Loading…
Reference in New Issue
Block a user