[BUGFIX] examples/gng_iris.py works again

This commit is contained in:
Jensun Ravichandran 2021-06-14 20:29:31 +02:00
parent 4eafe88dc4
commit d2856383e2
2 changed files with 3 additions and 2 deletions

View File

@ -29,7 +29,7 @@ if __name__ == "__main__":
# Initialize the model # Initialize the model
model = pt.models.GrowingNeuralGas( model = pt.models.GrowingNeuralGas(
hparams, hparams,
prototype_initializer=pt.components.Zeros(2), prototypes_initializer=pt.initializers.ZCI(2),
) )
# Compute intermediate input and output sizes # Compute intermediate input and output sizes

View File

@ -6,6 +6,7 @@ import pytorch_lightning as pl
import torch import torch
from ..core.components import Components from ..core.components import Components
from ..core.initializers import LiteralCompInitializer
from .extras import ConnectionTopology from .extras import ConnectionTopology
@ -117,7 +118,7 @@ class GNGCallback(pl.Callback):
# Add component # Add component
pl_module.proto_layer.add_components( pl_module.proto_layer.add_components(
initialized_components=new_component.unsqueeze(0)) initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
# Adjust Topology # Adjust Topology
topology.add_prototype() topology.add_prototype()