Use the add_components
API for adding prototypes
This commit is contained in:
parent
1b6843dbbb
commit
aff6aedd60
@ -7,7 +7,6 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.components import Components, LabeledComponents
|
||||
from prototorch.components import initializers as cinit
|
||||
from prototorch.components.initializers import ZerosInitializer, parse_data_arg
|
||||
from prototorch.functions.competitions import knnc
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
@ -26,7 +25,7 @@ class GNGCallback(Callback):
|
||||
"""GNG Callback.
|
||||
|
||||
Applies growing algorithm based on accumulated error and topology.
|
||||
|
||||
|
||||
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
|
||||
"""
|
||||
def __init__(self, reduction=0.1, freq=10):
|
||||
@ -55,11 +54,10 @@ class GNGCallback(Callback):
|
||||
# New Prototype
|
||||
new_component = 0.5 * (components[worst] +
|
||||
components[worst_neighbour])
|
||||
new_components = torch.vstack([components, new_component])
|
||||
|
||||
# Add component
|
||||
pl_module.proto_layer.register_parameter(
|
||||
"_components", torch.nn.parameter.Parameter(new_components))
|
||||
pl_module.proto_layer.add_components(
|
||||
initialized_components=new_component)
|
||||
|
||||
# Adjust Topology
|
||||
topology.add_prototype()
|
||||
|
Loading…
Reference in New Issue
Block a user