From 9da47b1dbad3f74c4288ee7cf3be15c03df51369 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 30 Mar 2022 15:10:06 +0200 Subject: [PATCH] fix: CBC example works again --- prototorch/models/abstract.py | 34 ++++++++++++++++++++-------------- prototorch/models/cbc.py | 2 +- prototorch/models/vis.py | 6 ++---- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index f90f0a1..b90ca88 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -109,26 +109,32 @@ class UnsupervisedPrototypeModel(PrototypeModel): class SupervisedPrototypeModel(PrototypeModel): - def __init__(self, hparams, **kwargs): + def __init__(self, hparams, skip_proto_layer=False, **kwargs): super().__init__(hparams, **kwargs) # Layers + distribution = hparams.get("distribution", None) prototypes_initializer = kwargs.get("prototypes_initializer", None) labels_initializer = kwargs.get("labels_initializer", LabelsInitializer()) - if prototypes_initializer is not None: - self.proto_layer = LabeledComponents( - distribution=self.hparams.distribution, - components_initializer=prototypes_initializer, - labels_initializer=labels_initializer, - ) - self.hparams.initialized_proto_dims = self.proto_layer.components.shape[ - 1:] - else: - self.proto_layer = LabeledComponents( - self.hparams.distribution, - ZerosCompInitializer(self.hparams.initialized_proto_dims), - ) + if not skip_proto_layer: + # when subclasses do not need a customized prototype layer + if prototypes_initializer is not None: + # when building a new model + self.proto_layer = LabeledComponents( + distribution=distribution, + components_initializer=prototypes_initializer, + labels_initializer=labels_initializer, + ) + proto_shape = self.proto_layer.components.shape[1:] + self.hparams.initialized_proto_shape = proto_shape + else: + # when restoring a checkpointed model + self.proto_layer = LabeledComponents( + distribution=distribution, + components_initializer=ZerosCompInitializer( + self.hparams.initialized_proto_shape), + ) self.competition_layer = WTAC() @property diff --git a/prototorch/models/cbc.py b/prototorch/models/cbc.py index 8eeb554..fe8aa41 100644 --- a/prototorch/models/cbc.py +++ b/prototorch/models/cbc.py @@ -15,7 +15,7 @@ class CBC(SiameseGLVQ): """Classification-By-Components.""" def __init__(self, hparams, **kwargs): - super().__init__(hparams, **kwargs) + super().__init__(hparams, skip_proto_layer=True, **kwargs) similarity_fn = kwargs.get("similarity_fn", euclidean_similarity) components_initializer = kwargs.get("components_initializer", None) diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 66d19f4..1a04888 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -222,8 +222,7 @@ class VisCBC2D(Vis2DAbstract): def visualize(self, pl_module): x_train, y_train = self.x_train, self.y_train protos = pl_module.components - ax = self.setup_ax(xlabel="Data dimension 1", - ylabel="Data dimension 2") + ax = self.setup_ax() self.plot_data(ax, x_train, y_train) self.plot_protos(ax, protos, "w") x = np.vstack((x_train, protos)) @@ -243,8 +242,7 @@ class VisNG2D(Vis2DAbstract): protos = pl_module.prototypes cmat = pl_module.topology_layer.cmat.cpu().numpy() - ax = self.setup_ax(xlabel="Data dimension 1", - ylabel="Data dimension 2") + ax = self.setup_ax() self.plot_data(ax, x_train, y_train) self.plot_protos(ax, protos, "w")