From 71167a8f777406bb81fafcad506b4c67f1e28754 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Wed, 25 Oct 2023 15:03:13 +0200 Subject: [PATCH] chore: remove optimizer_idx from all steps --- src/prototorch/models/abstract.py | 3 +-- src/prototorch/models/cbc.py | 6 +++--- src/prototorch/models/glvq.py | 10 +++------- src/prototorch/models/knn.py | 2 +- src/prototorch/models/lvq.py | 6 +++--- src/prototorch/models/probabilistic.py | 6 +++--- 6 files changed, 14 insertions(+), 19 deletions(-) diff --git a/src/prototorch/models/abstract.py b/src/prototorch/models/abstract.py index 9198bfb..0b69f63 100644 --- a/src/prototorch/models/abstract.py +++ b/src/prototorch/models/abstract.py @@ -2,7 +2,6 @@ import logging -import prototorch import pytorch_lightning as pl import torch import torch.nn.functional as F @@ -228,7 +227,7 @@ class NonGradientMixin(ProtoTorchMixin): super().__init__(*args, **kwargs) self.automatic_optimization = False - def training_step(self, train_batch, batch_idx, optimizer_idx=None): + def training_step(self, train_batch, batch_idx): raise NotImplementedError diff --git a/src/prototorch/models/cbc.py b/src/prototorch/models/cbc.py index 6e3bd9a..d114024 100644 --- a/src/prototorch/models/cbc.py +++ b/src/prototorch/models/cbc.py @@ -44,7 +44,7 @@ class CBC(SiameseGLVQ): probs = self.competition_layer(detections, reasonings) return probs - def shared_step(self, batch, batch_idx, optimizer_idx=None): + def shared_step(self, batch, batch_idx): x, y = batch y_pred = self(x) num_classes = self.num_classes @@ -52,8 +52,8 @@ class CBC(SiameseGLVQ): loss = self.loss(y_pred, y_true).mean() return y_pred, loss - def training_step(self, batch, batch_idx, optimizer_idx=None): - y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx) + def training_step(self, batch, batch_idx): + y_pred, train_loss = self.shared_step(batch, batch_idx) preds = torch.argmax(y_pred, dim=1) accuracy = torchmetrics.functional.accuracy( preds.int(), diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index 8f77fdd..4328b10 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -66,15 +66,15 @@ class GLVQ(SupervisedPrototypeModel): prototype_wr, ]) - def shared_step(self, batch, batch_idx, optimizer_idx=None): + def shared_step(self, batch, batch_idx): x, y = batch out = self.compute_distances(x) _, plabels = self.proto_layer() loss = self.loss(out, y, plabels) return out, loss - def training_step(self, batch, batch_idx, optimizer_idx=None): - out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx) + def training_step(self, batch, batch_idx): + out, train_loss = self.shared_step(batch, batch_idx) self.log_prototype_win_ratios(out) self.log("train_loss", train_loss) self.log_acc(out, batch[-1], tag="train_acc") @@ -99,10 +99,6 @@ class GLVQ(SupervisedPrototypeModel): test_loss += batch_loss.item() self.log("test_loss", test_loss) - # TODO - # def predict_step(self, batch, batch_idx, dataloader_idx=None): - # pass - class SiameseGLVQ(GLVQ): """GLVQ in a Siamese setting. diff --git a/src/prototorch/models/knn.py b/src/prototorch/models/knn.py index d277206..a8d3152 100644 --- a/src/prototorch/models/knn.py +++ b/src/prototorch/models/knn.py @@ -34,7 +34,7 @@ class KNN(SupervisedPrototypeModel): labels_initializer=LiteralLabelsInitializer(targets)) self.competition_layer = KNNC(k=self.hparams.k) - def training_step(self, train_batch, batch_idx, optimizer_idx=None): + def training_step(self, train_batch, batch_idx): return 1 # skip training step def on_train_batch_start(self, train_batch, batch_idx): diff --git a/src/prototorch/models/lvq.py b/src/prototorch/models/lvq.py index aa893a2..19e4440 100644 --- a/src/prototorch/models/lvq.py +++ b/src/prototorch/models/lvq.py @@ -13,7 +13,7 @@ from .glvq import GLVQ class LVQ1(NonGradientMixin, GLVQ): """Learning Vector Quantization 1.""" - def training_step(self, train_batch, batch_idx, optimizer_idx=None): + def training_step(self, train_batch, batch_idx): protos, plables = self.proto_layer() x, y = train_batch dis = self.compute_distances(x) @@ -43,7 +43,7 @@ class LVQ1(NonGradientMixin, GLVQ): class LVQ21(NonGradientMixin, GLVQ): """Learning Vector Quantization 2.1.""" - def training_step(self, train_batch, batch_idx, optimizer_idx=None): + def training_step(self, train_batch, batch_idx): protos, plabels = self.proto_layer() x, y = train_batch @@ -100,7 +100,7 @@ class MedianLVQ(NonGradientMixin, GLVQ): lower_bound = (gamma * f.log()).sum() return lower_bound - def training_step(self, train_batch, batch_idx, optimizer_idx=None): + def training_step(self, train_batch, batch_idx): protos, plabels = self.proto_layer() x, y = train_batch diff --git a/src/prototorch/models/probabilistic.py b/src/prototorch/models/probabilistic.py index 79da5d9..d060fba 100644 --- a/src/prototorch/models/probabilistic.py +++ b/src/prototorch/models/probabilistic.py @@ -21,7 +21,7 @@ class CELVQ(GLVQ): # Loss self.loss = torch.nn.CrossEntropyLoss() - def shared_step(self, batch, batch_idx, optimizer_idx=None): + def shared_step(self, batch, batch_idx): x, y = batch out = self.compute_distances(x) # [None, num_protos] _, plabels = self.proto_layer() @@ -63,7 +63,7 @@ class ProbabilisticLVQ(GLVQ): prediction[confidence < self.rejection_confidence] = -1 return prediction - def training_step(self, batch, batch_idx, optimizer_idx=None): + def training_step(self, batch, batch_idx): x, y = batch out = self.forward(x) _, plabels = self.proto_layer() @@ -123,7 +123,7 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ): self.loss = torch.nn.KLDivLoss() # FIXME - # def training_step(self, batch, batch_idx, optimizer_idx=None): + # def training_step(self, batch, batch_idx): # x, y = batch # y_pred = self(x) # batch_loss = self.loss(y_pred, y)