[BUGFIX] KNN works again
This commit is contained in:
parent
69e5ff3243
commit
97ec15b76a
@ -88,13 +88,11 @@ class UnsupervisedPrototypeModel(PrototypeModel):
|
|||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
||||||
initialized_prototypes = kwargs.get("initialized_prototypes", None)
|
if prototypes_initializer is not None:
|
||||||
if prototype_initializer is not None or initialized_prototypes is not None:
|
|
||||||
self.proto_layer = Components(
|
self.proto_layer = Components(
|
||||||
self.hparams.num_prototypes,
|
self.hparams.num_prototypes,
|
||||||
initializer=prototype_initializer,
|
initializer=prototypes_initializer,
|
||||||
initialized_components=initialized_prototypes,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
@ -112,19 +110,17 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
prototypes_initializer = kwargs.get("prototypes_initializer", None)
|
||||||
initialized_prototypes = kwargs.get("initialized_prototypes", None)
|
if prototypes_initializer is not None:
|
||||||
if prototype_initializer is not None or initialized_prototypes is not None:
|
|
||||||
self.proto_layer = LabeledComponents(
|
self.proto_layer = LabeledComponents(
|
||||||
distribution=self.hparams.distribution,
|
distribution=self.hparams.distribution,
|
||||||
initializer=prototype_initializer,
|
components_initializer=prototypes_initializer,
|
||||||
initialized_components=initialized_prototypes,
|
|
||||||
)
|
)
|
||||||
self.competition_layer = WTAC()
|
self.competition_layer = WTAC()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prototype_labels(self):
|
def prototype_labels(self):
|
||||||
return self.proto_layer.component_labels.detach().cpu()
|
return self.proto_layer.labels.detach().cpu()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
@ -137,15 +133,14 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
distances = self.compute_distances(x)
|
distances = self.compute_distances(x)
|
||||||
y_pred = self.predict_from_distances(distances)
|
plabels = self.proto_layer.labels
|
||||||
# TODO
|
winning = stratified_min_pooling(distances, plabels)
|
||||||
y_pred = torch.eye(self.num_classes, device=self.device)[
|
y_pred = torch.nn.functional.softmin(winning)
|
||||||
y_pred.long()] # depends on labels {0,...,num_classes}
|
|
||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
def predict_from_distances(self, distances):
|
def predict_from_distances(self, distances):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.labels
|
||||||
y_pred = self.competition_layer(distances, plabels)
|
y_pred = self.competition_layer(distances, plabels)
|
||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
|
@ -20,9 +20,13 @@ class KNN(SupervisedPrototypeModel):
|
|||||||
data = kwargs.get("data", None)
|
data = kwargs.get("data", None)
|
||||||
if data is None:
|
if data is None:
|
||||||
raise ValueError("KNN requires data, but was not provided!")
|
raise ValueError("KNN requires data, but was not provided!")
|
||||||
|
data, targets = parse_data_arg(data)
|
||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
self.proto_layer = LabeledComponents(initialized_components=data)
|
self.proto_layer = LabeledComponents(
|
||||||
|
distribution=[],
|
||||||
|
components_initializer=LiteralCompInitializer(data),
|
||||||
|
labels_initializer=LiteralLabelsInitializer(targets))
|
||||||
self.competition_layer = KNNC(k=self.hparams.k)
|
self.competition_layer = KNNC(k=self.hparams.k)
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
|
Loading…
Reference in New Issue
Block a user