fix: CBC example works again
This commit is contained in:
		@@ -109,26 +109,32 @@ class UnsupervisedPrototypeModel(PrototypeModel):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class SupervisedPrototypeModel(PrototypeModel):
 | 
					class SupervisedPrototypeModel(PrototypeModel):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, hparams, **kwargs):
 | 
					    def __init__(self, hparams, skip_proto_layer=False, **kwargs):
 | 
				
			||||||
        super().__init__(hparams, **kwargs)
 | 
					        super().__init__(hparams, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Layers
 | 
					        # Layers
 | 
				
			||||||
 | 
					        distribution = hparams.get("distribution", None)
 | 
				
			||||||
        prototypes_initializer = kwargs.get("prototypes_initializer", None)
 | 
					        prototypes_initializer = kwargs.get("prototypes_initializer", None)
 | 
				
			||||||
        labels_initializer = kwargs.get("labels_initializer",
 | 
					        labels_initializer = kwargs.get("labels_initializer",
 | 
				
			||||||
                                        LabelsInitializer())
 | 
					                                        LabelsInitializer())
 | 
				
			||||||
        if prototypes_initializer is not None:
 | 
					        if not skip_proto_layer:
 | 
				
			||||||
            self.proto_layer = LabeledComponents(
 | 
					            # when subclasses do not need a customized prototype layer
 | 
				
			||||||
                distribution=self.hparams.distribution,
 | 
					            if prototypes_initializer is not None:
 | 
				
			||||||
                components_initializer=prototypes_initializer,
 | 
					                # when building a new model
 | 
				
			||||||
                labels_initializer=labels_initializer,
 | 
					                self.proto_layer = LabeledComponents(
 | 
				
			||||||
            )
 | 
					                    distribution=distribution,
 | 
				
			||||||
            self.hparams.initialized_proto_dims = self.proto_layer.components.shape[
 | 
					                    components_initializer=prototypes_initializer,
 | 
				
			||||||
                1:]
 | 
					                    labels_initializer=labels_initializer,
 | 
				
			||||||
        else:
 | 
					                )
 | 
				
			||||||
            self.proto_layer = LabeledComponents(
 | 
					                proto_shape = self.proto_layer.components.shape[1:]
 | 
				
			||||||
                self.hparams.distribution,
 | 
					                self.hparams.initialized_proto_shape = proto_shape
 | 
				
			||||||
                ZerosCompInitializer(self.hparams.initialized_proto_dims),
 | 
					            else:
 | 
				
			||||||
            )
 | 
					                # when restoring a checkpointed model
 | 
				
			||||||
 | 
					                self.proto_layer = LabeledComponents(
 | 
				
			||||||
 | 
					                    distribution=distribution,
 | 
				
			||||||
 | 
					                    components_initializer=ZerosCompInitializer(
 | 
				
			||||||
 | 
					                        self.hparams.initialized_proto_shape),
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
        self.competition_layer = WTAC()
 | 
					        self.competition_layer = WTAC()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,7 +15,7 @@ class CBC(SiameseGLVQ):
 | 
				
			|||||||
    """Classification-By-Components."""
 | 
					    """Classification-By-Components."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, hparams, **kwargs):
 | 
					    def __init__(self, hparams, **kwargs):
 | 
				
			||||||
        super().__init__(hparams, **kwargs)
 | 
					        super().__init__(hparams, skip_proto_layer=True, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
 | 
					        similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
 | 
				
			||||||
        components_initializer = kwargs.get("components_initializer", None)
 | 
					        components_initializer = kwargs.get("components_initializer", None)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -222,8 +222,7 @@ class VisCBC2D(Vis2DAbstract):
 | 
				
			|||||||
    def visualize(self, pl_module):
 | 
					    def visualize(self, pl_module):
 | 
				
			||||||
        x_train, y_train = self.x_train, self.y_train
 | 
					        x_train, y_train = self.x_train, self.y_train
 | 
				
			||||||
        protos = pl_module.components
 | 
					        protos = pl_module.components
 | 
				
			||||||
        ax = self.setup_ax(xlabel="Data dimension 1",
 | 
					        ax = self.setup_ax()
 | 
				
			||||||
                           ylabel="Data dimension 2")
 | 
					 | 
				
			||||||
        self.plot_data(ax, x_train, y_train)
 | 
					        self.plot_data(ax, x_train, y_train)
 | 
				
			||||||
        self.plot_protos(ax, protos, "w")
 | 
					        self.plot_protos(ax, protos, "w")
 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
@@ -243,8 +242,7 @@ class VisNG2D(Vis2DAbstract):
 | 
				
			|||||||
        protos = pl_module.prototypes
 | 
					        protos = pl_module.prototypes
 | 
				
			||||||
        cmat = pl_module.topology_layer.cmat.cpu().numpy()
 | 
					        cmat = pl_module.topology_layer.cmat.cpu().numpy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ax = self.setup_ax(xlabel="Data dimension 1",
 | 
					        ax = self.setup_ax()
 | 
				
			||||||
                           ylabel="Data dimension 2")
 | 
					 | 
				
			||||||
        self.plot_data(ax, x_train, y_train)
 | 
					        self.plot_data(ax, x_train, y_train)
 | 
				
			||||||
        self.plot_protos(ax, protos, "w")
 | 
					        self.plot_protos(ax, protos, "w")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user