[refactor] Use functional variant of accuracy

Prevents Accuracy in `__repr__` of the models.
This commit is contained in:
Alexander Engelsberger 2021-05-28 21:30:50 +02:00
parent e9d2075fed
commit 0ac4ced85d
3 changed files with 7 additions and 5 deletions

View File

@ -36,6 +36,8 @@ if __name__ == "__main__":
prototype_initializer=pt.components.SMI(train_ds),
)
print(model)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)

View File

@ -130,9 +130,10 @@ class CBC(SiameseGLVQ):
def training_step(self, batch, batch_idx, optimizer_idx=None):
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
preds = torch.argmax(y_pred, dim=1)
self.acc_metric(preds.int(), batch[1].int())
accuracy = torchmetrics.functional.accuracy(preds.int(),
batch[1].int())
self.log("train_acc",
self.acc_metric,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,

View File

@ -33,7 +33,6 @@ class GLVQ(AbstractPrototypeModel):
initializer=self.prototype_initializer(**kwargs))
self.transfer_fn = get_activation(self.hparams.transfer_fn)
self.acc_metric = torchmetrics.Accuracy()
self.loss = glvq_loss
@ -73,11 +72,11 @@ class GLVQ(AbstractPrototypeModel):
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
self.acc_metric(preds.int(), targets.int())
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
# `.int()` because FloatTensors are assumed to be class probabilities
self.log(tag,
self.acc_metric,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,