diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 1f959fb..65b291f 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -87,14 +87,14 @@ class GLVQ(AbstractPrototypeModel): def validation_step(self, batch, batch_idx): # `model.eval()` and `torch.no_grad()` handled by pl - out, val_loss = self.shared_step(batch, batch_idx, optimizer_idx) + out, val_loss = self.shared_step(batch, batch_idx) self.log("val_loss", val_loss) self.log_acc(out, batch[-1], tag="val_acc") return val_loss def test_step(self, batch, batch_idx): # `model.eval()` and `torch.no_grad()` handled by pl - out, test_loss = self.shared_step(batch, batch_idx, optimizer_idx) + out, test_loss = self.shared_step(batch, batch_idx) return test_loss # def predict_step(self, batch, batch_idx, dataloader_idx=None):