[REFACTOR] Clean up GLVQ-types

This commit is contained in:
Jensun Ravichandran 2021-06-07 17:00:38 +02:00
parent 34ffeb95bc
commit d558fa6a4a

View File

@ -119,28 +119,25 @@ class SiameseGLVQ(GLVQ):
def configure_optimizers(self): def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(), proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams.proto_lr) lr=self.hparams.proto_lr)
optimizer = None # Only add a backbone optimizer if backbone has trainable parameters
if list(self.backbone.parameters()): if (bb_params := list(self.backbone.parameters())):
# only add an optimizer is the backbone has trainable parameters bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr)
# otherwise, the next line fails optimizers = [proto_opt, bb_opt]
bb_opt = self.optimizer(self.backbone.parameters(),
lr=self.hparams.bb_lr)
optimizer = [proto_opt, bb_opt]
else: else:
optimizer = proto_opt optimizers = [proto_opt]
if self.lr_scheduler is not None: if self.lr_scheduler is not None:
schedulers = []
for optimizer in optimizers:
scheduler = self.lr_scheduler(optimizer, scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs) **self.lr_scheduler_kwargs)
sch = { schedulers.append(scheduler)
"scheduler": scheduler, return optimizers, schedulers
"interval": "step",
} # called after each training step
return optimizer, [sch]
else: else:
return optimizer return optimizers
def compute_distances(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
x, protos = get_flat(x, protos)
latent_x = self.backbone(x) latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients) self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos) latent_protos = self.backbone(protos)
@ -165,64 +162,6 @@ class SiameseGLVQ(GLVQ):
return y_pred return y_pred
class GRLVQ(SiameseGLVQ):
"""Generalized Relevance Learning Vector Quantization.
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
"""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
relevances = torch.ones(self.hparams.input_dim, device=self.device)
self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone.
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self.relevances),
name="relevances")
@property
def relevance_profile(self):
return self.relevances.detach().cpu()
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, torch.diag(self.relevances))
return distances
class SiameseGMLVQ(SiameseGLVQ):
"""Generalized Matrix Learning Vector Quantization.
Implemented as a Siamese network with a linear transformation backbone.
"""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Override the backbone.
self.backbone = torch.nn.Linear(self.hparams.input_dim,
self.hparams.latent_dim,
bias=False)
@property
def omega_matrix(self):
return self.backbone.weight.detach().cpu()
@property
def lambda_matrix(self):
omega = self.backbone.weight # (latent_dim, input_dim)
lam = omega.T @ omega
return lam.detach().cpu()
def compute_distances(self, x):
protos, _ = self.proto_layer()
x, protos = get_flat(x, protos)
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True)
distances = self.distance_layer(latent_x, latent_protos)
return distances
class LVQMLN(SiameseGLVQ): class LVQMLN(SiameseGLVQ):
"""Learning Vector Quantization Multi-Layer Network. """Learning Vector Quantization Multi-Layer Network.
@ -239,21 +178,79 @@ class LVQMLN(SiameseGLVQ):
return distances return distances
class GRLVQ(SiameseGLVQ):
"""Generalized Relevance Learning Vector Quantization.
Implemented as a Siamese network with a linear transformation backbone.
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
"""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Additional parameters
relevances = torch.ones(self.hparams.input_dim, device=self.device)
self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances),
name="relevance scaling")
@property
def relevance_profile(self):
return self._relevances.detach().cpu()
def extra_repr(self):
return f"(relevances): (shape: {tuple(self._relevances.shape)})"
class SiameseGMLVQ(SiameseGLVQ):
"""Generalized Matrix Learning Vector Quantization.
Implemented as a Siamese network with a linear transformation backbone.
"""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Override the backbone
self.backbone = torch.nn.Linear(self.hparams.input_dim,
self.hparams.latent_dim,
bias=False)
@property
def omega_matrix(self):
return self.backbone.weight.detach().cpu()
@property
def lambda_matrix(self):
omega = self.backbone.weight # (latent_dim, input_dim)
lam = omega.T @ omega
return lam.detach().cpu()
class GMLVQ(GLVQ): class GMLVQ(GLVQ):
"""Generalized Matrix Learning Vector Quantization. """Generalized Matrix Learning Vector Quantization.
Implemented as a regular GLVQ network that simply uses a different distance Implemented as a regular GLVQ network that simply uses a different distance
function. function. This makes it easier to implement a localized variant.
""" """
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", omega_distance) distance_fn = kwargs.pop("distance_fn", omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters
omega = torch.randn(self.hparams.input_dim, omega = torch.randn(self.hparams.input_dim,
self.hparams.latent_dim, self.hparams.latent_dim,
device=self.device) device=self.device)
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
@property
def omega_matrix(self):
return self._omega.detach().cpu()
def compute_distances(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega) distances = self.distance_layer(x, protos, self._omega)
@ -268,6 +265,7 @@ class LGMLVQ(GMLVQ):
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", lomega_distance) distance_fn = kwargs.pop("distance_fn", lomega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Re-register `_omega` to override the one from the super class. # Re-register `_omega` to override the one from the super class.
omega = torch.randn( omega = torch.randn(
self.num_prototypes, self.num_prototypes,