[BUGFIX] examples/gng_iris.py
works again
This commit is contained in:
parent
4eafe88dc4
commit
d2856383e2
@ -29,7 +29,7 @@ if __name__ == "__main__":
|
|||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.GrowingNeuralGas(
|
model = pt.models.GrowingNeuralGas(
|
||||||
hparams,
|
hparams,
|
||||||
prototype_initializer=pt.components.Zeros(2),
|
prototypes_initializer=pt.initializers.ZCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
# Compute intermediate input and output sizes
|
||||||
|
@ -6,6 +6,7 @@ import pytorch_lightning as pl
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..core.components import Components
|
from ..core.components import Components
|
||||||
|
from ..core.initializers import LiteralCompInitializer
|
||||||
from .extras import ConnectionTopology
|
from .extras import ConnectionTopology
|
||||||
|
|
||||||
|
|
||||||
@ -117,7 +118,7 @@ class GNGCallback(pl.Callback):
|
|||||||
|
|
||||||
# Add component
|
# Add component
|
||||||
pl_module.proto_layer.add_components(
|
pl_module.proto_layer.add_components(
|
||||||
initialized_components=new_component.unsqueeze(0))
|
initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
|
||||||
|
|
||||||
# Adjust Topology
|
# Adjust Topology
|
||||||
topology.add_prototype()
|
topology.add_prototype()
|
||||||
|
Loading…
Reference in New Issue
Block a user