[BUGFIX] Growing neural gas
This commit is contained in:
parent
7379c61966
commit
bda88149d4
@ -8,7 +8,7 @@ from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
|
|||||||
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
||||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||||
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
|
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
|
||||||
from .unsupervised import KNN, NeuralGas
|
from .unsupervised import KNN, GrowingNeuralGas, NeuralGas
|
||||||
from .vis import *
|
from .vis import *
|
||||||
|
|
||||||
__version__ = "0.1.7"
|
__version__ = "0.1.7"
|
||||||
|
@ -54,7 +54,7 @@ class GNGCallback(Callback):
|
|||||||
|
|
||||||
# Add component
|
# Add component
|
||||||
pl_module.proto_layer.add_components(
|
pl_module.proto_layer.add_components(
|
||||||
initialized_components=new_component)
|
initialized_components=new_component.unsqueeze(0))
|
||||||
|
|
||||||
# Adjust Topology
|
# Adjust Topology
|
||||||
topology.add_prototype()
|
topology.add_prototype()
|
||||||
@ -223,8 +223,9 @@ class GrowingNeuralGas(NeuralGas):
|
|||||||
self.hparams.setdefault("insert_reduction", 0.1)
|
self.hparams.setdefault("insert_reduction", 0.1)
|
||||||
self.hparams.setdefault("insert_freq", 10)
|
self.hparams.setdefault("insert_freq", 10)
|
||||||
|
|
||||||
self.register_buffer("errors",
|
self.register_buffer(
|
||||||
torch.zeros(self.hparams.num_prototypes))
|
"errors",
|
||||||
|
torch.zeros(self.hparams.num_prototypes, device=self.device))
|
||||||
|
|
||||||
def training_step(self, train_batch, _batch_idx):
|
def training_step(self, train_batch, _batch_idx):
|
||||||
x = train_batch[0]
|
x = train_batch[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user