chore: remove optimizer_idx from all steps
This commit is contained in:
parent
60990f42d2
commit
71167a8f77
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import prototorch
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -228,7 +227,7 @@ class NonGradientMixin(ProtoTorchMixin):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.automatic_optimization = False
|
self.automatic_optimization = False
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ class CBC(SiameseGLVQ):
|
|||||||
probs = self.competition_layer(detections, reasonings)
|
probs = self.competition_layer(detections, reasonings)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_pred = self(x)
|
y_pred = self(x)
|
||||||
num_classes = self.num_classes
|
num_classes = self.num_classes
|
||||||
@ -52,8 +52,8 @@ class CBC(SiameseGLVQ):
|
|||||||
loss = self.loss(y_pred, y_true).mean()
|
loss = self.loss(y_pred, y_true).mean()
|
||||||
return y_pred, loss
|
return y_pred, loss
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx):
|
||||||
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
y_pred, train_loss = self.shared_step(batch, batch_idx)
|
||||||
preds = torch.argmax(y_pred, dim=1)
|
preds = torch.argmax(y_pred, dim=1)
|
||||||
accuracy = torchmetrics.functional.accuracy(
|
accuracy = torchmetrics.functional.accuracy(
|
||||||
preds.int(),
|
preds.int(),
|
||||||
|
@ -66,15 +66,15 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
prototype_wr,
|
prototype_wr,
|
||||||
])
|
])
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.compute_distances(x)
|
out = self.compute_distances(x)
|
||||||
_, plabels = self.proto_layer()
|
_, plabels = self.proto_layer()
|
||||||
loss = self.loss(out, y, plabels)
|
loss = self.loss(out, y, plabels)
|
||||||
return out, loss
|
return out, loss
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx):
|
||||||
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
out, train_loss = self.shared_step(batch, batch_idx)
|
||||||
self.log_prototype_win_ratios(out)
|
self.log_prototype_win_ratios(out)
|
||||||
self.log("train_loss", train_loss)
|
self.log("train_loss", train_loss)
|
||||||
self.log_acc(out, batch[-1], tag="train_acc")
|
self.log_acc(out, batch[-1], tag="train_acc")
|
||||||
@ -99,10 +99,6 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
test_loss += batch_loss.item()
|
test_loss += batch_loss.item()
|
||||||
self.log("test_loss", test_loss)
|
self.log("test_loss", test_loss)
|
||||||
|
|
||||||
# TODO
|
|
||||||
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
|
||||||
# pass
|
|
||||||
|
|
||||||
|
|
||||||
class SiameseGLVQ(GLVQ):
|
class SiameseGLVQ(GLVQ):
|
||||||
"""GLVQ in a Siamese setting.
|
"""GLVQ in a Siamese setting.
|
||||||
|
@ -34,7 +34,7 @@ class KNN(SupervisedPrototypeModel):
|
|||||||
labels_initializer=LiteralLabelsInitializer(targets))
|
labels_initializer=LiteralLabelsInitializer(targets))
|
||||||
self.competition_layer = KNNC(k=self.hparams.k)
|
self.competition_layer = KNNC(k=self.hparams.k)
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx):
|
||||||
return 1 # skip training step
|
return 1 # skip training step
|
||||||
|
|
||||||
def on_train_batch_start(self, train_batch, batch_idx):
|
def on_train_batch_start(self, train_batch, batch_idx):
|
||||||
|
@ -13,7 +13,7 @@ from .glvq import GLVQ
|
|||||||
class LVQ1(NonGradientMixin, GLVQ):
|
class LVQ1(NonGradientMixin, GLVQ):
|
||||||
"""Learning Vector Quantization 1."""
|
"""Learning Vector Quantization 1."""
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx):
|
||||||
protos, plables = self.proto_layer()
|
protos, plables = self.proto_layer()
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
dis = self.compute_distances(x)
|
dis = self.compute_distances(x)
|
||||||
@ -43,7 +43,7 @@ class LVQ1(NonGradientMixin, GLVQ):
|
|||||||
class LVQ21(NonGradientMixin, GLVQ):
|
class LVQ21(NonGradientMixin, GLVQ):
|
||||||
"""Learning Vector Quantization 2.1."""
|
"""Learning Vector Quantization 2.1."""
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx):
|
||||||
protos, plabels = self.proto_layer()
|
protos, plabels = self.proto_layer()
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
@ -100,7 +100,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
|
|||||||
lower_bound = (gamma * f.log()).sum()
|
lower_bound = (gamma * f.log()).sum()
|
||||||
return lower_bound
|
return lower_bound
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx):
|
||||||
protos, plabels = self.proto_layer()
|
protos, plabels = self.proto_layer()
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
|
@ -21,7 +21,7 @@ class CELVQ(GLVQ):
|
|||||||
# Loss
|
# Loss
|
||||||
self.loss = torch.nn.CrossEntropyLoss()
|
self.loss = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.compute_distances(x) # [None, num_protos]
|
out = self.compute_distances(x) # [None, num_protos]
|
||||||
_, plabels = self.proto_layer()
|
_, plabels = self.proto_layer()
|
||||||
@ -63,7 +63,7 @@ class ProbabilisticLVQ(GLVQ):
|
|||||||
prediction[confidence < self.rejection_confidence] = -1
|
prediction[confidence < self.rejection_confidence] = -1
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.forward(x)
|
out = self.forward(x)
|
||||||
_, plabels = self.proto_layer()
|
_, plabels = self.proto_layer()
|
||||||
@ -123,7 +123,7 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
|||||||
self.loss = torch.nn.KLDivLoss()
|
self.loss = torch.nn.KLDivLoss()
|
||||||
|
|
||||||
# FIXME
|
# FIXME
|
||||||
# def training_step(self, batch, batch_idx, optimizer_idx=None):
|
# def training_step(self, batch, batch_idx):
|
||||||
# x, y = batch
|
# x, y = batch
|
||||||
# y_pred = self(x)
|
# y_pred = self(x)
|
||||||
# batch_loss = self.loss(y_pred, y)
|
# batch_loss = self.loss(y_pred, y)
|
||||||
|
Loading…
Reference in New Issue
Block a user