[BUGFIX] Growing neural gas

This commit is contained in:
Alexander Engelsberger 2021-06-03 15:13:27 +02:00
parent 7379c61966
commit bda88149d4
2 changed files with 5 additions and 4 deletions

View File

@ -8,7 +8,7 @@ from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
from .unsupervised import KNN, NeuralGas
from .unsupervised import KNN, GrowingNeuralGas, NeuralGas
from .vis import *
__version__ = "0.1.7"

View File

@ -54,7 +54,7 @@ class GNGCallback(Callback):
# Add component
pl_module.proto_layer.add_components(
initialized_components=new_component)
initialized_components=new_component.unsqueeze(0))
# Adjust Topology
topology.add_prototype()
@ -223,8 +223,9 @@ class GrowingNeuralGas(NeuralGas):
self.hparams.setdefault("insert_reduction", 0.1)
self.hparams.setdefault("insert_freq", 10)
self.register_buffer("errors",
torch.zeros(self.hparams.num_prototypes))
self.register_buffer(
"errors",
torch.zeros(self.hparams.num_prototypes, device=self.device))
def training_step(self, train_batch, _batch_idx):
x = train_batch[0]