fix: use multiclass accuracy by default
This commit is contained in:
parent
4cd6aee330
commit
2a665e220f
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import prototorch
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -186,21 +187,32 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
|
|
||||||
def log_acc(self, distances, targets, tag):
|
def log_acc(self, distances, targets, tag):
|
||||||
preds = self.predict_from_distances(distances)
|
preds = self.predict_from_distances(distances)
|
||||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
accuracy = torchmetrics.functional.accuracy(
|
||||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
preds.int(),
|
||||||
|
targets.int(),
|
||||||
|
"multiclass",
|
||||||
|
num_classes=self.num_classes,
|
||||||
|
)
|
||||||
|
|
||||||
self.log(tag,
|
self.log(
|
||||||
|
tag,
|
||||||
accuracy,
|
accuracy,
|
||||||
on_step=False,
|
on_step=False,
|
||||||
on_epoch=True,
|
on_epoch=True,
|
||||||
prog_bar=True,
|
prog_bar=True,
|
||||||
logger=True)
|
logger=True,
|
||||||
|
)
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
x, targets = batch
|
x, targets = batch
|
||||||
|
|
||||||
preds = self.predict(x)
|
preds = self.predict(x)
|
||||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
accuracy = torchmetrics.functional.accuracy(
|
||||||
|
preds.int(),
|
||||||
|
targets.int(),
|
||||||
|
"multiclass",
|
||||||
|
num_classes=self.num_classes,
|
||||||
|
)
|
||||||
|
|
||||||
self.log("test_acc", accuracy)
|
self.log("test_acc", accuracy)
|
||||||
|
|
||||||
|
@ -55,14 +55,20 @@ class CBC(SiameseGLVQ):
|
|||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||||
preds = torch.argmax(y_pred, dim=1)
|
preds = torch.argmax(y_pred, dim=1)
|
||||||
accuracy = torchmetrics.functional.accuracy(preds.int(),
|
accuracy = torchmetrics.functional.accuracy(
|
||||||
batch[1].int())
|
preds.int(),
|
||||||
self.log("train_acc",
|
batch[1].int(),
|
||||||
|
"multiclass",
|
||||||
|
num_classes=self.num_classes,
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
"train_acc",
|
||||||
accuracy,
|
accuracy,
|
||||||
on_step=False,
|
on_step=False,
|
||||||
on_epoch=True,
|
on_epoch=True,
|
||||||
prog_bar=True,
|
prog_bar=True,
|
||||||
logger=True)
|
logger=True,
|
||||||
|
)
|
||||||
return train_loss
|
return train_loss
|
||||||
|
|
||||||
def predict(self, x):
|
def predict(self, x):
|
||||||
|
Loading…
Reference in New Issue
Block a user