[BUG] LVQ1 is broken

This commit is contained in:
Jensun Ravichandran
2021-06-14 21:08:05 +02:00
parent 7ec5528ade
commit 1b420c1f6b
2 changed files with 78 additions and 2 deletions

View File

@@ -9,7 +9,7 @@ class LVQ1(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
plabels = self.proto_layer.labels
x, y = train_batch
dis = self.compute_distances(x)
@@ -28,6 +28,8 @@ class LVQ1(NonGradientMixin, GLVQ):
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
print(f"{dis=}")
print(f"{y=}")
# Logging
self.log_acc(dis, y, tag="train_acc")
@@ -38,7 +40,7 @@ class LVQ21(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
plabels = self.proto_layer.labels
x, y = train_batch
dis = self.compute_distances(x)