Route initialized prototypes
This commit is contained in:
parent
43fc7d1678
commit
022d791ea5
@ -81,10 +81,12 @@ class UnsupervisedPrototypeModel(PrototypeModel):
|
|||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||||
if prototype_initializer is not None:
|
initialized_prototypes = kwargs.get("initialized_prototypes", None)
|
||||||
|
if prototype_initializer is not None or initialized_prototypes is not None:
|
||||||
self.proto_layer = Components(
|
self.proto_layer = Components(
|
||||||
self.hparams.num_prototypes,
|
self.hparams.num_prototypes,
|
||||||
initializer=prototype_initializer,
|
initializer=prototype_initializer,
|
||||||
|
initialized_components=initialized_prototypes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
@ -103,10 +105,12 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||||
if prototype_initializer is not None:
|
initialized_prototypes = kwargs.get("initialized_prototypes", None)
|
||||||
|
if prototype_initializer is not None or initialized_prototypes is not None:
|
||||||
self.proto_layer = LabeledComponents(
|
self.proto_layer = LabeledComponents(
|
||||||
distribution=self.hparams.distribution,
|
distribution=self.hparams.distribution,
|
||||||
initializer=prototype_initializer,
|
initializer=prototype_initializer,
|
||||||
|
initialized_components=initialized_prototypes,
|
||||||
)
|
)
|
||||||
self.competition_layer = WTAC()
|
self.competition_layer = WTAC()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user