Use squared euclidean distance in GMLVQ
This commit is contained in:
parent
d8e017ae74
commit
a1ac5a70c7
@ -3,7 +3,8 @@ import torch
|
||||
import torchmetrics
|
||||
from prototorch.components import LabeledComponents
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.functions.distances import (euclidean_distance,
|
||||
squared_euclidean_distance)
|
||||
from prototorch.functions.losses import glvq_loss
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
@ -151,7 +152,7 @@ class GMLVQ(GLVQ):
|
||||
protos, _ = self.proto_layer()
|
||||
latent_x = self.omega_layer(x)
|
||||
latent_protos = self.omega_layer(protos)
|
||||
dis = euclidean_distance(latent_x, latent_protos)
|
||||
dis = squared_euclidean_distance(latent_x, latent_protos)
|
||||
return dis
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user