Add prototype_initializer function to GLVQ

This allows overwriting it inside subclasses.
This commit is contained in:
Alexander Engelsberger 2021-05-21 17:11:27 +02:00
parent 7b4f7d84e0
commit 8ce18f83ce

View File

@ -22,7 +22,6 @@ class GLVQ(AbstractPrototypeModel):
self.distance_fn = kwargs.get("distance_fn", euclidean_distance) self.distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.optimizer = kwargs.get("optimizer", torch.optim.Adam) self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
prototype_initializer = kwargs.get("prototype_initializer", None)
# Default Values # Default Values
self.hparams.setdefault("transfer_fn", "identity") self.hparams.setdefault("transfer_fn", "identity")
@ -31,13 +30,16 @@ class GLVQ(AbstractPrototypeModel):
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution, distribution=self.hparams.distribution,
initializer=prototype_initializer) initializer=self.prototype_initializer(**kwargs))
self.transfer_fn = get_activation(self.hparams.transfer_fn) self.transfer_fn = get_activation(self.hparams.transfer_fn)
self.acc_metric = torchmetrics.Accuracy() self.acc_metric = torchmetrics.Accuracy()
self.loss = glvq_loss self.loss = glvq_loss
def prototype_initializer(self, **kwargs):
return kwargs.get("prototype_initializer", None)
@property @property
def prototype_labels(self): def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu() return self.proto_layer.component_labels.detach().cpu()