[BUGFIX] Growing neural gas
This commit is contained in:
		| @@ -8,7 +8,7 @@ from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN, | |||||||
|                    ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ) |                    ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ) | ||||||
| from .lvq import LVQ1, LVQ21, MedianLVQ | from .lvq import LVQ1, LVQ21, MedianLVQ | ||||||
| from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ | from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ | ||||||
| from .unsupervised import KNN, NeuralGas | from .unsupervised import KNN, GrowingNeuralGas, NeuralGas | ||||||
| from .vis import * | from .vis import * | ||||||
|  |  | ||||||
| __version__ = "0.1.7" | __version__ = "0.1.7" | ||||||
|   | |||||||
| @@ -54,7 +54,7 @@ class GNGCallback(Callback): | |||||||
|  |  | ||||||
|             # Add component |             # Add component | ||||||
|             pl_module.proto_layer.add_components( |             pl_module.proto_layer.add_components( | ||||||
|                 initialized_components=new_component) |                 initialized_components=new_component.unsqueeze(0)) | ||||||
|  |  | ||||||
|             # Adjust Topology |             # Adjust Topology | ||||||
|             topology.add_prototype() |             topology.add_prototype() | ||||||
| @@ -223,8 +223,9 @@ class GrowingNeuralGas(NeuralGas): | |||||||
|         self.hparams.setdefault("insert_reduction", 0.1) |         self.hparams.setdefault("insert_reduction", 0.1) | ||||||
|         self.hparams.setdefault("insert_freq", 10) |         self.hparams.setdefault("insert_freq", 10) | ||||||
|  |  | ||||||
|         self.register_buffer("errors", |         self.register_buffer( | ||||||
|                              torch.zeros(self.hparams.num_prototypes)) |             "errors", | ||||||
|  |             torch.zeros(self.hparams.num_prototypes, device=self.device)) | ||||||
|  |  | ||||||
|     def training_step(self, train_batch, _batch_idx): |     def training_step(self, train_batch, _batch_idx): | ||||||
|         x = train_batch[0] |         x = train_batch[0] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user