From bda88149d41419a6e127f0351496afc97fa73522 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Thu, 3 Jun 2021 15:13:27 +0200 Subject: [PATCH] [BUGFIX] Growing neural gas --- prototorch/models/__init__.py | 2 +- prototorch/models/unsupervised.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/prototorch/models/__init__.py b/prototorch/models/__init__.py index 251370c..1d8d158 100644 --- a/prototorch/models/__init__.py +++ b/prototorch/models/__init__.py @@ -8,7 +8,7 @@ from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN, ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ) from .lvq import LVQ1, LVQ21, MedianLVQ from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ -from .unsupervised import KNN, NeuralGas +from .unsupervised import KNN, GrowingNeuralGas, NeuralGas from .vis import * __version__ = "0.1.7" diff --git a/prototorch/models/unsupervised.py b/prototorch/models/unsupervised.py index 1d31bd3..8f691bf 100644 --- a/prototorch/models/unsupervised.py +++ b/prototorch/models/unsupervised.py @@ -54,7 +54,7 @@ class GNGCallback(Callback): # Add component pl_module.proto_layer.add_components( - initialized_components=new_component) + initialized_components=new_component.unsqueeze(0)) # Adjust Topology topology.add_prototype() @@ -223,8 +223,9 @@ class GrowingNeuralGas(NeuralGas): self.hparams.setdefault("insert_reduction", 0.1) self.hparams.setdefault("insert_freq", 10) - self.register_buffer("errors", - torch.zeros(self.hparams.num_prototypes)) + self.register_buffer( + "errors", + torch.zeros(self.hparams.num_prototypes, device=self.device)) def training_step(self, train_batch, _batch_idx): x = train_batch[0]