Add prototype_initializer function to GLVQ
This allows overwriting it inside subclasses.
This commit is contained in:
parent
7b4f7d84e0
commit
8ce18f83ce
@ -22,7 +22,6 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
self.distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||
|
||||
# Default Values
|
||||
self.hparams.setdefault("transfer_fn", "identity")
|
||||
@ -31,13 +30,16 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
self.proto_layer = LabeledComponents(
|
||||
distribution=self.hparams.distribution,
|
||||
initializer=prototype_initializer)
|
||||
initializer=self.prototype_initializer(**kwargs))
|
||||
|
||||
self.transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||
self.acc_metric = torchmetrics.Accuracy()
|
||||
|
||||
self.loss = glvq_loss
|
||||
|
||||
def prototype_initializer(self, **kwargs):
|
||||
return kwargs.get("prototype_initializer", None)
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
return self.proto_layer.component_labels.detach().cpu()
|
||||
|
Loading…
Reference in New Issue
Block a user