predict_latent no longer returns numpy
				
					
				
			This commit is contained in:
		@@ -171,7 +171,7 @@ class SiameseGLVQ(GLVQ):
 | 
				
			|||||||
            latent_protos = self.backbone_dependent(protos)
 | 
					            latent_protos = self.backbone_dependent(protos)
 | 
				
			||||||
            d = euclidean_distance(x, latent_protos)
 | 
					            d = euclidean_distance(x, latent_protos)
 | 
				
			||||||
            y_pred = wtac(d, plabels)
 | 
					            y_pred = wtac(d, plabels)
 | 
				
			||||||
        return y_pred.numpy()
 | 
					        return y_pred
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GRLVQ(GLVQ):
 | 
					class GRLVQ(GLVQ):
 | 
				
			||||||
@@ -206,7 +206,7 @@ class GRLVQ(GLVQ):
 | 
				
			|||||||
            latent_protos = protos @ torch.diag(self.relevances)
 | 
					            latent_protos = protos @ torch.diag(self.relevances)
 | 
				
			||||||
            d = squared_euclidean_distance(x, latent_protos)
 | 
					            d = squared_euclidean_distance(x, latent_protos)
 | 
				
			||||||
            y_pred = wtac(d, plabels)
 | 
					            y_pred = wtac(d, plabels)
 | 
				
			||||||
        return y_pred.numpy()
 | 
					        return y_pred
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GMLVQ(GLVQ):
 | 
					class GMLVQ(GLVQ):
 | 
				
			||||||
@@ -261,7 +261,7 @@ class GMLVQ(GLVQ):
 | 
				
			|||||||
            latent_protos = self.omega_layer(protos)
 | 
					            latent_protos = self.omega_layer(protos)
 | 
				
			||||||
            d = squared_euclidean_distance(x, latent_protos)
 | 
					            d = squared_euclidean_distance(x, latent_protos)
 | 
				
			||||||
            y_pred = wtac(d, plabels)
 | 
					            y_pred = wtac(d, plabels)
 | 
				
			||||||
        return y_pred.numpy()
 | 
					        return y_pred
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ImageGMLVQ(GMLVQ, PrototypeImageModel):
 | 
					class ImageGMLVQ(GMLVQ, PrototypeImageModel):
 | 
				
			||||||
@@ -306,4 +306,4 @@ class LVQMLN(GLVQ):
 | 
				
			|||||||
            latent_protos, plabels = self.proto_layer()
 | 
					            latent_protos, plabels = self.proto_layer()
 | 
				
			||||||
            d = euclidean_distance(x, latent_protos)
 | 
					            d = euclidean_distance(x, latent_protos)
 | 
				
			||||||
            y_pred = wtac(d, plabels)
 | 
					            y_pred = wtac(d, plabels)
 | 
				
			||||||
        return y_pred.numpy()
 | 
					        return y_pred
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user