[FEATURE] Add warm-starting example

This commit is contained in:
Jensun Ravichandran
2021-06-14 20:42:57 +02:00
parent 1911d4b33e
commit 1c658cdc1b
3 changed files with 89 additions and 0 deletions

View File

@@ -9,6 +9,7 @@ import torchmetrics
from ..core.competitions import WTAC
from ..core.components import Components, LabeledComponents
from ..core.distances import euclidean_distance
from ..core.initializers import LabelsInitializer
from ..core.pooling import stratified_min_pooling
from ..nn.wrappers import LambdaLayer
@@ -111,10 +112,13 @@ class SupervisedPrototypeModel(PrototypeModel):
# Layers
prototypes_initializer = kwargs.get("prototypes_initializer", None)
labels_initializer = kwargs.get("labels_initializer",
LabelsInitializer())
if prototypes_initializer is not None:
self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution,
components_initializer=prototypes_initializer,
labels_initializer=labels_initializer,
)
self.competition_layer = WTAC()

View File

@@ -118,6 +118,7 @@ class GNGCallback(pl.Callback):
# Add component
pl_module.proto_layer.add_components(
None,
initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
# Adjust Topology