Stop passing component initializers as hparams

Pass the component initializer as an hparam slows down the script very much. The
API has now been changed to pass it as a kwarg to the models instead.

The example scripts have also been updated to reflect the new changes.

Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also
been added.
This commit is contained in:
Jensun Ravichandran
2021-05-12 16:36:22 +02:00
parent 1498c4bde5
commit ca39aa00d5
11 changed files with 172 additions and 21 deletions

View File

@@ -8,6 +8,11 @@ class AbstractPrototypeModel(pl.LightningModule):
def prototypes(self):
return self.proto_layer.components.detach().cpu()
@property
def components(self):
"""Only an alias for the prototypes."""
return self.prototypes
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
@@ -19,3 +24,8 @@ class AbstractPrototypeModel(pl.LightningModule):
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class PrototypeImageModel(pl.LightningModule):
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.components.data.clamp_(0.0, 1.0)