feat: add simple test step

This commit is contained in:
Alexander Engelsberger
2021-09-10 19:19:51 +02:00
parent fa928afe2c
commit d7ea89d47e
2 changed files with 26 additions and 6 deletions

View File

@@ -162,6 +162,14 @@ class SupervisedPrototypeModel(PrototypeModel):
prog_bar=True,
logger=True)
def test_step(self, batch, batch_idx):
x, targets = batch
preds = self.predict(x)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
self.log("test_acc", accuracy)
class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins."""