[REFACTOR] Clean up GLVQ-types
This commit is contained in:
parent
34ffeb95bc
commit
d558fa6a4a
@ -119,28 +119,25 @@ class SiameseGLVQ(GLVQ):
|
|||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||||
lr=self.hparams.proto_lr)
|
lr=self.hparams.proto_lr)
|
||||||
optimizer = None
|
# Only add a backbone optimizer if backbone has trainable parameters
|
||||||
if list(self.backbone.parameters()):
|
if (bb_params := list(self.backbone.parameters())):
|
||||||
# only add an optimizer is the backbone has trainable parameters
|
bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr)
|
||||||
# otherwise, the next line fails
|
optimizers = [proto_opt, bb_opt]
|
||||||
bb_opt = self.optimizer(self.backbone.parameters(),
|
|
||||||
lr=self.hparams.bb_lr)
|
|
||||||
optimizer = [proto_opt, bb_opt]
|
|
||||||
else:
|
else:
|
||||||
optimizer = proto_opt
|
optimizers = [proto_opt]
|
||||||
if self.lr_scheduler is not None:
|
if self.lr_scheduler is not None:
|
||||||
|
schedulers = []
|
||||||
|
for optimizer in optimizers:
|
||||||
scheduler = self.lr_scheduler(optimizer,
|
scheduler = self.lr_scheduler(optimizer,
|
||||||
**self.lr_scheduler_kwargs)
|
**self.lr_scheduler_kwargs)
|
||||||
sch = {
|
schedulers.append(scheduler)
|
||||||
"scheduler": scheduler,
|
return optimizers, schedulers
|
||||||
"interval": "step",
|
|
||||||
} # called after each training step
|
|
||||||
return optimizer, [sch]
|
|
||||||
else:
|
else:
|
||||||
return optimizer
|
return optimizers
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
|
x, protos = get_flat(x, protos)
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
self.backbone.requires_grad_(self.both_path_gradients)
|
self.backbone.requires_grad_(self.both_path_gradients)
|
||||||
latent_protos = self.backbone(protos)
|
latent_protos = self.backbone(protos)
|
||||||
@ -165,64 +162,6 @@ class SiameseGLVQ(GLVQ):
|
|||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
|
|
||||||
class GRLVQ(SiameseGLVQ):
|
|
||||||
"""Generalized Relevance Learning Vector Quantization.
|
|
||||||
|
|
||||||
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
|
|
||||||
"""
|
|
||||||
def __init__(self, hparams, **kwargs):
|
|
||||||
distance_fn = kwargs.pop("distance_fn", omega_distance)
|
|
||||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
|
||||||
relevances = torch.ones(self.hparams.input_dim, device=self.device)
|
|
||||||
self.register_parameter("_relevances", Parameter(relevances))
|
|
||||||
# Override the backbone.
|
|
||||||
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self.relevances),
|
|
||||||
name="relevances")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def relevance_profile(self):
|
|
||||||
return self.relevances.detach().cpu()
|
|
||||||
|
|
||||||
def compute_distances(self, x):
|
|
||||||
protos, _ = self.proto_layer()
|
|
||||||
distances = self.distance_layer(x, protos, torch.diag(self.relevances))
|
|
||||||
return distances
|
|
||||||
|
|
||||||
|
|
||||||
class SiameseGMLVQ(SiameseGLVQ):
|
|
||||||
"""Generalized Matrix Learning Vector Quantization.
|
|
||||||
|
|
||||||
Implemented as a Siamese network with a linear transformation backbone.
|
|
||||||
|
|
||||||
"""
|
|
||||||
def __init__(self, hparams, **kwargs):
|
|
||||||
super().__init__(hparams, **kwargs)
|
|
||||||
# Override the backbone.
|
|
||||||
self.backbone = torch.nn.Linear(self.hparams.input_dim,
|
|
||||||
self.hparams.latent_dim,
|
|
||||||
bias=False)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def omega_matrix(self):
|
|
||||||
return self.backbone.weight.detach().cpu()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lambda_matrix(self):
|
|
||||||
omega = self.backbone.weight # (latent_dim, input_dim)
|
|
||||||
lam = omega.T @ omega
|
|
||||||
return lam.detach().cpu()
|
|
||||||
|
|
||||||
def compute_distances(self, x):
|
|
||||||
protos, _ = self.proto_layer()
|
|
||||||
x, protos = get_flat(x, protos)
|
|
||||||
latent_x = self.backbone(x)
|
|
||||||
self.backbone.requires_grad_(self.both_path_gradients)
|
|
||||||
latent_protos = self.backbone(protos)
|
|
||||||
self.backbone.requires_grad_(True)
|
|
||||||
distances = self.distance_layer(latent_x, latent_protos)
|
|
||||||
return distances
|
|
||||||
|
|
||||||
|
|
||||||
class LVQMLN(SiameseGLVQ):
|
class LVQMLN(SiameseGLVQ):
|
||||||
"""Learning Vector Quantization Multi-Layer Network.
|
"""Learning Vector Quantization Multi-Layer Network.
|
||||||
|
|
||||||
@ -239,21 +178,79 @@ class LVQMLN(SiameseGLVQ):
|
|||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
class GRLVQ(SiameseGLVQ):
|
||||||
|
"""Generalized Relevance Learning Vector Quantization.
|
||||||
|
|
||||||
|
Implemented as a Siamese network with a linear transformation backbone.
|
||||||
|
|
||||||
|
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
|
# Additional parameters
|
||||||
|
relevances = torch.ones(self.hparams.input_dim, device=self.device)
|
||||||
|
self.register_parameter("_relevances", Parameter(relevances))
|
||||||
|
|
||||||
|
# Override the backbone
|
||||||
|
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances),
|
||||||
|
name="relevance scaling")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def relevance_profile(self):
|
||||||
|
return self._relevances.detach().cpu()
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"(relevances): (shape: {tuple(self._relevances.shape)})"
|
||||||
|
|
||||||
|
|
||||||
|
class SiameseGMLVQ(SiameseGLVQ):
|
||||||
|
"""Generalized Matrix Learning Vector Quantization.
|
||||||
|
|
||||||
|
Implemented as a Siamese network with a linear transformation backbone.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
|
# Override the backbone
|
||||||
|
self.backbone = torch.nn.Linear(self.hparams.input_dim,
|
||||||
|
self.hparams.latent_dim,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def omega_matrix(self):
|
||||||
|
return self.backbone.weight.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lambda_matrix(self):
|
||||||
|
omega = self.backbone.weight # (latent_dim, input_dim)
|
||||||
|
lam = omega.T @ omega
|
||||||
|
return lam.detach().cpu()
|
||||||
|
|
||||||
|
|
||||||
class GMLVQ(GLVQ):
|
class GMLVQ(GLVQ):
|
||||||
"""Generalized Matrix Learning Vector Quantization.
|
"""Generalized Matrix Learning Vector Quantization.
|
||||||
|
|
||||||
Implemented as a regular GLVQ network that simply uses a different distance
|
Implemented as a regular GLVQ network that simply uses a different distance
|
||||||
function.
|
function. This makes it easier to implement a localized variant.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
distance_fn = kwargs.pop("distance_fn", omega_distance)
|
distance_fn = kwargs.pop("distance_fn", omega_distance)
|
||||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
|
|
||||||
|
# Additional parameters
|
||||||
omega = torch.randn(self.hparams.input_dim,
|
omega = torch.randn(self.hparams.input_dim,
|
||||||
self.hparams.latent_dim,
|
self.hparams.latent_dim,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
self.register_parameter("_omega", Parameter(omega))
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def omega_matrix(self):
|
||||||
|
return self._omega.detach().cpu()
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
distances = self.distance_layer(x, protos, self._omega)
|
distances = self.distance_layer(x, protos, self._omega)
|
||||||
@ -268,6 +265,7 @@ class LGMLVQ(GMLVQ):
|
|||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
distance_fn = kwargs.pop("distance_fn", lomega_distance)
|
distance_fn = kwargs.pop("distance_fn", lomega_distance)
|
||||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
|
|
||||||
# Re-register `_omega` to override the one from the super class.
|
# Re-register `_omega` to override the one from the super class.
|
||||||
omega = torch.randn(
|
omega = torch.randn(
|
||||||
self.num_prototypes,
|
self.num_prototypes,
|
||||||
|
Loading…
Reference in New Issue
Block a user