diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index b90ca88..7642228 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -43,7 +43,7 @@ class ProtoTorchBolt(pl.LightningModule): return optimizer def reconfigure_optimizers(self): - self.trainer.accelerator.setup_optimizers(self.trainer) + self.trainer.strategy.setup_optimizers(self.trainer) def __repr__(self): surep = super().__repr__() diff --git a/prototorch/models/callbacks.py b/prototorch/models/callbacks.py index f12d162..2489cb4 100644 --- a/prototorch/models/callbacks.py +++ b/prototorch/models/callbacks.py @@ -137,4 +137,4 @@ class GNGCallback(pl.Callback): pl_module.errors[ worst_neighbor] = errors[worst_neighbor] * self.reduction - trainer.accelerator.setup_optimizers(trainer) + trainer.strategy.setup_optimizers(trainer) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index c5556ca..089ae30 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -148,7 +148,7 @@ class SiameseGLVQ(GLVQ): x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)] latent_x = self.backbone(x) - bb_grad = self.backbone._weights.requires_grad + bb_grad = any([el.requires_grad for el in self.backbone.parameters()]) self.backbone.requires_grad_(bb_grad and self.both_path_gradients) latent_protos = self.backbone(protos) diff --git a/prototorch/models/unsupervised.py b/prototorch/models/unsupervised.py index f59157d..64a324a 100644 --- a/prototorch/models/unsupervised.py +++ b/prototorch/models/unsupervised.py @@ -58,8 +58,10 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel): diff = x.unsqueeze(dim=1) - protos delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff updated_protos = protos + delta.sum(dim=0) - self.proto_layer.load_state_dict({"_components": updated_protos}, - strict=False) + self.proto_layer.load_state_dict( + {"_components": updated_protos}, + strict=False, + ) def training_epoch_end(self, training_step_outputs): self._sigma = self.hparams.sigma * np.exp( @@ -145,6 +147,8 @@ class GrowingNeuralGas(NeuralGas): def configure_callbacks(self): return [ - GNGCallback(reduction=self.hparams.insert_reduction, - freq=self.hparams.insert_freq) + GNGCallback( + reduction=self.hparams.insert_reduction, + freq=self.hparams.insert_freq, + ) ]