chore: fix errors for pytorch_lightning>1.6
This commit is contained in:
parent
dbfe315f4f
commit
5911f4dd90
@ -43,7 +43,7 @@ class ProtoTorchBolt(pl.LightningModule):
|
|||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
def reconfigure_optimizers(self):
|
def reconfigure_optimizers(self):
|
||||||
self.trainer.accelerator.setup_optimizers(self.trainer)
|
self.trainer.strategy.setup_optimizers(self.trainer)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
surep = super().__repr__()
|
surep = super().__repr__()
|
||||||
|
@ -137,4 +137,4 @@ class GNGCallback(pl.Callback):
|
|||||||
pl_module.errors[
|
pl_module.errors[
|
||||||
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
||||||
|
|
||||||
trainer.accelerator.setup_optimizers(trainer)
|
trainer.strategy.setup_optimizers(trainer)
|
||||||
|
@ -148,7 +148,7 @@ class SiameseGLVQ(GLVQ):
|
|||||||
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
|
|
||||||
bb_grad = self.backbone._weights.requires_grad
|
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
|
||||||
|
|
||||||
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
|
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
|
||||||
latent_protos = self.backbone(protos)
|
latent_protos = self.backbone(protos)
|
||||||
|
@ -58,8 +58,10 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
|||||||
diff = x.unsqueeze(dim=1) - protos
|
diff = x.unsqueeze(dim=1) - protos
|
||||||
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
|
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
|
||||||
updated_protos = protos + delta.sum(dim=0)
|
updated_protos = protos + delta.sum(dim=0)
|
||||||
self.proto_layer.load_state_dict({"_components": updated_protos},
|
self.proto_layer.load_state_dict(
|
||||||
strict=False)
|
{"_components": updated_protos},
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
def training_epoch_end(self, training_step_outputs):
|
def training_epoch_end(self, training_step_outputs):
|
||||||
self._sigma = self.hparams.sigma * np.exp(
|
self._sigma = self.hparams.sigma * np.exp(
|
||||||
@ -145,6 +147,8 @@ class GrowingNeuralGas(NeuralGas):
|
|||||||
|
|
||||||
def configure_callbacks(self):
|
def configure_callbacks(self):
|
||||||
return [
|
return [
|
||||||
GNGCallback(reduction=self.hparams.insert_reduction,
|
GNGCallback(
|
||||||
freq=self.hparams.insert_freq)
|
reduction=self.hparams.insert_reduction,
|
||||||
|
freq=self.hparams.insert_freq,
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user