Use the add_components API for adding prototypes

This commit is contained in:
Jensun Ravichandran 2021-06-01 23:37:45 +02:00
parent 1b6843dbbb
commit aff6aedd60

View File

@ -7,7 +7,6 @@ import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from prototorch.components import Components, LabeledComponents from prototorch.components import Components, LabeledComponents
from prototorch.components import initializers as cinit
from prototorch.components.initializers import ZerosInitializer, parse_data_arg from prototorch.components.initializers import ZerosInitializer, parse_data_arg
from prototorch.functions.competitions import knnc from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
@ -26,7 +25,7 @@ class GNGCallback(Callback):
"""GNG Callback. """GNG Callback.
Applies growing algorithm based on accumulated error and topology. Applies growing algorithm based on accumulated error and topology.
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke. Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
""" """
def __init__(self, reduction=0.1, freq=10): def __init__(self, reduction=0.1, freq=10):
@ -55,11 +54,10 @@ class GNGCallback(Callback):
# New Prototype # New Prototype
new_component = 0.5 * (components[worst] + new_component = 0.5 * (components[worst] +
components[worst_neighbour]) components[worst_neighbour])
new_components = torch.vstack([components, new_component])
# Add component # Add component
pl_module.proto_layer.register_parameter( pl_module.proto_layer.add_components(
"_components", torch.nn.parameter.Parameter(new_components)) initialized_components=new_component)
# Adjust Topology # Adjust Topology
topology.add_prototype() topology.add_prototype()