fix: use multiclass accuracy by default

This commit is contained in:
Alexander Engelsberger 2023-06-20 18:30:18 +02:00
parent 4cd6aee330
commit 2a665e220f
No known key found for this signature in database
2 changed files with 35 additions and 17 deletions

View File

@ -2,6 +2,7 @@
import logging
import prototorch
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
@ -186,21 +187,32 @@ class SupervisedPrototypeModel(PrototypeModel):
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
# `.int()` because FloatTensors are assumed to be class probabilities
accuracy = torchmetrics.functional.accuracy(
preds.int(),
targets.int(),
"multiclass",
num_classes=self.num_classes,
)
self.log(tag,
self.log(
tag,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
logger=True,
)
def test_step(self, batch, batch_idx):
x, targets = batch
preds = self.predict(x)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
accuracy = torchmetrics.functional.accuracy(
preds.int(),
targets.int(),
"multiclass",
num_classes=self.num_classes,
)
self.log("test_acc", accuracy)

View File

@ -55,14 +55,20 @@ class CBC(SiameseGLVQ):
def training_step(self, batch, batch_idx, optimizer_idx=None):
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
preds = torch.argmax(y_pred, dim=1)
accuracy = torchmetrics.functional.accuracy(preds.int(),
batch[1].int())
self.log("train_acc",
accuracy = torchmetrics.functional.accuracy(
preds.int(),
batch[1].int(),
"multiclass",
num_classes=self.num_classes,
)
self.log(
"train_acc",
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
logger=True,
)
return train_loss
def predict(self, x):