Use the add_components API for adding prototypes

This commit is contained in:
Jensun Ravichandran 2021-06-01 23:37:45 +02:00
parent 1b6843dbbb
commit aff6aedd60

View File

@ -7,7 +7,6 @@ import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.components import Components, LabeledComponents
from prototorch.components import initializers as cinit
from prototorch.components.initializers import ZerosInitializer, parse_data_arg
from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance
@ -26,7 +25,7 @@ class GNGCallback(Callback):
"""GNG Callback.
Applies growing algorithm based on accumulated error and topology.
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
"""
def __init__(self, reduction=0.1, freq=10):
@ -55,11 +54,10 @@ class GNGCallback(Callback):
# New Prototype
new_component = 0.5 * (components[worst] +
components[worst_neighbour])
new_components = torch.vstack([components, new_component])
# Add component
pl_module.proto_layer.register_parameter(
"_components", torch.nn.parameter.Parameter(new_components))
pl_module.proto_layer.add_components(
initialized_components=new_component)
# Adjust Topology
topology.add_prototype()