feat: gtlvq with examples
This commit is contained in:
		
				
					committed by
					
						
						Jensun Ravichandran
					
				
			
			
				
	
			
			
			
						parent
						
							6ffd27d12a
						
					
				
				
					commit
					d3bb430104
				
			
							
								
								
									
										130
									
								
								examples/gtlvq_moons.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								examples/gtlvq_moons.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,130 @@
 | 
				
			|||||||
 | 
					"""Localized-GMLVQ example using the Moons dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Reproducibility
 | 
				
			||||||
 | 
					    pl.utilities.seed.seed_everything(seed=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Dataset
 | 
				
			||||||
 | 
					    train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Dataloaders
 | 
				
			||||||
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds,
 | 
				
			||||||
 | 
					                                               batch_size=256,
 | 
				
			||||||
 | 
					                                               shuffle=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Hyperparameters
 | 
				
			||||||
 | 
					    hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Initialize the model
 | 
				
			||||||
 | 
					    model = pt.models.GTLVQ(
 | 
				
			||||||
 | 
					        hparams,
 | 
				
			||||||
 | 
					        prototypes_initializer=pt.initializers.SMCI(train_ds),
 | 
				
			||||||
 | 
					        omega_initializer=-pt.initializers.PCALinearTransformInitializer(
 | 
				
			||||||
 | 
					            train_ds))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 | 
					    model.example_input_array = torch.zeros(4, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Summary
 | 
				
			||||||
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Callbacks
 | 
				
			||||||
 | 
					    vis = pt.models.VisGLVQ2D(data=train_ds)
 | 
				
			||||||
 | 
					    es = pl.callbacks.EarlyStopping(
 | 
				
			||||||
 | 
					        monitor="train_acc",
 | 
				
			||||||
 | 
					        min_delta=0.001,
 | 
				
			||||||
 | 
					        patience=20,
 | 
				
			||||||
 | 
					        mode="max",
 | 
				
			||||||
 | 
					        verbose=False,
 | 
				
			||||||
 | 
					        check_on_train_epoch_end=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Setup trainer
 | 
				
			||||||
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
 | 
					        args,
 | 
				
			||||||
 | 
					        callbacks=[
 | 
				
			||||||
 | 
					            vis,
 | 
				
			||||||
 | 
					            es,
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        weights_summary="full",
 | 
				
			||||||
 | 
					        accelerator="ddp",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Training loop
 | 
				
			||||||
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
 | 
					"""Localized-GMLVQ example using the Moons dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Reproducibility
 | 
				
			||||||
 | 
					    pl.utilities.seed.seed_everything(seed=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Dataset
 | 
				
			||||||
 | 
					    train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Dataloaders
 | 
				
			||||||
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds,
 | 
				
			||||||
 | 
					                                               batch_size=256,
 | 
				
			||||||
 | 
					                                               shuffle=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Hyperparameters
 | 
				
			||||||
 | 
					    hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Initialize the model
 | 
				
			||||||
 | 
					    model = pt.models.GTLVQ(
 | 
				
			||||||
 | 
					        hparams,
 | 
				
			||||||
 | 
					        prototypes_initializer=pt.initializers.SMCI(train_ds),
 | 
				
			||||||
 | 
					        omega_initializer=-pt.initializers.PCALinearTransformInitializer(
 | 
				
			||||||
 | 
					            train_ds))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 | 
					    model.example_input_array = torch.zeros(4, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Summary
 | 
				
			||||||
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Callbacks
 | 
				
			||||||
 | 
					    vis = pt.models.VisGLVQ2D(data=train_ds)
 | 
				
			||||||
 | 
					    es = pl.callbacks.EarlyStopping(
 | 
				
			||||||
 | 
					        monitor="train_acc",
 | 
				
			||||||
 | 
					        min_delta=0.001,
 | 
				
			||||||
 | 
					        patience=20,
 | 
				
			||||||
 | 
					        mode="max",
 | 
				
			||||||
 | 
					        verbose=False,
 | 
				
			||||||
 | 
					        check_on_train_epoch_end=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Setup trainer
 | 
				
			||||||
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
 | 
					        args,
 | 
				
			||||||
 | 
					        callbacks=[
 | 
				
			||||||
 | 
					            vis,
 | 
				
			||||||
 | 
					            es,
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        weights_summary="full",
 | 
				
			||||||
 | 
					        accelerator="ddp",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Training loop
 | 
				
			||||||
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
@@ -8,6 +8,7 @@ from .glvq import (
 | 
				
			|||||||
    GLVQ21,
 | 
					    GLVQ21,
 | 
				
			||||||
    GMLVQ,
 | 
					    GMLVQ,
 | 
				
			||||||
    GRLVQ,
 | 
					    GRLVQ,
 | 
				
			||||||
 | 
					    GTLVQ,
 | 
				
			||||||
    LGMLVQ,
 | 
					    LGMLVQ,
 | 
				
			||||||
    LVQMLN,
 | 
					    LVQMLN,
 | 
				
			||||||
    ImageGLVQ,
 | 
					    ImageGLVQ,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,6 +15,44 @@ def rank_scaled_gaussian(distances, lambd):
 | 
				
			|||||||
    return torch.exp(-torch.exp(-ranks / lambd) * distances)
 | 
					    return torch.exp(-torch.exp(-ranks / lambd) * distances)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def orthogonalization(tensors):
 | 
				
			||||||
 | 
					    """Orthogonalization via polar decomposition """
 | 
				
			||||||
 | 
					    u, _, v = torch.svd(tensors, compute_uv=True)
 | 
				
			||||||
 | 
					    u_shape = tuple(list(u.shape))
 | 
				
			||||||
 | 
					    v_shape = tuple(list(v.shape))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # reshape to (num x N x M)
 | 
				
			||||||
 | 
					    u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
 | 
				
			||||||
 | 
					    v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    out = u @ v.permute([0, 2, 1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def ltangent_distance(x, y, omegas):
 | 
				
			||||||
 | 
					    r"""Localized Tangent distance.
 | 
				
			||||||
 | 
					    Compute Orthogonal Complement: math:`\bm P_k = \bm I - \Omega_k \Omega_k^T`
 | 
				
			||||||
 | 
					    Compute Tangent Distance: math:`{\| \bm P \bm x - \bm P_k \bm y_k \|}_2`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :param `torch.tensor` omegas: Three dimensional matrix
 | 
				
			||||||
 | 
					    :rtype: `torch.tensor`
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
 | 
				
			||||||
 | 
					    p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
 | 
				
			||||||
 | 
					        omegas, omegas.permute([0, 2, 1]))
 | 
				
			||||||
 | 
					    projected_x = x @ p
 | 
				
			||||||
 | 
					    projected_y = torch.diagonal(y @ p).T
 | 
				
			||||||
 | 
					    expanded_y = torch.unsqueeze(projected_y, dim=1)
 | 
				
			||||||
 | 
					    batchwise_difference = expanded_y - projected_x
 | 
				
			||||||
 | 
					    differences_squared = batchwise_difference**2
 | 
				
			||||||
 | 
					    distances = torch.sqrt(torch.sum(differences_squared, dim=2))
 | 
				
			||||||
 | 
					    distances = distances.permute(1, 0)
 | 
				
			||||||
 | 
					    return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GaussianPrior(torch.nn.Module):
 | 
					class GaussianPrior(torch.nn.Module):
 | 
				
			||||||
    def __init__(self, variance):
 | 
					    def __init__(self, variance):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,6 +10,7 @@ from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
 | 
				
			|||||||
from ..core.transforms import LinearTransform
 | 
					from ..core.transforms import LinearTransform
 | 
				
			||||||
from ..nn.wrappers import LambdaLayer, LossLayer
 | 
					from ..nn.wrappers import LambdaLayer, LossLayer
 | 
				
			||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
					from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
				
			||||||
 | 
					from .extras import ltangent_distance, orthogonalization
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GLVQ(SupervisedPrototypeModel):
 | 
					class GLVQ(SupervisedPrototypeModel):
 | 
				
			||||||
@@ -282,6 +283,30 @@ class LGMLVQ(GMLVQ):
 | 
				
			|||||||
        self.register_parameter("_omega", Parameter(omega))
 | 
					        self.register_parameter("_omega", Parameter(omega))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GTLVQ(LGMLVQ):
 | 
				
			||||||
 | 
					    """Localized and Generalized Matrix Learning Vector Quantization."""
 | 
				
			||||||
 | 
					    def __init__(self, hparams, **kwargs):
 | 
				
			||||||
 | 
					        distance_fn = kwargs.pop("distance_fn", ltangent_distance)
 | 
				
			||||||
 | 
					        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        omega_initializer = kwargs.get("omega_initializer")
 | 
				
			||||||
 | 
					        omega = omega_initializer.generate(self.hparams.input_dim,
 | 
				
			||||||
 | 
					                                           self.hparams.latent_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Re-register `_omega` to override the one from the super class.
 | 
				
			||||||
 | 
					        omega = torch.rand(
 | 
				
			||||||
 | 
					            self.num_prototypes,
 | 
				
			||||||
 | 
					            self.hparams.input_dim,
 | 
				
			||||||
 | 
					            self.hparams.latent_dim,
 | 
				
			||||||
 | 
					            device=self.device,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.register_parameter("_omega", Parameter(omega))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            self._omega.copy_(orthogonalization(self._omega))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GLVQ1(GLVQ):
 | 
					class GLVQ1(GLVQ):
 | 
				
			||||||
    """Generalized Learning Vector Quantization 1."""
 | 
					    """Generalized Learning Vector Quantization 1."""
 | 
				
			||||||
    def __init__(self, hparams, **kwargs):
 | 
					    def __init__(self, hparams, **kwargs):
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user