fix: example test fixed

This commit is contained in:
Alexander Engelsberger 2023-06-20 17:42:36 +02:00
parent 72e9587a10
commit 634ef86a2c
No known key found for this signature in database
7 changed files with 6 additions and 26 deletions

View File

@ -51,8 +51,7 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
distribution=[1, 2, 3],
proto_lr=0.01,
bb_lr=0.01,
lr=0.01,
)
# Initialize the backbone

View File

@ -51,8 +51,7 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
distribution=[1, 2, 3],
proto_lr=0.01,
bb_lr=0.01,
lr=0.01,
input_dim=2,
latent_dim=1,
)

View File

@ -55,7 +55,9 @@ if __name__ == "__main__":
# Setup trainer for GNG
trainer = pl.Trainer(
max_epochs=1000,
accelerator="cpu",
max_epochs=50 if args.fast_dev_run else
1000, # 10 epochs fast dev run reproducible DIV error.
callbacks=[
es,
],

BIN
glvq_iris.ckpt Normal file

Binary file not shown.

BIN
iris.pth Normal file

Binary file not shown.

View File

@ -123,26 +123,6 @@ class SiameseGLVQ(GLVQ):
self.backbone = backbone
self.both_path_gradients = both_path_gradients
def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams["proto_lr"])
# Only add a backbone optimizer if backbone has trainable parameters
bb_params = list(self.backbone.parameters())
if (bb_params):
bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"])
optimizers = [proto_opt, bb_opt]
else:
optimizers = [proto_opt]
if self.lr_scheduler is not None:
schedulers = []
for optimizer in optimizers:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
schedulers.append(scheduler)
return optimizers, schedulers
else:
return optimizers
def compute_distances(self, x):
protos, _ = self.proto_layer()
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))

View File

@ -63,7 +63,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
strict=False,
)
def training_epoch_end(self, training_step_outputs):
def on_training_epoch_end(self, training_step_outputs):
self._sigma = self.hparams.sigma * np.exp(
-self.current_epoch / self.trainer.max_epochs)