chore: fix errors for pytorch_lightning>1.6
This commit is contained in:
parent
dbfe315f4f
commit
5911f4dd90
@ -43,7 +43,7 @@ class ProtoTorchBolt(pl.LightningModule):
|
||||
return optimizer
|
||||
|
||||
def reconfigure_optimizers(self):
|
||||
self.trainer.accelerator.setup_optimizers(self.trainer)
|
||||
self.trainer.strategy.setup_optimizers(self.trainer)
|
||||
|
||||
def __repr__(self):
|
||||
surep = super().__repr__()
|
||||
|
@ -137,4 +137,4 @@ class GNGCallback(pl.Callback):
|
||||
pl_module.errors[
|
||||
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
||||
|
||||
trainer.accelerator.setup_optimizers(trainer)
|
||||
trainer.strategy.setup_optimizers(trainer)
|
||||
|
@ -148,7 +148,7 @@ class SiameseGLVQ(GLVQ):
|
||||
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
||||
latent_x = self.backbone(x)
|
||||
|
||||
bb_grad = self.backbone._weights.requires_grad
|
||||
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
|
||||
|
||||
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
|
||||
latent_protos = self.backbone(protos)
|
||||
|
@ -58,8 +58,10 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||
diff = x.unsqueeze(dim=1) - protos
|
||||
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
|
||||
updated_protos = protos + delta.sum(dim=0)
|
||||
self.proto_layer.load_state_dict({"_components": updated_protos},
|
||||
strict=False)
|
||||
self.proto_layer.load_state_dict(
|
||||
{"_components": updated_protos},
|
||||
strict=False,
|
||||
)
|
||||
|
||||
def training_epoch_end(self, training_step_outputs):
|
||||
self._sigma = self.hparams.sigma * np.exp(
|
||||
@ -145,6 +147,8 @@ class GrowingNeuralGas(NeuralGas):
|
||||
|
||||
def configure_callbacks(self):
|
||||
return [
|
||||
GNGCallback(reduction=self.hparams.insert_reduction,
|
||||
freq=self.hparams.insert_freq)
|
||||
GNGCallback(
|
||||
reduction=self.hparams.insert_reduction,
|
||||
freq=self.hparams.insert_freq,
|
||||
)
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user