fix: CBC example works again

This commit is contained in:
Jensun Ravichandran 2022-03-30 15:10:06 +02:00
parent 41f0e77fc9
commit 9da47b1dba
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F
3 changed files with 23 additions and 19 deletions

View File

@ -109,26 +109,32 @@ class UnsupervisedPrototypeModel(PrototypeModel):
class SupervisedPrototypeModel(PrototypeModel): class SupervisedPrototypeModel(PrototypeModel):
def __init__(self, hparams, **kwargs): def __init__(self, hparams, skip_proto_layer=False, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
# Layers # Layers
distribution = hparams.get("distribution", None)
prototypes_initializer = kwargs.get("prototypes_initializer", None) prototypes_initializer = kwargs.get("prototypes_initializer", None)
labels_initializer = kwargs.get("labels_initializer", labels_initializer = kwargs.get("labels_initializer",
LabelsInitializer()) LabelsInitializer())
if prototypes_initializer is not None: if not skip_proto_layer:
self.proto_layer = LabeledComponents( # when subclasses do not need a customized prototype layer
distribution=self.hparams.distribution, if prototypes_initializer is not None:
components_initializer=prototypes_initializer, # when building a new model
labels_initializer=labels_initializer, self.proto_layer = LabeledComponents(
) distribution=distribution,
self.hparams.initialized_proto_dims = self.proto_layer.components.shape[ components_initializer=prototypes_initializer,
1:] labels_initializer=labels_initializer,
else: )
self.proto_layer = LabeledComponents( proto_shape = self.proto_layer.components.shape[1:]
self.hparams.distribution, self.hparams.initialized_proto_shape = proto_shape
ZerosCompInitializer(self.hparams.initialized_proto_dims), else:
) # when restoring a checkpointed model
self.proto_layer = LabeledComponents(
distribution=distribution,
components_initializer=ZerosCompInitializer(
self.hparams.initialized_proto_shape),
)
self.competition_layer = WTAC() self.competition_layer = WTAC()
@property @property

View File

@ -15,7 +15,7 @@ class CBC(SiameseGLVQ):
"""Classification-By-Components.""" """Classification-By-Components."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, skip_proto_layer=True, **kwargs)
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity) similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
components_initializer = kwargs.get("components_initializer", None) components_initializer = kwargs.get("components_initializer", None)

View File

@ -222,8 +222,7 @@ class VisCBC2D(Vis2DAbstract):
def visualize(self, pl_module): def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
protos = pl_module.components protos = pl_module.components
ax = self.setup_ax(xlabel="Data dimension 1", ax = self.setup_ax()
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w") self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
@ -243,8 +242,7 @@ class VisNG2D(Vis2DAbstract):
protos = pl_module.prototypes protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy() cmat = pl_module.topology_layer.cmat.cpu().numpy()
ax = self.setup_ax(xlabel="Data dimension 1", ax = self.setup_ax()
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w") self.plot_protos(ax, protos, "w")