[Bugfix] Remove optimzer_idx from validation and test.

This commit is contained in:
Alexander Engelsberger 2021-05-20 13:17:27 +02:00
parent 5ffbd43a7c
commit faf1a88f99

View File

@ -87,14 +87,14 @@ class GLVQ(AbstractPrototypeModel):
def validation_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, val_loss = self.shared_step(batch, batch_idx, optimizer_idx)
out, val_loss = self.shared_step(batch, batch_idx)
self.log("val_loss", val_loss)
self.log_acc(out, batch[-1], tag="val_acc")
return val_loss
def test_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, test_loss = self.shared_step(batch, batch_idx, optimizer_idx)
out, test_loss = self.shared_step(batch, batch_idx)
return test_loss
# def predict_step(self, batch, batch_idx, dataloader_idx=None):