[WIP] Add Growing Neural Gas Energy
This commit is contained in:
parent
946cda00d2
commit
2722d976f5
@ -1,7 +1,6 @@
|
|||||||
"""ProtoTorch losses."""
|
"""ProtoTorch losses."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions.activations import get_activation
|
from prototorch.functions.activations import get_activation
|
||||||
from prototorch.functions.losses import glvq_loss
|
from prototorch.functions.losses import glvq_loss
|
||||||
|
|
||||||
@ -38,3 +37,22 @@ class NeuralGasEnergy(torch.nn.Module):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _nghood_fn(rankings, lm):
|
def _nghood_fn(rankings, lm):
|
||||||
return torch.exp(-rankings / lm)
|
return torch.exp(-rankings / lm)
|
||||||
|
|
||||||
|
|
||||||
|
class GrowingNeuralGasEnergy(NeuralGasEnergy):
|
||||||
|
def __init__(self, topology_layer):
|
||||||
|
super().__init__()
|
||||||
|
self.topology_layer = topology_layer
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _nghood_fn(rankings, topology):
|
||||||
|
winner = rankings[:, 0]
|
||||||
|
|
||||||
|
weights = torch.zeros_like(rankings, dtype=torch.float)
|
||||||
|
weights[torch.arange(rankings.shape[0]), winner] = 1.0
|
||||||
|
|
||||||
|
neighbours = topology.get_neighbours(winner)
|
||||||
|
|
||||||
|
weights[neighbours] = 0.1
|
||||||
|
|
||||||
|
return weights
|
||||||
|
Loading…
Reference in New Issue
Block a user