[REFACTOR] Remove prototype_initializer function from GLVQ

Fixes #9
This commit is contained in:
Alexander Engelsberger 2021-06-03 15:15:22 +02:00
parent bda88149d4
commit 358f27257d
2 changed files with 8 additions and 6 deletions

View File

@ -53,7 +53,11 @@ def train_on_mnist(batch_size=256) -> type:
class DataClass(pl.LightningModule): class DataClass(pl.LightningModule):
datamodule = MNISTDataModule(batch_size=batch_size) datamodule = MNISTDataModule(batch_size=batch_size)
def prototype_initializer(self, **kwargs): def __init__(self, *args, **kwargs):
return pt.components.Zeros((28, 28, 1)) prototype_initializer = kwargs.pop(
"prototype_initializer", pt.components.Zeros((28, 28, 1)))
super().__init__(*args,
prototype_initializer=prototype_initializer,
**kwargs)
return DataClass return DataClass

View File

@ -34,9 +34,10 @@ class GLVQ(AbstractPrototypeModel):
transfer_fn = get_activation(self.hparams.transfer_fn) transfer_fn = get_activation(self.hparams.transfer_fn)
# Layers # Layers
prototype_initializer = kwargs.get("prototype_initializer", None)
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution, distribution=self.hparams.distribution,
initializer=self.prototype_initializer(**kwargs)) initializer=prototype_initializer)
self.distance_layer = LambdaLayer(distance_fn) self.distance_layer = LambdaLayer(distance_fn)
self.transfer_layer = LambdaLayer(transfer_fn) self.transfer_layer = LambdaLayer(transfer_fn)
@ -47,9 +48,6 @@ class GLVQ(AbstractPrototypeModel):
self.optimizer = kwargs.get("optimizer", torch.optim.Adam) self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
def prototype_initializer(self, **kwargs):
return kwargs.get("prototype_initializer", None)
@property @property
def prototype_labels(self): def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu() return self.proto_layer.component_labels.detach().cpu()