diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index 42b0bb3..9198bfb 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -2,6 +2,7 @@ import logging +import prototorch import pytorch_lightning as pl import torch import torch.nn.functional as F @@ -186,21 +187,32 @@ class SupervisedPrototypeModel(PrototypeModel): def log_acc(self, distances, targets, tag): preds = self.predict_from_distances(distances) - accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) - # `.int()` because FloatTensors are assumed to be class probabilities + accuracy = torchmetrics.functional.accuracy( + preds.int(), + targets.int(), + "multiclass", + num_classes=self.num_classes, + ) - self.log(tag, - accuracy, - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True) + self.log( + tag, + accuracy, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) def test_step(self, batch, batch_idx): x, targets = batch preds = self.predict(x) - accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) + accuracy = torchmetrics.functional.accuracy( + preds.int(), + targets.int(), + "multiclass", + num_classes=self.num_classes, + ) self.log("test_acc", accuracy) diff --git a/prototorch/models/cbc.py b/prototorch/models/cbc.py index 2e38394..6e3bd9a 100644 --- a/prototorch/models/cbc.py +++ b/prototorch/models/cbc.py @@ -55,14 +55,20 @@ class CBC(SiameseGLVQ): def training_step(self, batch, batch_idx, optimizer_idx=None): y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx) preds = torch.argmax(y_pred, dim=1) - accuracy = torchmetrics.functional.accuracy(preds.int(), - batch[1].int()) - self.log("train_acc", - accuracy, - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True) + accuracy = torchmetrics.functional.accuracy( + preds.int(), + batch[1].int(), + "multiclass", + num_classes=self.num_classes, + ) + self.log( + "train_acc", + accuracy, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) return train_loss def predict(self, x):