fix: use multiclass accuracy by default
This commit is contained in:
parent
4cd6aee330
commit
2a665e220f
@ -2,6 +2,7 @@
|
||||
|
||||
import logging
|
||||
|
||||
import prototorch
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -186,21 +187,32 @@ class SupervisedPrototypeModel(PrototypeModel):
|
||||
|
||||
def log_acc(self, distances, targets, tag):
|
||||
preds = self.predict_from_distances(distances)
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||
accuracy = torchmetrics.functional.accuracy(
|
||||
preds.int(),
|
||||
targets.int(),
|
||||
"multiclass",
|
||||
num_classes=self.num_classes,
|
||||
)
|
||||
|
||||
self.log(tag,
|
||||
accuracy,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True)
|
||||
self.log(
|
||||
tag,
|
||||
accuracy,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, targets = batch
|
||||
|
||||
preds = self.predict(x)
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||
accuracy = torchmetrics.functional.accuracy(
|
||||
preds.int(),
|
||||
targets.int(),
|
||||
"multiclass",
|
||||
num_classes=self.num_classes,
|
||||
)
|
||||
|
||||
self.log("test_acc", accuracy)
|
||||
|
||||
|
@ -55,14 +55,20 @@ class CBC(SiameseGLVQ):
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||
preds = torch.argmax(y_pred, dim=1)
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(),
|
||||
batch[1].int())
|
||||
self.log("train_acc",
|
||||
accuracy,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True)
|
||||
accuracy = torchmetrics.functional.accuracy(
|
||||
preds.int(),
|
||||
batch[1].int(),
|
||||
"multiclass",
|
||||
num_classes=self.num_classes,
|
||||
)
|
||||
self.log(
|
||||
"train_acc",
|
||||
accuracy,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
)
|
||||
return train_loss
|
||||
|
||||
def predict(self, x):
|
||||
|
Loading…
Reference in New Issue
Block a user