[refactor] Use functional variant of accuracy
Prevents Accuracy in `__repr__` of the models.
This commit is contained in:
parent
e9d2075fed
commit
0ac4ced85d
@ -36,6 +36,8 @@ if __name__ == "__main__":
|
|||||||
prototype_initializer=pt.components.SMI(train_ds),
|
prototype_initializer=pt.components.SMI(train_ds),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||||
|
|
||||||
|
@ -130,9 +130,10 @@ class CBC(SiameseGLVQ):
|
|||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||||
preds = torch.argmax(y_pred, dim=1)
|
preds = torch.argmax(y_pred, dim=1)
|
||||||
self.acc_metric(preds.int(), batch[1].int())
|
accuracy = torchmetrics.functional.accuracy(preds.int(),
|
||||||
|
batch[1].int())
|
||||||
self.log("train_acc",
|
self.log("train_acc",
|
||||||
self.acc_metric,
|
accuracy,
|
||||||
on_step=False,
|
on_step=False,
|
||||||
on_epoch=True,
|
on_epoch=True,
|
||||||
prog_bar=True,
|
prog_bar=True,
|
||||||
|
@ -33,7 +33,6 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
initializer=self.prototype_initializer(**kwargs))
|
initializer=self.prototype_initializer(**kwargs))
|
||||||
|
|
||||||
self.transfer_fn = get_activation(self.hparams.transfer_fn)
|
self.transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||||
self.acc_metric = torchmetrics.Accuracy()
|
|
||||||
|
|
||||||
self.loss = glvq_loss
|
self.loss = glvq_loss
|
||||||
|
|
||||||
@ -73,11 +72,11 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
def log_acc(self, distances, targets, tag):
|
def log_acc(self, distances, targets, tag):
|
||||||
preds = self.predict_from_distances(distances)
|
preds = self.predict_from_distances(distances)
|
||||||
self.acc_metric(preds.int(), targets.int())
|
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||||
|
|
||||||
self.log(tag,
|
self.log(tag,
|
||||||
self.acc_metric,
|
accuracy,
|
||||||
on_step=False,
|
on_step=False,
|
||||||
on_epoch=True,
|
on_epoch=True,
|
||||||
prog_bar=True,
|
prog_bar=True,
|
||||||
|
Loading…
Reference in New Issue
Block a user