fix: CBC example works again

This commit is contained in:
Jensun Ravichandran 2022-03-30 15:10:06 +02:00
parent 41f0e77fc9
commit 9da47b1dba
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F
3 changed files with 23 additions and 19 deletions

View File

@ -109,25 +109,31 @@ class UnsupervisedPrototypeModel(PrototypeModel):
class SupervisedPrototypeModel(PrototypeModel):
def __init__(self, hparams, **kwargs):
def __init__(self, hparams, skip_proto_layer=False, **kwargs):
super().__init__(hparams, **kwargs)
# Layers
distribution = hparams.get("distribution", None)
prototypes_initializer = kwargs.get("prototypes_initializer", None)
labels_initializer = kwargs.get("labels_initializer",
LabelsInitializer())
if not skip_proto_layer:
# when subclasses do not need a customized prototype layer
if prototypes_initializer is not None:
# when building a new model
self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution,
distribution=distribution,
components_initializer=prototypes_initializer,
labels_initializer=labels_initializer,
)
self.hparams.initialized_proto_dims = self.proto_layer.components.shape[
1:]
proto_shape = self.proto_layer.components.shape[1:]
self.hparams.initialized_proto_shape = proto_shape
else:
# when restoring a checkpointed model
self.proto_layer = LabeledComponents(
self.hparams.distribution,
ZerosCompInitializer(self.hparams.initialized_proto_dims),
distribution=distribution,
components_initializer=ZerosCompInitializer(
self.hparams.initialized_proto_shape),
)
self.competition_layer = WTAC()

View File

@ -15,7 +15,7 @@ class CBC(SiameseGLVQ):
"""Classification-By-Components."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
super().__init__(hparams, skip_proto_layer=True, **kwargs)
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
components_initializer = kwargs.get("components_initializer", None)

View File

@ -222,8 +222,7 @@ class VisCBC2D(Vis2DAbstract):
def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train
protos = pl_module.components
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos))
@ -243,8 +242,7 @@ class VisNG2D(Vis2DAbstract):
protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy()
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")