Added Vis for GMLVQ with more then 2 dims using PCA (#11)
* Added Vis for GMLVQ with more then 2 dims using PCA * Added initialization possibility to GMlVQ with PCA and one example with omega init + PCA vis of 3 dims * test(githooks): Add githooks for automatic commit checks Co-authored-by: staps@hs-mittweida.de <staps@hs-mittweida.de> Co-authored-by: Alexander Engelsberger <alexanderengelsberger@gmail.com>
This commit is contained in:
		
							
								
								
									
										59
									
								
								examples/gmlvq_iris.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								examples/gmlvq_iris.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,59 @@
 | 
				
			|||||||
 | 
					"""GLVQ example using the Iris dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.optim.lr_scheduler import ExponentialLR
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Dataset
 | 
				
			||||||
 | 
					    train_ds = pt.datasets.Iris()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Dataloaders
 | 
				
			||||||
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Hyperparameters
 | 
				
			||||||
 | 
					    hparams = dict(
 | 
				
			||||||
 | 
					        input_dim=4,
 | 
				
			||||||
 | 
					        latent_dim=3,
 | 
				
			||||||
 | 
					        distribution={
 | 
				
			||||||
 | 
					            "num_classes": 3,
 | 
				
			||||||
 | 
					            "prototypes_per_class": 2
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        proto_lr=0.0005,
 | 
				
			||||||
 | 
					        bb_lr=0.0005,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Initialize the model
 | 
				
			||||||
 | 
					    model = pt.models.GMLVQ(
 | 
				
			||||||
 | 
					        hparams,
 | 
				
			||||||
 | 
					        optimizer=torch.optim.Adam,
 | 
				
			||||||
 | 
					        prototype_initializer=pt.components.SSI(train_ds),
 | 
				
			||||||
 | 
					        lr_scheduler=ExponentialLR,
 | 
				
			||||||
 | 
					        lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
 | 
				
			||||||
 | 
					        omega_initializer=pt.components.PCA(train_ds.data)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 | 
					    #model.example_input_array = torch.zeros(4, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Callbacks
 | 
				
			||||||
 | 
					    vis = pt.models.VisGMLVQ2D(data=train_ds, border=0.1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Setup trainer
 | 
				
			||||||
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
 | 
					        args,
 | 
				
			||||||
 | 
					        callbacks=[vis],
 | 
				
			||||||
 | 
					        weights_summary="full",
 | 
				
			||||||
 | 
					        accelerator="ddp",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Training loop
 | 
				
			||||||
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
							
								
								
									
										
											BIN
										
									
								
								prototorch/models/.glvq.py.swp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								prototorch/models/.glvq.py.swp
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							@@ -7,6 +7,7 @@ from prototorch.functions.distances import (lomega_distance, omega_distance,
 | 
				
			|||||||
                                            squared_euclidean_distance)
 | 
					                                            squared_euclidean_distance)
 | 
				
			||||||
from prototorch.functions.helper import get_flat
 | 
					from prototorch.functions.helper import get_flat
 | 
				
			||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
					from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
				
			||||||
 | 
					from prototorch.components import LinearMapping
 | 
				
			||||||
from prototorch.modules import LambdaLayer, LossLayer
 | 
					from prototorch.modules import LambdaLayer, LossLayer
 | 
				
			||||||
from torch.nn.parameter import Parameter
 | 
					from torch.nn.parameter import Parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -239,10 +240,17 @@ class GMLVQ(GLVQ):
 | 
				
			|||||||
        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
					        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Additional parameters
 | 
					        # Additional parameters
 | 
				
			||||||
        omega = torch.randn(self.hparams.input_dim,
 | 
					        omega_initializer = kwargs.get("omega_initializer", None)
 | 
				
			||||||
                            self.hparams.latent_dim,
 | 
					        initialized_omega = kwargs.get("initialized_omega", None)
 | 
				
			||||||
                            device=self.device)
 | 
					        if omega_initializer is not None or initialized_omega is not None:
 | 
				
			||||||
        self.register_parameter("_omega", Parameter(omega))
 | 
					            self.omega_layer = LinearMapping(
 | 
				
			||||||
 | 
					                mapping_shape=(self.hparams.input_dim, self.hparams.latent_dim),
 | 
				
			||||||
 | 
					                initializer=omega_initializer,
 | 
				
			||||||
 | 
					                initialized_linearmapping=initialized_omega,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.register_parameter("_omega", Parameter(self.omega_layer.mapping))
 | 
				
			||||||
 | 
					        self.backbone = LambdaLayer(lambda x: x @ self._omega, name = "omega matrix")
 | 
				
			||||||
       
 | 
					       
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def omega_matrix(self):
 | 
					    def omega_matrix(self):
 | 
				
			||||||
@@ -256,6 +264,24 @@ class GMLVQ(GLVQ):
 | 
				
			|||||||
    def extra_repr(self):
 | 
					    def extra_repr(self):
 | 
				
			||||||
        return f"(omega): (shape: {tuple(self._omega.shape)})"
 | 
					        return f"(omega): (shape: {tuple(self._omega.shape)})"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def predict_latent(self, x, map_protos=True):
 | 
				
			||||||
 | 
					        """Predict `x` assuming it is already embedded in the latent space.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Only the prototypes are embedded in the latent space using the
 | 
				
			||||||
 | 
					        backbone.
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.eval()
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            protos, plabels = self.proto_layer()
 | 
				
			||||||
 | 
					            if map_protos:
 | 
				
			||||||
 | 
					                protos = self.backbone(protos)
 | 
				
			||||||
 | 
					            d = squared_euclidean_distance(x, protos)
 | 
				
			||||||
 | 
					            y_pred = wtac(d, plabels)
 | 
				
			||||||
 | 
					        return y_pred
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LGMLVQ(GMLVQ):
 | 
					class LGMLVQ(GMLVQ):
 | 
				
			||||||
    """Localized and Generalized Matrix Learning Vector Quantization."""
 | 
					    """Localized and Generalized Matrix Learning Vector Quantization."""
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -83,7 +83,13 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
        mesh_input = np.c_[xx.ravel(), yy.ravel()]
 | 
					        mesh_input = np.c_[xx.ravel(), yy.ravel()]
 | 
				
			||||||
        return mesh_input, xx, yy
 | 
					        return mesh_input, xx, yy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def plot_data(self, ax, x, y):
 | 
					    def perform_pca_2D(self, data):
 | 
				
			||||||
 | 
					        (_, eigVal, eigVec) = torch.pca_lowrank(data, q=2)
 | 
				
			||||||
 | 
					        return data @ eigVec
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def plot_data(self, ax, x, y, pca=False):
 | 
				
			||||||
 | 
					        if pca:
 | 
				
			||||||
 | 
					            x = self.perform_pca_2D(x)
 | 
				
			||||||
        ax.scatter(
 | 
					        ax.scatter(
 | 
				
			||||||
            x[:, 0],
 | 
					            x[:, 0],
 | 
				
			||||||
            x[:, 1],
 | 
					            x[:, 1],
 | 
				
			||||||
@@ -94,7 +100,9 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
            s=30,
 | 
					            s=30,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def plot_protos(self, ax, protos, plabels):
 | 
					    def plot_protos(self, ax, protos, plabels, pca=False):
 | 
				
			||||||
 | 
					        if pca:
 | 
				
			||||||
 | 
					            protos = self.perform_pca_2D(protos)
 | 
				
			||||||
        ax.scatter(
 | 
					        ax.scatter(
 | 
				
			||||||
            protos[:, 0],
 | 
					            protos[:, 0],
 | 
				
			||||||
            protos[:, 1],
 | 
					            protos[:, 1],
 | 
				
			||||||
@@ -186,6 +194,50 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
        self.log_and_display(trainer, pl_module)
 | 
					        self.log_and_display(trainer, pl_module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VisGMLVQ2D(Vis2DAbstract):
 | 
				
			||||||
 | 
					    def __init__(self, *args, map_protos=True, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					        self.map_protos = map_protos
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
 | 
					        if not self.precheck(trainer):
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        protos = pl_module.prototypes
 | 
				
			||||||
 | 
					        plabels = pl_module.prototype_labels
 | 
				
			||||||
 | 
					        x_train, y_train = self.x_train, self.y_train
 | 
				
			||||||
 | 
					        device = pl_module.device
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            x_train = pl_module.backbone(torch.Tensor(x_train).to(device))
 | 
				
			||||||
 | 
					            x_train = x_train.cpu().detach()
 | 
				
			||||||
 | 
					        if self.map_protos:
 | 
				
			||||||
 | 
					            with torch.no_grad():
 | 
				
			||||||
 | 
					                protos = pl_module.backbone(torch.Tensor(protos).to(device))
 | 
				
			||||||
 | 
					                protos = protos.cpu().detach()
 | 
				
			||||||
 | 
					        ax = self.setup_ax()
 | 
				
			||||||
 | 
					        if x_train.shape[1] > 2:
 | 
				
			||||||
 | 
					            self.plot_data(ax, x_train, y_train, pca=True)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.plot_data(ax, x_train, y_train, pca=False)
 | 
				
			||||||
 | 
					        if self.show_protos:
 | 
				
			||||||
 | 
					            if protos.shape[1] > 2:
 | 
				
			||||||
 | 
					                self.plot_protos(ax, protos, plabels, pca=True)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.plot_protos(ax, protos, plabels, pca=False)
 | 
				
			||||||
 | 
					        ### something to work on: meshgrid with pca
 | 
				
			||||||
 | 
					        #    x = np.vstack((x_train, protos))
 | 
				
			||||||
 | 
					        #    mesh_input, xx, yy = self.get_mesh_input(x)
 | 
				
			||||||
 | 
					        #else:
 | 
				
			||||||
 | 
					        #    mesh_input, xx, yy = self.get_mesh_input(x_train)
 | 
				
			||||||
 | 
					        #_components = pl_module.proto_layer._components
 | 
				
			||||||
 | 
					        #mesh_input = torch.Tensor(mesh_input).type_as(_components)
 | 
				
			||||||
 | 
					        #y_pred = pl_module.predict_latent(mesh_input,
 | 
				
			||||||
 | 
					        #                                  map_protos=self.map_protos)
 | 
				
			||||||
 | 
					        #y_pred = y_pred.cpu().reshape(xx.shape)
 | 
				
			||||||
 | 
					        #ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
				
			||||||
 | 
					        self.log_and_display(trainer, pl_module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisCBC2D(Vis2DAbstract):
 | 
					class VisCBC2D(Vis2DAbstract):
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
        if not self.precheck(trainer):
 | 
					        if not self.precheck(trainer):
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user