fix: CBC example works again
This commit is contained in:
parent
41f0e77fc9
commit
9da47b1dba
@ -109,25 +109,31 @@ class UnsupervisedPrototypeModel(PrototypeModel):
|
|||||||
|
|
||||||
class SupervisedPrototypeModel(PrototypeModel):
|
class SupervisedPrototypeModel(PrototypeModel):
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, skip_proto_layer=False, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
|
distribution = hparams.get("distribution", None)
|
||||||
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
||||||
labels_initializer = kwargs.get("labels_initializer",
|
labels_initializer = kwargs.get("labels_initializer",
|
||||||
LabelsInitializer())
|
LabelsInitializer())
|
||||||
|
if not skip_proto_layer:
|
||||||
|
# when subclasses do not need a customized prototype layer
|
||||||
if prototypes_initializer is not None:
|
if prototypes_initializer is not None:
|
||||||
|
# when building a new model
|
||||||
self.proto_layer = LabeledComponents(
|
self.proto_layer = LabeledComponents(
|
||||||
distribution=self.hparams.distribution,
|
distribution=distribution,
|
||||||
components_initializer=prototypes_initializer,
|
components_initializer=prototypes_initializer,
|
||||||
labels_initializer=labels_initializer,
|
labels_initializer=labels_initializer,
|
||||||
)
|
)
|
||||||
self.hparams.initialized_proto_dims = self.proto_layer.components.shape[
|
proto_shape = self.proto_layer.components.shape[1:]
|
||||||
1:]
|
self.hparams.initialized_proto_shape = proto_shape
|
||||||
else:
|
else:
|
||||||
|
# when restoring a checkpointed model
|
||||||
self.proto_layer = LabeledComponents(
|
self.proto_layer = LabeledComponents(
|
||||||
self.hparams.distribution,
|
distribution=distribution,
|
||||||
ZerosCompInitializer(self.hparams.initialized_proto_dims),
|
components_initializer=ZerosCompInitializer(
|
||||||
|
self.hparams.initialized_proto_shape),
|
||||||
)
|
)
|
||||||
self.competition_layer = WTAC()
|
self.competition_layer = WTAC()
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ class CBC(SiameseGLVQ):
|
|||||||
"""Classification-By-Components."""
|
"""Classification-By-Components."""
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, skip_proto_layer=True, **kwargs)
|
||||||
|
|
||||||
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
|
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
|
||||||
components_initializer = kwargs.get("components_initializer", None)
|
components_initializer = kwargs.get("components_initializer", None)
|
||||||
|
@ -222,8 +222,7 @@ class VisCBC2D(Vis2DAbstract):
|
|||||||
def visualize(self, pl_module):
|
def visualize(self, pl_module):
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
protos = pl_module.components
|
protos = pl_module.components
|
||||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
ax = self.setup_ax()
|
||||||
ylabel="Data dimension 2")
|
|
||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
self.plot_protos(ax, protos, "w")
|
self.plot_protos(ax, protos, "w")
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
@ -243,8 +242,7 @@ class VisNG2D(Vis2DAbstract):
|
|||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
cmat = pl_module.topology_layer.cmat.cpu().numpy()
|
cmat = pl_module.topology_layer.cmat.cpu().numpy()
|
||||||
|
|
||||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
ax = self.setup_ax()
|
||||||
ylabel="Data dimension 2")
|
|
||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
self.plot_protos(ax, protos, "w")
|
self.plot_protos(ax, protos, "w")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user