fix: example test fixed
This commit is contained in:
parent
72e9587a10
commit
634ef86a2c
@ -51,8 +51,7 @@ if __name__ == "__main__":
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[1, 2, 3],
|
||||
proto_lr=0.01,
|
||||
bb_lr=0.01,
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the backbone
|
||||
|
@ -51,8 +51,7 @@ if __name__ == "__main__":
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[1, 2, 3],
|
||||
proto_lr=0.01,
|
||||
bb_lr=0.01,
|
||||
lr=0.01,
|
||||
input_dim=2,
|
||||
latent_dim=1,
|
||||
)
|
||||
|
@ -55,7 +55,9 @@ if __name__ == "__main__":
|
||||
|
||||
# Setup trainer for GNG
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=1000,
|
||||
accelerator="cpu",
|
||||
max_epochs=50 if args.fast_dev_run else
|
||||
1000, # 10 epochs fast dev run reproducible DIV error.
|
||||
callbacks=[
|
||||
es,
|
||||
],
|
||||
|
BIN
glvq_iris.ckpt
Normal file
BIN
glvq_iris.ckpt
Normal file
Binary file not shown.
@ -123,26 +123,6 @@ class SiameseGLVQ(GLVQ):
|
||||
self.backbone = backbone
|
||||
self.both_path_gradients = both_path_gradients
|
||||
|
||||
def configure_optimizers(self):
|
||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||
lr=self.hparams["proto_lr"])
|
||||
# Only add a backbone optimizer if backbone has trainable parameters
|
||||
bb_params = list(self.backbone.parameters())
|
||||
if (bb_params):
|
||||
bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"])
|
||||
optimizers = [proto_opt, bb_opt]
|
||||
else:
|
||||
optimizers = [proto_opt]
|
||||
if self.lr_scheduler is not None:
|
||||
schedulers = []
|
||||
for optimizer in optimizers:
|
||||
scheduler = self.lr_scheduler(optimizer,
|
||||
**self.lr_scheduler_kwargs)
|
||||
schedulers.append(scheduler)
|
||||
return optimizers, schedulers
|
||||
else:
|
||||
return optimizers
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
|
||||
|
@ -63,7 +63,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||
strict=False,
|
||||
)
|
||||
|
||||
def training_epoch_end(self, training_step_outputs):
|
||||
def on_training_epoch_end(self, training_step_outputs):
|
||||
self._sigma = self.hparams.sigma * np.exp(
|
||||
-self.current_epoch / self.trainer.max_epochs)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user