[REFACTOR] Use LambdaLayer instead of EuclideanDistance
This commit is contained in:
parent
ef4d70eee0
commit
98c198d463
@ -10,17 +10,13 @@ from prototorch.components import Components, LabeledComponents
|
|||||||
from prototorch.components.initializers import ZerosInitializer, parse_data_arg
|
from prototorch.components.initializers import ZerosInitializer, parse_data_arg
|
||||||
from prototorch.functions.competitions import knnc
|
from prototorch.functions.competitions import knnc
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
|
from prototorch.modules import LambdaLayer
|
||||||
from prototorch.modules.losses import NeuralGasEnergy
|
from prototorch.modules.losses import NeuralGasEnergy
|
||||||
from pytorch_lightning.callbacks import Callback
|
from pytorch_lightning.callbacks import Callback
|
||||||
|
|
||||||
from .abstract import AbstractPrototypeModel
|
from .abstract import AbstractPrototypeModel
|
||||||
|
|
||||||
|
|
||||||
class EuclideanDistance(torch.nn.Module):
|
|
||||||
def forward(self, x, y):
|
|
||||||
return euclidean_distance(x, y)
|
|
||||||
|
|
||||||
|
|
||||||
class GNGCallback(Callback):
|
class GNGCallback(Callback):
|
||||||
"""GNG Callback.
|
"""GNG Callback.
|
||||||
|
|
||||||
@ -201,7 +197,7 @@ class NeuralGas(AbstractPrototypeModel):
|
|||||||
self.hparams.num_prototypes,
|
self.hparams.num_prototypes,
|
||||||
initializer=self.hparams.prototype_initializer)
|
initializer=self.hparams.prototype_initializer)
|
||||||
|
|
||||||
self.distance_layer = EuclideanDistance()
|
self.distance_layer = LambdaLayer(euclidean_distance)
|
||||||
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
|
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
|
||||||
self.topology_layer = ConnectionTopology(
|
self.topology_layer = ConnectionTopology(
|
||||||
agelimit=self.hparams.agelimit,
|
agelimit=self.hparams.agelimit,
|
||||||
@ -212,8 +208,7 @@ class NeuralGas(AbstractPrototypeModel):
|
|||||||
x = train_batch[0]
|
x = train_batch[0]
|
||||||
protos = self.proto_layer()
|
protos = self.proto_layer()
|
||||||
d = self.distance_layer(x, protos)
|
d = self.distance_layer(x, protos)
|
||||||
cost, order = self.energy_layer(d)
|
cost, _ = self.energy_layer(d)
|
||||||
|
|
||||||
self.topology_layer(d)
|
self.topology_layer(d)
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
@ -235,9 +230,7 @@ class GrowingNeuralGas(NeuralGas):
|
|||||||
protos = self.proto_layer()
|
protos = self.proto_layer()
|
||||||
d = self.distance_layer(x, protos)
|
d = self.distance_layer(x, protos)
|
||||||
cost, order = self.energy_layer(d)
|
cost, order = self.energy_layer(d)
|
||||||
|
|
||||||
winner = order[:, 0]
|
winner = order[:, 0]
|
||||||
|
|
||||||
mask = torch.zeros_like(d)
|
mask = torch.zeros_like(d)
|
||||||
mask[torch.arange(len(mask)), winner] = 1.0
|
mask[torch.arange(len(mask)), winner] = 1.0
|
||||||
winner_distances = d * mask
|
winner_distances = d * mask
|
||||||
|
Loading…
Reference in New Issue
Block a user