Route initialized prototypes

This commit is contained in:
Jensun Ravichandran 2021-06-07 21:18:08 +02:00
parent 43fc7d1678
commit 022d791ea5

View File

@ -81,10 +81,12 @@ class UnsupervisedPrototypeModel(PrototypeModel):
# Layers # Layers
prototype_initializer = kwargs.get("prototype_initializer", None) prototype_initializer = kwargs.get("prototype_initializer", None)
if prototype_initializer is not None: initialized_prototypes = kwargs.get("initialized_prototypes", None)
if prototype_initializer is not None or initialized_prototypes is not None:
self.proto_layer = Components( self.proto_layer = Components(
self.hparams.num_prototypes, self.hparams.num_prototypes,
initializer=prototype_initializer, initializer=prototype_initializer,
initialized_components=initialized_prototypes,
) )
def compute_distances(self, x): def compute_distances(self, x):
@ -103,10 +105,12 @@ class SupervisedPrototypeModel(PrototypeModel):
# Layers # Layers
prototype_initializer = kwargs.get("prototype_initializer", None) prototype_initializer = kwargs.get("prototype_initializer", None)
if prototype_initializer is not None: initialized_prototypes = kwargs.get("initialized_prototypes", None)
if prototype_initializer is not None or initialized_prototypes is not None:
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution, distribution=self.hparams.distribution,
initializer=prototype_initializer, initializer=prototype_initializer,
initialized_components=initialized_prototypes,
) )
self.competition_layer = WTAC() self.competition_layer = WTAC()