Use the add_components API for adding prototypes
				
					
				
			This commit is contained in:
		| @@ -7,7 +7,6 @@ import pytorch_lightning as pl | |||||||
| import torch | import torch | ||||||
| import torchmetrics | import torchmetrics | ||||||
| from prototorch.components import Components, LabeledComponents | from prototorch.components import Components, LabeledComponents | ||||||
| from prototorch.components import initializers as cinit |  | ||||||
| from prototorch.components.initializers import ZerosInitializer, parse_data_arg | from prototorch.components.initializers import ZerosInitializer, parse_data_arg | ||||||
| from prototorch.functions.competitions import knnc | from prototorch.functions.competitions import knnc | ||||||
| from prototorch.functions.distances import euclidean_distance | from prototorch.functions.distances import euclidean_distance | ||||||
| @@ -55,11 +54,10 @@ class GNGCallback(Callback): | |||||||
|             # New Prototype |             # New Prototype | ||||||
|             new_component = 0.5 * (components[worst] + |             new_component = 0.5 * (components[worst] + | ||||||
|                                    components[worst_neighbour]) |                                    components[worst_neighbour]) | ||||||
|             new_components = torch.vstack([components, new_component]) |  | ||||||
|  |  | ||||||
|             # Add component |             # Add component | ||||||
|             pl_module.proto_layer.register_parameter( |             pl_module.proto_layer.add_components( | ||||||
|                 "_components", torch.nn.parameter.Parameter(new_components)) |                 initialized_components=new_component) | ||||||
|  |  | ||||||
|             # Adjust Topology |             # Adjust Topology | ||||||
|             topology.add_prototype() |             topology.add_prototype() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user