From aff6aedd60724d4ce724825ee92ef4dd5f4a6ca2 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 1 Jun 2021 23:37:45 +0200 Subject: [PATCH] Use the `add_components` API for adding prototypes --- prototorch/models/unsupervised.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/prototorch/models/unsupervised.py b/prototorch/models/unsupervised.py index fb1b1de..0060c6a 100644 --- a/prototorch/models/unsupervised.py +++ b/prototorch/models/unsupervised.py @@ -7,7 +7,6 @@ import pytorch_lightning as pl import torch import torchmetrics from prototorch.components import Components, LabeledComponents -from prototorch.components import initializers as cinit from prototorch.components.initializers import ZerosInitializer, parse_data_arg from prototorch.functions.competitions import knnc from prototorch.functions.distances import euclidean_distance @@ -26,7 +25,7 @@ class GNGCallback(Callback): """GNG Callback. Applies growing algorithm based on accumulated error and topology. - + Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke. """ def __init__(self, reduction=0.1, freq=10): @@ -55,11 +54,10 @@ class GNGCallback(Callback): # New Prototype new_component = 0.5 * (components[worst] + components[worst_neighbour]) - new_components = torch.vstack([components, new_component]) # Add component - pl_module.proto_layer.register_parameter( - "_components", torch.nn.parameter.Parameter(new_components)) + pl_module.proto_layer.add_components( + initialized_components=new_component) # Adjust Topology topology.add_prototype()