[BUGFIX] Growing neural gas
This commit is contained in:
parent
7379c61966
commit
bda88149d4
@ -8,7 +8,7 @@ from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
|
||||
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
|
||||
from .unsupervised import KNN, NeuralGas
|
||||
from .unsupervised import KNN, GrowingNeuralGas, NeuralGas
|
||||
from .vis import *
|
||||
|
||||
__version__ = "0.1.7"
|
||||
|
@ -54,7 +54,7 @@ class GNGCallback(Callback):
|
||||
|
||||
# Add component
|
||||
pl_module.proto_layer.add_components(
|
||||
initialized_components=new_component)
|
||||
initialized_components=new_component.unsqueeze(0))
|
||||
|
||||
# Adjust Topology
|
||||
topology.add_prototype()
|
||||
@ -223,8 +223,9 @@ class GrowingNeuralGas(NeuralGas):
|
||||
self.hparams.setdefault("insert_reduction", 0.1)
|
||||
self.hparams.setdefault("insert_freq", 10)
|
||||
|
||||
self.register_buffer("errors",
|
||||
torch.zeros(self.hparams.num_prototypes))
|
||||
self.register_buffer(
|
||||
"errors",
|
||||
torch.zeros(self.hparams.num_prototypes, device=self.device))
|
||||
|
||||
def training_step(self, train_batch, _batch_idx):
|
||||
x = train_batch[0]
|
||||
|
Loading…
Reference in New Issue
Block a user