fix: example test fixed

This commit is contained in:
Alexander Engelsberger 2023-06-20 17:42:36 +02:00
parent 72e9587a10
commit 634ef86a2c
No known key found for this signature in database
7 changed files with 6 additions and 26 deletions

View File

@ -51,8 +51,7 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
distribution=[1, 2, 3], distribution=[1, 2, 3],
proto_lr=0.01, lr=0.01,
bb_lr=0.01,
) )
# Initialize the backbone # Initialize the backbone

View File

@ -51,8 +51,7 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
distribution=[1, 2, 3], distribution=[1, 2, 3],
proto_lr=0.01, lr=0.01,
bb_lr=0.01,
input_dim=2, input_dim=2,
latent_dim=1, latent_dim=1,
) )

View File

@ -55,7 +55,9 @@ if __name__ == "__main__":
# Setup trainer for GNG # Setup trainer for GNG
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=1000, accelerator="cpu",
max_epochs=50 if args.fast_dev_run else
1000, # 10 epochs fast dev run reproducible DIV error.
callbacks=[ callbacks=[
es, es,
], ],

BIN
glvq_iris.ckpt Normal file

Binary file not shown.

BIN
iris.pth Normal file

Binary file not shown.

View File

@ -123,26 +123,6 @@ class SiameseGLVQ(GLVQ):
self.backbone = backbone self.backbone = backbone
self.both_path_gradients = both_path_gradients self.both_path_gradients = both_path_gradients
def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams["proto_lr"])
# Only add a backbone optimizer if backbone has trainable parameters
bb_params = list(self.backbone.parameters())
if (bb_params):
bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"])
optimizers = [proto_opt, bb_opt]
else:
optimizers = [proto_opt]
if self.lr_scheduler is not None:
schedulers = []
for optimizer in optimizers:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
schedulers.append(scheduler)
return optimizers, schedulers
else:
return optimizers
def compute_distances(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos)) x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))

View File

@ -63,7 +63,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
strict=False, strict=False,
) )
def training_epoch_end(self, training_step_outputs): def on_training_epoch_end(self, training_step_outputs):
self._sigma = self.hparams.sigma * np.exp( self._sigma = self.hparams.sigma * np.exp(
-self.current_epoch / self.trainer.max_epochs) -self.current_epoch / self.trainer.max_epochs)