Use squared euclidean distance in GMLVQ
This commit is contained in:
parent
d8e017ae74
commit
a1ac5a70c7
@ -3,7 +3,8 @@ import torch
|
|||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.components import LabeledComponents
|
from prototorch.components import LabeledComponents
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import (euclidean_distance,
|
||||||
|
squared_euclidean_distance)
|
||||||
from prototorch.functions.losses import glvq_loss
|
from prototorch.functions.losses import glvq_loss
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
|
|
||||||
@ -151,7 +152,7 @@ class GMLVQ(GLVQ):
|
|||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
latent_x = self.omega_layer(x)
|
latent_x = self.omega_layer(x)
|
||||||
latent_protos = self.omega_layer(protos)
|
latent_protos = self.omega_layer(protos)
|
||||||
dis = euclidean_distance(latent_x, latent_protos)
|
dis = squared_euclidean_distance(latent_x, latent_protos)
|
||||||
return dis
|
return dis
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user