From 022d791ea55d980a2f79a0b6177724a24078c0e2 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 7 Jun 2021 21:18:08 +0200 Subject: [PATCH] Route initialized prototypes --- prototorch/models/abstract.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index 6e3deae..03d91e9 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -81,10 +81,12 @@ class UnsupervisedPrototypeModel(PrototypeModel): # Layers prototype_initializer = kwargs.get("prototype_initializer", None) - if prototype_initializer is not None: + initialized_prototypes = kwargs.get("initialized_prototypes", None) + if prototype_initializer is not None or initialized_prototypes is not None: self.proto_layer = Components( self.hparams.num_prototypes, initializer=prototype_initializer, + initialized_components=initialized_prototypes, ) def compute_distances(self, x): @@ -103,10 +105,12 @@ class SupervisedPrototypeModel(PrototypeModel): # Layers prototype_initializer = kwargs.get("prototype_initializer", None) - if prototype_initializer is not None: + initialized_prototypes = kwargs.get("initialized_prototypes", None) + if prototype_initializer is not None or initialized_prototypes is not None: self.proto_layer = LabeledComponents( distribution=self.hparams.distribution, initializer=prototype_initializer, + initialized_components=initialized_prototypes, ) self.competition_layer = WTAC()