[BUGFIX] examples/gng_iris.py works again
				
					
				
			This commit is contained in:
		@@ -29,7 +29,7 @@ if __name__ == "__main__":
 | 
			
		||||
    # Initialize the model
 | 
			
		||||
    model = pt.models.GrowingNeuralGas(
 | 
			
		||||
        hparams,
 | 
			
		||||
        prototype_initializer=pt.components.Zeros(2),
 | 
			
		||||
        prototypes_initializer=pt.initializers.ZCI(2),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Compute intermediate input and output sizes
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from ..core.components import Components
 | 
			
		||||
from ..core.initializers import LiteralCompInitializer
 | 
			
		||||
from .extras import ConnectionTopology
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -117,7 +118,7 @@ class GNGCallback(pl.Callback):
 | 
			
		||||
 | 
			
		||||
            # Add component
 | 
			
		||||
            pl_module.proto_layer.add_components(
 | 
			
		||||
                initialized_components=new_component.unsqueeze(0))
 | 
			
		||||
                initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
 | 
			
		||||
 | 
			
		||||
            # Adjust Topology
 | 
			
		||||
            topology.add_prototype()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user