[BUGFIX] KNN works again
This commit is contained in:
		@@ -88,13 +88,11 @@ class UnsupervisedPrototypeModel(PrototypeModel):
 | 
				
			|||||||
        super().__init__(hparams, **kwargs)
 | 
					        super().__init__(hparams, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Layers
 | 
					        # Layers
 | 
				
			||||||
        prototype_initializer = kwargs.get("prototype_initializer", None)
 | 
					        prototypes_initializer = kwargs.get("prototypes_initializer", None)
 | 
				
			||||||
        initialized_prototypes = kwargs.get("initialized_prototypes", None)
 | 
					        if prototypes_initializer is not None:
 | 
				
			||||||
        if prototype_initializer is not None or initialized_prototypes is not None:
 | 
					 | 
				
			||||||
            self.proto_layer = Components(
 | 
					            self.proto_layer = Components(
 | 
				
			||||||
                self.hparams.num_prototypes,
 | 
					                self.hparams.num_prototypes,
 | 
				
			||||||
                initializer=prototype_initializer,
 | 
					                initializer=prototypes_initializer,
 | 
				
			||||||
                initialized_components=initialized_prototypes,
 | 
					 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def compute_distances(self, x):
 | 
					    def compute_distances(self, x):
 | 
				
			||||||
@@ -112,19 +110,17 @@ class SupervisedPrototypeModel(PrototypeModel):
 | 
				
			|||||||
        super().__init__(hparams, **kwargs)
 | 
					        super().__init__(hparams, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Layers
 | 
					        # Layers
 | 
				
			||||||
        prototype_initializer = kwargs.get("prototype_initializer", None)
 | 
					        prototypes_initializer = kwargs.get("prototypes_initializer", None)
 | 
				
			||||||
        initialized_prototypes = kwargs.get("initialized_prototypes", None)
 | 
					        if prototypes_initializer is not None:
 | 
				
			||||||
        if prototype_initializer is not None or initialized_prototypes is not None:
 | 
					 | 
				
			||||||
            self.proto_layer = LabeledComponents(
 | 
					            self.proto_layer = LabeledComponents(
 | 
				
			||||||
                distribution=self.hparams.distribution,
 | 
					                distribution=self.hparams.distribution,
 | 
				
			||||||
                initializer=prototype_initializer,
 | 
					                components_initializer=prototypes_initializer,
 | 
				
			||||||
                initialized_components=initialized_prototypes,
 | 
					 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        self.competition_layer = WTAC()
 | 
					        self.competition_layer = WTAC()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def prototype_labels(self):
 | 
					    def prototype_labels(self):
 | 
				
			||||||
        return self.proto_layer.component_labels.detach().cpu()
 | 
					        return self.proto_layer.labels.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def num_classes(self):
 | 
					    def num_classes(self):
 | 
				
			||||||
@@ -137,15 +133,14 @@ class SupervisedPrototypeModel(PrototypeModel):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        distances = self.compute_distances(x)
 | 
					        distances = self.compute_distances(x)
 | 
				
			||||||
        y_pred = self.predict_from_distances(distances)
 | 
					        plabels = self.proto_layer.labels
 | 
				
			||||||
        # TODO
 | 
					        winning = stratified_min_pooling(distances, plabels)
 | 
				
			||||||
        y_pred = torch.eye(self.num_classes, device=self.device)[
 | 
					        y_pred = torch.nn.functional.softmin(winning)
 | 
				
			||||||
            y_pred.long()]  # depends on labels {0,...,num_classes}
 | 
					 | 
				
			||||||
        return y_pred
 | 
					        return y_pred
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def predict_from_distances(self, distances):
 | 
					    def predict_from_distances(self, distances):
 | 
				
			||||||
        with torch.no_grad():
 | 
					        with torch.no_grad():
 | 
				
			||||||
            plabels = self.proto_layer.component_labels
 | 
					            plabels = self.proto_layer.labels
 | 
				
			||||||
            y_pred = self.competition_layer(distances, plabels)
 | 
					            y_pred = self.competition_layer(distances, plabels)
 | 
				
			||||||
        return y_pred
 | 
					        return y_pred
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -20,9 +20,13 @@ class KNN(SupervisedPrototypeModel):
 | 
				
			|||||||
        data = kwargs.get("data", None)
 | 
					        data = kwargs.get("data", None)
 | 
				
			||||||
        if data is None:
 | 
					        if data is None:
 | 
				
			||||||
            raise ValueError("KNN requires data, but was not provided!")
 | 
					            raise ValueError("KNN requires data, but was not provided!")
 | 
				
			||||||
 | 
					        data, targets = parse_data_arg(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Layers
 | 
					        # Layers
 | 
				
			||||||
        self.proto_layer = LabeledComponents(initialized_components=data)
 | 
					        self.proto_layer = LabeledComponents(
 | 
				
			||||||
 | 
					            distribution=[],
 | 
				
			||||||
 | 
					            components_initializer=LiteralCompInitializer(data),
 | 
				
			||||||
 | 
					            labels_initializer=LiteralLabelsInitializer(targets))
 | 
				
			||||||
        self.competition_layer = KNNC(k=self.hparams.k)
 | 
					        self.competition_layer = KNNC(k=self.hparams.k)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
					    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user