[refactor] Use functional variant of accuracy
Prevents Accuracy in `__repr__` of the models.
This commit is contained in:
@@ -33,7 +33,6 @@ class GLVQ(AbstractPrototypeModel):
|
||||
initializer=self.prototype_initializer(**kwargs))
|
||||
|
||||
self.transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||
self.acc_metric = torchmetrics.Accuracy()
|
||||
|
||||
self.loss = glvq_loss
|
||||
|
||||
@@ -73,11 +72,11 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
def log_acc(self, distances, targets, tag):
|
||||
preds = self.predict_from_distances(distances)
|
||||
self.acc_metric(preds.int(), targets.int())
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||
|
||||
self.log(tag,
|
||||
self.acc_metric,
|
||||
accuracy,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
|
Reference in New Issue
Block a user