parent
bda88149d4
commit
358f27257d
@ -53,7 +53,11 @@ def train_on_mnist(batch_size=256) -> type:
|
|||||||
class DataClass(pl.LightningModule):
|
class DataClass(pl.LightningModule):
|
||||||
datamodule = MNISTDataModule(batch_size=batch_size)
|
datamodule = MNISTDataModule(batch_size=batch_size)
|
||||||
|
|
||||||
def prototype_initializer(self, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
return pt.components.Zeros((28, 28, 1))
|
prototype_initializer = kwargs.pop(
|
||||||
|
"prototype_initializer", pt.components.Zeros((28, 28, 1)))
|
||||||
|
super().__init__(*args,
|
||||||
|
prototype_initializer=prototype_initializer,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
return DataClass
|
return DataClass
|
||||||
|
@ -34,9 +34,10 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
transfer_fn = get_activation(self.hparams.transfer_fn)
|
transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
|
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||||
self.proto_layer = LabeledComponents(
|
self.proto_layer = LabeledComponents(
|
||||||
distribution=self.hparams.distribution,
|
distribution=self.hparams.distribution,
|
||||||
initializer=self.prototype_initializer(**kwargs))
|
initializer=prototype_initializer)
|
||||||
|
|
||||||
self.distance_layer = LambdaLayer(distance_fn)
|
self.distance_layer = LambdaLayer(distance_fn)
|
||||||
self.transfer_layer = LambdaLayer(transfer_fn)
|
self.transfer_layer = LambdaLayer(transfer_fn)
|
||||||
@ -47,9 +48,6 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||||
|
|
||||||
def prototype_initializer(self, **kwargs):
|
|
||||||
return kwargs.get("prototype_initializer", None)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prototype_labels(self):
|
def prototype_labels(self):
|
||||||
return self.proto_layer.component_labels.detach().cpu()
|
return self.proto_layer.component_labels.detach().cpu()
|
||||||
|
Loading…
Reference in New Issue
Block a user