fix: example test fixed
This commit is contained in:
parent
72e9587a10
commit
634ef86a2c
@ -51,8 +51,7 @@ if __name__ == "__main__":
|
|||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=[1, 2, 3],
|
distribution=[1, 2, 3],
|
||||||
proto_lr=0.01,
|
lr=0.01,
|
||||||
bb_lr=0.01,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the backbone
|
# Initialize the backbone
|
||||||
|
@ -51,8 +51,7 @@ if __name__ == "__main__":
|
|||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=[1, 2, 3],
|
distribution=[1, 2, 3],
|
||||||
proto_lr=0.01,
|
lr=0.01,
|
||||||
bb_lr=0.01,
|
|
||||||
input_dim=2,
|
input_dim=2,
|
||||||
latent_dim=1,
|
latent_dim=1,
|
||||||
)
|
)
|
||||||
|
@ -55,7 +55,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Setup trainer for GNG
|
# Setup trainer for GNG
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
max_epochs=1000,
|
accelerator="cpu",
|
||||||
|
max_epochs=50 if args.fast_dev_run else
|
||||||
|
1000, # 10 epochs fast dev run reproducible DIV error.
|
||||||
callbacks=[
|
callbacks=[
|
||||||
es,
|
es,
|
||||||
],
|
],
|
||||||
|
BIN
glvq_iris.ckpt
Normal file
BIN
glvq_iris.ckpt
Normal file
Binary file not shown.
@ -123,26 +123,6 @@ class SiameseGLVQ(GLVQ):
|
|||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.both_path_gradients = both_path_gradients
|
self.both_path_gradients = both_path_gradients
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
|
||||||
lr=self.hparams["proto_lr"])
|
|
||||||
# Only add a backbone optimizer if backbone has trainable parameters
|
|
||||||
bb_params = list(self.backbone.parameters())
|
|
||||||
if (bb_params):
|
|
||||||
bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"])
|
|
||||||
optimizers = [proto_opt, bb_opt]
|
|
||||||
else:
|
|
||||||
optimizers = [proto_opt]
|
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
schedulers = []
|
|
||||||
for optimizer in optimizers:
|
|
||||||
scheduler = self.lr_scheduler(optimizer,
|
|
||||||
**self.lr_scheduler_kwargs)
|
|
||||||
schedulers.append(scheduler)
|
|
||||||
return optimizers, schedulers
|
|
||||||
else:
|
|
||||||
return optimizers
|
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
|
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
|
||||||
|
@ -63,7 +63,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
|||||||
strict=False,
|
strict=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def training_epoch_end(self, training_step_outputs):
|
def on_training_epoch_end(self, training_step_outputs):
|
||||||
self._sigma = self.hparams.sigma * np.exp(
|
self._sigma = self.hparams.sigma * np.exp(
|
||||||
-self.current_epoch / self.trainer.max_epochs)
|
-self.current_epoch / self.trainer.max_epochs)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user