diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index efd76b7..ff1da61 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -52,6 +52,8 @@ class GLVQ(AbstractPrototypeModel): # Compute training accuracy with torch.no_grad(): preds = wtac(dis, plabels) + + self.train_acc(preds.int(), y.int()) # `.int()` because FloatTensors are assumed to be class probabilities # Logging