chore: fix errors for pytorch_lightning>1.6

This commit is contained in:
Alexander Engelsberger 2022-04-27 09:25:42 +02:00
parent dbfe315f4f
commit 5911f4dd90
4 changed files with 11 additions and 7 deletions

View File

@ -43,7 +43,7 @@ class ProtoTorchBolt(pl.LightningModule):
return optimizer
def reconfigure_optimizers(self):
self.trainer.accelerator.setup_optimizers(self.trainer)
self.trainer.strategy.setup_optimizers(self.trainer)
def __repr__(self):
surep = super().__repr__()

View File

@ -137,4 +137,4 @@ class GNGCallback(pl.Callback):
pl_module.errors[
worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.accelerator.setup_optimizers(trainer)
trainer.strategy.setup_optimizers(trainer)

View File

@ -148,7 +148,7 @@ class SiameseGLVQ(GLVQ):
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
latent_x = self.backbone(x)
bb_grad = self.backbone._weights.requires_grad
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
latent_protos = self.backbone(protos)

View File

@ -58,8 +58,10 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
diff = x.unsqueeze(dim=1) - protos
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
updated_protos = protos + delta.sum(dim=0)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
self.proto_layer.load_state_dict(
{"_components": updated_protos},
strict=False,
)
def training_epoch_end(self, training_step_outputs):
self._sigma = self.hparams.sigma * np.exp(
@ -145,6 +147,8 @@ class GrowingNeuralGas(NeuralGas):
def configure_callbacks(self):
return [
GNGCallback(reduction=self.hparams.insert_reduction,
freq=self.hparams.insert_freq)
GNGCallback(
reduction=self.hparams.insert_reduction,
freq=self.hparams.insert_freq,
)
]