fix: example test fixed
This commit is contained in:
@@ -123,26 +123,6 @@ class SiameseGLVQ(GLVQ):
|
||||
self.backbone = backbone
|
||||
self.both_path_gradients = both_path_gradients
|
||||
|
||||
def configure_optimizers(self):
|
||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||
lr=self.hparams["proto_lr"])
|
||||
# Only add a backbone optimizer if backbone has trainable parameters
|
||||
bb_params = list(self.backbone.parameters())
|
||||
if (bb_params):
|
||||
bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"])
|
||||
optimizers = [proto_opt, bb_opt]
|
||||
else:
|
||||
optimizers = [proto_opt]
|
||||
if self.lr_scheduler is not None:
|
||||
schedulers = []
|
||||
for optimizer in optimizers:
|
||||
scheduler = self.lr_scheduler(optimizer,
|
||||
**self.lr_scheduler_kwargs)
|
||||
schedulers.append(scheduler)
|
||||
return optimizers, schedulers
|
||||
else:
|
||||
return optimizers
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
|
||||
|
@@ -63,7 +63,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||
strict=False,
|
||||
)
|
||||
|
||||
def training_epoch_end(self, training_step_outputs):
|
||||
def on_training_epoch_end(self, training_step_outputs):
|
||||
self._sigma = self.hparams.sigma * np.exp(
|
||||
-self.current_epoch / self.trainer.max_epochs)
|
||||
|
||||
|
Reference in New Issue
Block a user