[BUGFIX] Growing neural gas

This commit is contained in:
Alexander Engelsberger 2021-06-03 15:13:27 +02:00
parent 7379c61966
commit bda88149d4
2 changed files with 5 additions and 4 deletions

View File

@ -8,7 +8,7 @@ from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ) ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
from .lvq import LVQ1, LVQ21, MedianLVQ from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
from .unsupervised import KNN, NeuralGas from .unsupervised import KNN, GrowingNeuralGas, NeuralGas
from .vis import * from .vis import *
__version__ = "0.1.7" __version__ = "0.1.7"

View File

@ -54,7 +54,7 @@ class GNGCallback(Callback):
# Add component # Add component
pl_module.proto_layer.add_components( pl_module.proto_layer.add_components(
initialized_components=new_component) initialized_components=new_component.unsqueeze(0))
# Adjust Topology # Adjust Topology
topology.add_prototype() topology.add_prototype()
@ -223,8 +223,9 @@ class GrowingNeuralGas(NeuralGas):
self.hparams.setdefault("insert_reduction", 0.1) self.hparams.setdefault("insert_reduction", 0.1)
self.hparams.setdefault("insert_freq", 10) self.hparams.setdefault("insert_freq", 10)
self.register_buffer("errors", self.register_buffer(
torch.zeros(self.hparams.num_prototypes)) "errors",
torch.zeros(self.hparams.num_prototypes, device=self.device))
def training_step(self, train_batch, _batch_idx): def training_step(self, train_batch, _batch_idx):
x = train_batch[0] x = train_batch[0]