fix: use multiclass accuracy by default

This commit is contained in:
Alexander Engelsberger 2023-06-20 18:30:18 +02:00
parent 4cd6aee330
commit 2a665e220f
No known key found for this signature in database
2 changed files with 35 additions and 17 deletions

View File

@ -2,6 +2,7 @@
import logging import logging
import prototorch
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -186,21 +187,32 @@ class SupervisedPrototypeModel(PrototypeModel):
def log_acc(self, distances, targets, tag): def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances) preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) accuracy = torchmetrics.functional.accuracy(
# `.int()` because FloatTensors are assumed to be class probabilities preds.int(),
targets.int(),
"multiclass",
num_classes=self.num_classes,
)
self.log(tag, self.log(
tag,
accuracy, accuracy,
on_step=False, on_step=False,
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
logger=True) logger=True,
)
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
x, targets = batch x, targets = batch
preds = self.predict(x) preds = self.predict(x)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) accuracy = torchmetrics.functional.accuracy(
preds.int(),
targets.int(),
"multiclass",
num_classes=self.num_classes,
)
self.log("test_acc", accuracy) self.log("test_acc", accuracy)

View File

@ -55,14 +55,20 @@ class CBC(SiameseGLVQ):
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx, optimizer_idx=None):
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx) y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
preds = torch.argmax(y_pred, dim=1) preds = torch.argmax(y_pred, dim=1)
accuracy = torchmetrics.functional.accuracy(preds.int(), accuracy = torchmetrics.functional.accuracy(
batch[1].int()) preds.int(),
self.log("train_acc", batch[1].int(),
"multiclass",
num_classes=self.num_classes,
)
self.log(
"train_acc",
accuracy, accuracy,
on_step=False, on_step=False,
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
logger=True) logger=True,
)
return train_loss return train_loss
def predict(self, x): def predict(self, x):