Accumulate test loss

This commit is contained in:
Alexander Engelsberger 2021-05-20 14:20:23 +02:00
parent 0204f5eab6
commit 969fb34cc3

View File

@ -97,7 +97,14 @@ class GLVQ(AbstractPrototypeModel):
out, test_loss = self.shared_step(batch, batch_idx)
self.log_acc(out, batch[-1], tag="test_acc")
self.log_dict({'test_loss': test_loss})
return test_loss
def test_epoch_end(self, outputs):
total_loss = 0
for batch_loss in outputs:
total_loss += batch_loss.item()
self.log('test_loss', total_loss)
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass