chore: remove optimizer_idx from all steps
This commit is contained in:
parent
60990f42d2
commit
71167a8f77
@ -2,7 +2,6 @@
|
||||
|
||||
import logging
|
||||
|
||||
import prototorch
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -228,7 +227,7 @@ class NonGradientMixin(ProtoTorchMixin):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@ -44,7 +44,7 @@ class CBC(SiameseGLVQ):
|
||||
probs = self.competition_layer(detections, reasonings)
|
||||
return probs
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
def shared_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_pred = self(x)
|
||||
num_classes = self.num_classes
|
||||
@ -52,8 +52,8 @@ class CBC(SiameseGLVQ):
|
||||
loss = self.loss(y_pred, y_true).mean()
|
||||
return y_pred, loss
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||
def training_step(self, batch, batch_idx):
|
||||
y_pred, train_loss = self.shared_step(batch, batch_idx)
|
||||
preds = torch.argmax(y_pred, dim=1)
|
||||
accuracy = torchmetrics.functional.accuracy(
|
||||
preds.int(),
|
||||
|
@ -66,15 +66,15 @@ class GLVQ(SupervisedPrototypeModel):
|
||||
prototype_wr,
|
||||
])
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
def shared_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
out = self.compute_distances(x)
|
||||
_, plabels = self.proto_layer()
|
||||
loss = self.loss(out, y, plabels)
|
||||
return out, loss
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||
def training_step(self, batch, batch_idx):
|
||||
out, train_loss = self.shared_step(batch, batch_idx)
|
||||
self.log_prototype_win_ratios(out)
|
||||
self.log("train_loss", train_loss)
|
||||
self.log_acc(out, batch[-1], tag="train_acc")
|
||||
@ -99,10 +99,6 @@ class GLVQ(SupervisedPrototypeModel):
|
||||
test_loss += batch_loss.item()
|
||||
self.log("test_loss", test_loss)
|
||||
|
||||
# TODO
|
||||
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
# pass
|
||||
|
||||
|
||||
class SiameseGLVQ(GLVQ):
|
||||
"""GLVQ in a Siamese setting.
|
||||
|
@ -34,7 +34,7 @@ class KNN(SupervisedPrototypeModel):
|
||||
labels_initializer=LiteralLabelsInitializer(targets))
|
||||
self.competition_layer = KNNC(k=self.hparams.k)
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
return 1 # skip training step
|
||||
|
||||
def on_train_batch_start(self, train_batch, batch_idx):
|
||||
|
@ -13,7 +13,7 @@ from .glvq import GLVQ
|
||||
class LVQ1(NonGradientMixin, GLVQ):
|
||||
"""Learning Vector Quantization 1."""
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
protos, plables = self.proto_layer()
|
||||
x, y = train_batch
|
||||
dis = self.compute_distances(x)
|
||||
@ -43,7 +43,7 @@ class LVQ1(NonGradientMixin, GLVQ):
|
||||
class LVQ21(NonGradientMixin, GLVQ):
|
||||
"""Learning Vector Quantization 2.1."""
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
protos, plabels = self.proto_layer()
|
||||
|
||||
x, y = train_batch
|
||||
@ -100,7 +100,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
|
||||
lower_bound = (gamma * f.log()).sum()
|
||||
return lower_bound
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
protos, plabels = self.proto_layer()
|
||||
|
||||
x, y = train_batch
|
||||
|
@ -21,7 +21,7 @@ class CELVQ(GLVQ):
|
||||
# Loss
|
||||
self.loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
def shared_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
out = self.compute_distances(x) # [None, num_protos]
|
||||
_, plabels = self.proto_layer()
|
||||
@ -63,7 +63,7 @@ class ProbabilisticLVQ(GLVQ):
|
||||
prediction[confidence < self.rejection_confidence] = -1
|
||||
return prediction
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
out = self.forward(x)
|
||||
_, plabels = self.proto_layer()
|
||||
@ -123,7 +123,7 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
||||
self.loss = torch.nn.KLDivLoss()
|
||||
|
||||
# FIXME
|
||||
# def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
# def training_step(self, batch, batch_idx):
|
||||
# x, y = batch
|
||||
# y_pred = self(x)
|
||||
# batch_loss = self.loss(y_pred, y)
|
||||
|
Loading…
Reference in New Issue
Block a user