[refactor] Use functional variant of accuracy
Prevents Accuracy in `__repr__` of the models.
This commit is contained in:
parent
e9d2075fed
commit
0ac4ced85d
@ -36,6 +36,8 @@ if __name__ == "__main__":
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
)
|
||||
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
|
||||
|
@ -130,9 +130,10 @@ class CBC(SiameseGLVQ):
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||
preds = torch.argmax(y_pred, dim=1)
|
||||
self.acc_metric(preds.int(), batch[1].int())
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(),
|
||||
batch[1].int())
|
||||
self.log("train_acc",
|
||||
self.acc_metric,
|
||||
accuracy,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
|
@ -33,7 +33,6 @@ class GLVQ(AbstractPrototypeModel):
|
||||
initializer=self.prototype_initializer(**kwargs))
|
||||
|
||||
self.transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||
self.acc_metric = torchmetrics.Accuracy()
|
||||
|
||||
self.loss = glvq_loss
|
||||
|
||||
@ -73,11 +72,11 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
def log_acc(self, distances, targets, tag):
|
||||
preds = self.predict_from_distances(distances)
|
||||
self.acc_metric(preds.int(), targets.int())
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||
|
||||
self.log(tag,
|
||||
self.acc_metric,
|
||||
accuracy,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
|
Loading…
Reference in New Issue
Block a user