fix: CBC example works again
This commit is contained in:
parent
41f0e77fc9
commit
9da47b1dba
@ -109,25 +109,31 @@ class UnsupervisedPrototypeModel(PrototypeModel):
|
||||
|
||||
class SupervisedPrototypeModel(PrototypeModel):
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
def __init__(self, hparams, skip_proto_layer=False, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Layers
|
||||
distribution = hparams.get("distribution", None)
|
||||
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
||||
labels_initializer = kwargs.get("labels_initializer",
|
||||
LabelsInitializer())
|
||||
if not skip_proto_layer:
|
||||
# when subclasses do not need a customized prototype layer
|
||||
if prototypes_initializer is not None:
|
||||
# when building a new model
|
||||
self.proto_layer = LabeledComponents(
|
||||
distribution=self.hparams.distribution,
|
||||
distribution=distribution,
|
||||
components_initializer=prototypes_initializer,
|
||||
labels_initializer=labels_initializer,
|
||||
)
|
||||
self.hparams.initialized_proto_dims = self.proto_layer.components.shape[
|
||||
1:]
|
||||
proto_shape = self.proto_layer.components.shape[1:]
|
||||
self.hparams.initialized_proto_shape = proto_shape
|
||||
else:
|
||||
# when restoring a checkpointed model
|
||||
self.proto_layer = LabeledComponents(
|
||||
self.hparams.distribution,
|
||||
ZerosCompInitializer(self.hparams.initialized_proto_dims),
|
||||
distribution=distribution,
|
||||
components_initializer=ZerosCompInitializer(
|
||||
self.hparams.initialized_proto_shape),
|
||||
)
|
||||
self.competition_layer = WTAC()
|
||||
|
||||
|
@ -15,7 +15,7 @@ class CBC(SiameseGLVQ):
|
||||
"""Classification-By-Components."""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
super().__init__(hparams, skip_proto_layer=True, **kwargs)
|
||||
|
||||
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
|
||||
components_initializer = kwargs.get("components_initializer", None)
|
||||
|
@ -222,8 +222,7 @@ class VisCBC2D(Vis2DAbstract):
|
||||
def visualize(self, pl_module):
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
protos = pl_module.components
|
||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2")
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
self.plot_protos(ax, protos, "w")
|
||||
x = np.vstack((x_train, protos))
|
||||
@ -243,8 +242,7 @@ class VisNG2D(Vis2DAbstract):
|
||||
protos = pl_module.prototypes
|
||||
cmat = pl_module.topology_layer.cmat.cpu().numpy()
|
||||
|
||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2")
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
self.plot_protos(ax, protos, "w")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user