Use the add_components
API for adding prototypes
This commit is contained in:
parent
1b6843dbbb
commit
aff6aedd60
@ -7,7 +7,6 @@ import pytorch_lightning as pl
|
|||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.components import Components, LabeledComponents
|
from prototorch.components import Components, LabeledComponents
|
||||||
from prototorch.components import initializers as cinit
|
|
||||||
from prototorch.components.initializers import ZerosInitializer, parse_data_arg
|
from prototorch.components.initializers import ZerosInitializer, parse_data_arg
|
||||||
from prototorch.functions.competitions import knnc
|
from prototorch.functions.competitions import knnc
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
@ -26,7 +25,7 @@ class GNGCallback(Callback):
|
|||||||
"""GNG Callback.
|
"""GNG Callback.
|
||||||
|
|
||||||
Applies growing algorithm based on accumulated error and topology.
|
Applies growing algorithm based on accumulated error and topology.
|
||||||
|
|
||||||
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
|
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
|
||||||
"""
|
"""
|
||||||
def __init__(self, reduction=0.1, freq=10):
|
def __init__(self, reduction=0.1, freq=10):
|
||||||
@ -55,11 +54,10 @@ class GNGCallback(Callback):
|
|||||||
# New Prototype
|
# New Prototype
|
||||||
new_component = 0.5 * (components[worst] +
|
new_component = 0.5 * (components[worst] +
|
||||||
components[worst_neighbour])
|
components[worst_neighbour])
|
||||||
new_components = torch.vstack([components, new_component])
|
|
||||||
|
|
||||||
# Add component
|
# Add component
|
||||||
pl_module.proto_layer.register_parameter(
|
pl_module.proto_layer.add_components(
|
||||||
"_components", torch.nn.parameter.Parameter(new_components))
|
initialized_components=new_component)
|
||||||
|
|
||||||
# Adjust Topology
|
# Adjust Topology
|
||||||
topology.add_prototype()
|
topology.add_prototype()
|
||||||
|
Loading…
Reference in New Issue
Block a user