[BUG] LVQ1 is broken
This commit is contained in:
@@ -9,7 +9,7 @@ class LVQ1(NonGradientMixin, GLVQ):
|
||||
"""Learning Vector Quantization 1."""
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
protos = self.proto_layer.components
|
||||
plabels = self.proto_layer.component_labels
|
||||
plabels = self.proto_layer.labels
|
||||
|
||||
x, y = train_batch
|
||||
dis = self.compute_distances(x)
|
||||
@@ -28,6 +28,8 @@ class LVQ1(NonGradientMixin, GLVQ):
|
||||
self.proto_layer.load_state_dict({"_components": updated_protos},
|
||||
strict=False)
|
||||
|
||||
print(f"{dis=}")
|
||||
print(f"{y=}")
|
||||
# Logging
|
||||
self.log_acc(dis, y, tag="train_acc")
|
||||
|
||||
@@ -38,7 +40,7 @@ class LVQ21(NonGradientMixin, GLVQ):
|
||||
"""Learning Vector Quantization 2.1."""
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
protos = self.proto_layer.components
|
||||
plabels = self.proto_layer.component_labels
|
||||
plabels = self.proto_layer.labels
|
||||
|
||||
x, y = train_batch
|
||||
dis = self.compute_distances(x)
|
||||
|
Reference in New Issue
Block a user