[BUGFIX] KNN works again

This commit is contained in:
Jensun Ravichandran 2021-06-14 20:09:41 +02:00
parent 69e5ff3243
commit 97ec15b76a
2 changed files with 16 additions and 17 deletions

View File

@ -88,13 +88,11 @@ class UnsupervisedPrototypeModel(PrototypeModel):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
# Layers # Layers
prototype_initializer = kwargs.get("prototype_initializer", None) prototypes_initializer = kwargs.get("prototypes_initializer", None)
initialized_prototypes = kwargs.get("initialized_prototypes", None) if prototypes_initializer is not None:
if prototype_initializer is not None or initialized_prototypes is not None:
self.proto_layer = Components( self.proto_layer = Components(
self.hparams.num_prototypes, self.hparams.num_prototypes,
initializer=prototype_initializer, initializer=prototypes_initializer,
initialized_components=initialized_prototypes,
) )
def compute_distances(self, x): def compute_distances(self, x):
@ -112,19 +110,17 @@ class SupervisedPrototypeModel(PrototypeModel):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
# Layers # Layers
prototype_initializer = kwargs.get("prototype_initializer", None) prototypes_initializer = kwargs.get("prototypes_initializer", None)
initialized_prototypes = kwargs.get("initialized_prototypes", None) if prototypes_initializer is not None:
if prototype_initializer is not None or initialized_prototypes is not None:
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution, distribution=self.hparams.distribution,
initializer=prototype_initializer, components_initializer=prototypes_initializer,
initialized_components=initialized_prototypes,
) )
self.competition_layer = WTAC() self.competition_layer = WTAC()
@property @property
def prototype_labels(self): def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu() return self.proto_layer.labels.detach().cpu()
@property @property
def num_classes(self): def num_classes(self):
@ -137,15 +133,14 @@ class SupervisedPrototypeModel(PrototypeModel):
def forward(self, x): def forward(self, x):
distances = self.compute_distances(x) distances = self.compute_distances(x)
y_pred = self.predict_from_distances(distances) plabels = self.proto_layer.labels
# TODO winning = stratified_min_pooling(distances, plabels)
y_pred = torch.eye(self.num_classes, device=self.device)[ y_pred = torch.nn.functional.softmin(winning)
y_pred.long()] # depends on labels {0,...,num_classes}
return y_pred return y_pred
def predict_from_distances(self, distances): def predict_from_distances(self, distances):
with torch.no_grad(): with torch.no_grad():
plabels = self.proto_layer.component_labels plabels = self.proto_layer.labels
y_pred = self.competition_layer(distances, plabels) y_pred = self.competition_layer(distances, plabels)
return y_pred return y_pred

View File

@ -20,9 +20,13 @@ class KNN(SupervisedPrototypeModel):
data = kwargs.get("data", None) data = kwargs.get("data", None)
if data is None: if data is None:
raise ValueError("KNN requires data, but was not provided!") raise ValueError("KNN requires data, but was not provided!")
data, targets = parse_data_arg(data)
# Layers # Layers
self.proto_layer = LabeledComponents(initialized_components=data) self.proto_layer = LabeledComponents(
distribution=[],
components_initializer=LiteralCompInitializer(data),
labels_initializer=LiteralLabelsInitializer(targets))
self.competition_layer = KNNC(k=self.hparams.k) self.competition_layer = KNNC(k=self.hparams.k)
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):