Log test accuracy.

This commit is contained in:
Alexander Engelsberger 2021-05-20 14:03:31 +02:00
parent b7fc5df386
commit 0204f5eab6

View File

@ -95,6 +95,8 @@ class GLVQ(AbstractPrototypeModel):
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl # `model.eval()` and `torch.no_grad()` handled by pl
out, test_loss = self.shared_step(batch, batch_idx) out, test_loss = self.shared_step(batch, batch_idx)
self.log_acc(out, batch[-1], tag="test_acc")
self.log_dict({'test_loss': test_loss}) self.log_dict({'test_loss': test_loss})
# def predict_step(self, batch, batch_idx, dataloader_idx=None): # def predict_step(self, batch, batch_idx, dataloader_idx=None):