[BUGFIX] examples/gng_iris.py works again
				
					
				
			This commit is contained in:
		@@ -29,7 +29,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.GrowingNeuralGas(
 | 
					    model = pt.models.GrowingNeuralGas(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        prototype_initializer=pt.components.Zeros(2),
 | 
					        prototypes_initializer=pt.initializers.ZCI(2),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Compute intermediate input and output sizes
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,6 +6,7 @@ import pytorch_lightning as pl
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..core.components import Components
 | 
					from ..core.components import Components
 | 
				
			||||||
 | 
					from ..core.initializers import LiteralCompInitializer
 | 
				
			||||||
from .extras import ConnectionTopology
 | 
					from .extras import ConnectionTopology
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -117,7 +118,7 @@ class GNGCallback(pl.Callback):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            # Add component
 | 
					            # Add component
 | 
				
			||||||
            pl_module.proto_layer.add_components(
 | 
					            pl_module.proto_layer.add_components(
 | 
				
			||||||
                initialized_components=new_component.unsqueeze(0))
 | 
					                initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Adjust Topology
 | 
					            # Adjust Topology
 | 
				
			||||||
            topology.add_prototype()
 | 
					            topology.add_prototype()
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user