[REFACTOR] Clean up GLVQ-types
This commit is contained in:
parent
34ffeb95bc
commit
d558fa6a4a
@ -119,28 +119,25 @@ class SiameseGLVQ(GLVQ):
|
||||
def configure_optimizers(self):
|
||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||
lr=self.hparams.proto_lr)
|
||||
optimizer = None
|
||||
if list(self.backbone.parameters()):
|
||||
# only add an optimizer is the backbone has trainable parameters
|
||||
# otherwise, the next line fails
|
||||
bb_opt = self.optimizer(self.backbone.parameters(),
|
||||
lr=self.hparams.bb_lr)
|
||||
optimizer = [proto_opt, bb_opt]
|
||||
# Only add a backbone optimizer if backbone has trainable parameters
|
||||
if (bb_params := list(self.backbone.parameters())):
|
||||
bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr)
|
||||
optimizers = [proto_opt, bb_opt]
|
||||
else:
|
||||
optimizer = proto_opt
|
||||
optimizers = [proto_opt]
|
||||
if self.lr_scheduler is not None:
|
||||
schedulers = []
|
||||
for optimizer in optimizers:
|
||||
scheduler = self.lr_scheduler(optimizer,
|
||||
**self.lr_scheduler_kwargs)
|
||||
sch = {
|
||||
"scheduler": scheduler,
|
||||
"interval": "step",
|
||||
} # called after each training step
|
||||
return optimizer, [sch]
|
||||
schedulers.append(scheduler)
|
||||
return optimizers, schedulers
|
||||
else:
|
||||
return optimizer
|
||||
return optimizers
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = get_flat(x, protos)
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
latent_protos = self.backbone(protos)
|
||||
@ -165,64 +162,6 @@ class SiameseGLVQ(GLVQ):
|
||||
return y_pred
|
||||
|
||||
|
||||
class GRLVQ(SiameseGLVQ):
|
||||
"""Generalized Relevance Learning Vector Quantization.
|
||||
|
||||
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
|
||||
"""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", omega_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
relevances = torch.ones(self.hparams.input_dim, device=self.device)
|
||||
self.register_parameter("_relevances", Parameter(relevances))
|
||||
# Override the backbone.
|
||||
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self.relevances),
|
||||
name="relevances")
|
||||
|
||||
@property
|
||||
def relevance_profile(self):
|
||||
return self.relevances.detach().cpu()
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos, torch.diag(self.relevances))
|
||||
return distances
|
||||
|
||||
|
||||
class SiameseGMLVQ(SiameseGLVQ):
|
||||
"""Generalized Matrix Learning Vector Quantization.
|
||||
|
||||
Implemented as a Siamese network with a linear transformation backbone.
|
||||
|
||||
"""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
# Override the backbone.
|
||||
self.backbone = torch.nn.Linear(self.hparams.input_dim,
|
||||
self.hparams.latent_dim,
|
||||
bias=False)
|
||||
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
return self.backbone.weight.detach().cpu()
|
||||
|
||||
@property
|
||||
def lambda_matrix(self):
|
||||
omega = self.backbone.weight # (latent_dim, input_dim)
|
||||
lam = omega.T @ omega
|
||||
return lam.detach().cpu()
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = get_flat(x, protos)
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
latent_protos = self.backbone(protos)
|
||||
self.backbone.requires_grad_(True)
|
||||
distances = self.distance_layer(latent_x, latent_protos)
|
||||
return distances
|
||||
|
||||
|
||||
class LVQMLN(SiameseGLVQ):
|
||||
"""Learning Vector Quantization Multi-Layer Network.
|
||||
|
||||
@ -239,21 +178,79 @@ class LVQMLN(SiameseGLVQ):
|
||||
return distances
|
||||
|
||||
|
||||
class GRLVQ(SiameseGLVQ):
|
||||
"""Generalized Relevance Learning Vector Quantization.
|
||||
|
||||
Implemented as a Siamese network with a linear transformation backbone.
|
||||
|
||||
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
|
||||
|
||||
"""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Additional parameters
|
||||
relevances = torch.ones(self.hparams.input_dim, device=self.device)
|
||||
self.register_parameter("_relevances", Parameter(relevances))
|
||||
|
||||
# Override the backbone
|
||||
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances),
|
||||
name="relevance scaling")
|
||||
|
||||
@property
|
||||
def relevance_profile(self):
|
||||
return self._relevances.detach().cpu()
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(relevances): (shape: {tuple(self._relevances.shape)})"
|
||||
|
||||
|
||||
class SiameseGMLVQ(SiameseGLVQ):
|
||||
"""Generalized Matrix Learning Vector Quantization.
|
||||
|
||||
Implemented as a Siamese network with a linear transformation backbone.
|
||||
|
||||
"""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Override the backbone
|
||||
self.backbone = torch.nn.Linear(self.hparams.input_dim,
|
||||
self.hparams.latent_dim,
|
||||
bias=False)
|
||||
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
return self.backbone.weight.detach().cpu()
|
||||
|
||||
@property
|
||||
def lambda_matrix(self):
|
||||
omega = self.backbone.weight # (latent_dim, input_dim)
|
||||
lam = omega.T @ omega
|
||||
return lam.detach().cpu()
|
||||
|
||||
|
||||
class GMLVQ(GLVQ):
|
||||
"""Generalized Matrix Learning Vector Quantization.
|
||||
|
||||
Implemented as a regular GLVQ network that simply uses a different distance
|
||||
function.
|
||||
function. This makes it easier to implement a localized variant.
|
||||
|
||||
"""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", omega_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
|
||||
# Additional parameters
|
||||
omega = torch.randn(self.hparams.input_dim,
|
||||
self.hparams.latent_dim,
|
||||
device=self.device)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
return self._omega.detach().cpu()
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos, self._omega)
|
||||
@ -268,6 +265,7 @@ class LGMLVQ(GMLVQ):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", lomega_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
|
||||
# Re-register `_omega` to override the one from the super class.
|
||||
omega = torch.randn(
|
||||
self.num_prototypes,
|
||||
|
Loading…
Reference in New Issue
Block a user