Use squared euclidean distance in GMLVQ

This commit is contained in:
Jensun Ravichandran 2021-05-04 14:34:00 +02:00
parent d8e017ae74
commit a1ac5a70c7

View File

@ -3,7 +3,8 @@ import torch
import torchmetrics import torchmetrics
from prototorch.components import LabeledComponents from prototorch.components import LabeledComponents
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import (euclidean_distance,
squared_euclidean_distance)
from prototorch.functions.losses import glvq_loss from prototorch.functions.losses import glvq_loss
from prototorch.modules.prototypes import Prototypes1D from prototorch.modules.prototypes import Prototypes1D
@ -151,7 +152,7 @@ class GMLVQ(GLVQ):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
latent_x = self.omega_layer(x) latent_x = self.omega_layer(x)
latent_protos = self.omega_layer(protos) latent_protos = self.omega_layer(protos)
dis = euclidean_distance(latent_x, latent_protos) dis = squared_euclidean_distance(latent_x, latent_protos)
return dis return dis