Compare commits
	
		
			5 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					
						
						
							
						
						adafb49985
	
				 | 
					
					
						||
| 
						 | 
					
						
						
							
						
						78f8b6cc00
	
				 | 
					
					
						||
| 
						 | 
					
						
						
							
						
						c6f718a1d4
	
				 | 
					
					
						||
| 
						 | 
					
						
						
							
						
						1786031b4e
	
				 | 
					
					
						||
| 
						 | 
					
						
						
							
						
						824dfced92
	
				 | 
					
					
						
@@ -1,13 +1,15 @@
 | 
			
		||||
"""Models based on the GLVQ framework."""
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from numpy.typing import NDArray
 | 
			
		||||
from prototorch.core.competitions import wtac
 | 
			
		||||
from prototorch.core.distances import (
 | 
			
		||||
    ML_omega_distance,
 | 
			
		||||
    lomega_distance,
 | 
			
		||||
    omega_distance,
 | 
			
		||||
    squared_euclidean_distance,
 | 
			
		||||
)
 | 
			
		||||
from prototorch.core.initializers import EyeLinearTransformInitializer
 | 
			
		||||
from prototorch.core.initializers import LLTI, EyeLinearTransformInitializer
 | 
			
		||||
from prototorch.core.losses import (
 | 
			
		||||
    GLVQLoss,
 | 
			
		||||
    lvq1_loss,
 | 
			
		||||
@@ -15,7 +17,7 @@ from prototorch.core.losses import (
 | 
			
		||||
)
 | 
			
		||||
from prototorch.core.transforms import LinearTransform
 | 
			
		||||
from prototorch.nn.wrappers import LambdaLayer, LossLayer
 | 
			
		||||
from torch.nn.parameter import Parameter
 | 
			
		||||
from torch.nn import Parameter, ParameterList
 | 
			
		||||
 | 
			
		||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
			
		||||
from .extras import ltangent_distance, orthogonalization
 | 
			
		||||
@@ -45,26 +47,28 @@ class GLVQ(SupervisedPrototypeModel):
 | 
			
		||||
 | 
			
		||||
    def initialize_prototype_win_ratios(self):
 | 
			
		||||
        self.register_buffer(
 | 
			
		||||
            "prototype_win_ratios",
 | 
			
		||||
            torch.zeros(self.num_prototypes, device=self.device))
 | 
			
		||||
            "prototype_win_ratios", torch.zeros(self.num_prototypes, device=self.device)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def on_train_epoch_start(self):
 | 
			
		||||
        self.initialize_prototype_win_ratios()
 | 
			
		||||
 | 
			
		||||
    def log_prototype_win_ratios(self, distances):
 | 
			
		||||
        batch_size = len(distances)
 | 
			
		||||
        prototype_wc = torch.zeros(self.num_prototypes,
 | 
			
		||||
                                   dtype=torch.long,
 | 
			
		||||
                                   device=self.device)
 | 
			
		||||
        wi, wc = torch.unique(distances.min(dim=-1).indices,
 | 
			
		||||
                              sorted=True,
 | 
			
		||||
                              return_counts=True)
 | 
			
		||||
        prototype_wc = torch.zeros(
 | 
			
		||||
            self.num_prototypes, dtype=torch.long, device=self.device
 | 
			
		||||
        )
 | 
			
		||||
        wi, wc = torch.unique(
 | 
			
		||||
            distances.min(dim=-1).indices, sorted=True, return_counts=True
 | 
			
		||||
        )
 | 
			
		||||
        prototype_wc[wi] = wc
 | 
			
		||||
        prototype_wr = prototype_wc / batch_size
 | 
			
		||||
        self.prototype_win_ratios = torch.vstack([
 | 
			
		||||
        self.prototype_win_ratios = torch.vstack(
 | 
			
		||||
            [
 | 
			
		||||
                self.prototype_win_ratios,
 | 
			
		||||
                prototype_wr,
 | 
			
		||||
        ])
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def shared_step(self, batch, batch_idx):
 | 
			
		||||
        x, y = batch
 | 
			
		||||
@@ -109,11 +113,9 @@ class SiameseGLVQ(GLVQ):
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 hparams,
 | 
			
		||||
                 backbone=torch.nn.Identity(),
 | 
			
		||||
                 both_path_gradients=False,
 | 
			
		||||
                 **kwargs):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, hparams, backbone=torch.nn.Identity(), both_path_gradients=False, **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
 | 
			
		||||
        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
			
		||||
        self.backbone = backbone
 | 
			
		||||
@@ -175,6 +177,7 @@ class GRLVQ(SiameseGLVQ):
 | 
			
		||||
    TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    _relevances: torch.Tensor
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hparams, **kwargs):
 | 
			
		||||
@@ -185,8 +188,7 @@ class GRLVQ(SiameseGLVQ):
 | 
			
		||||
        self.register_parameter("_relevances", Parameter(relevances))
 | 
			
		||||
 | 
			
		||||
        # Override the backbone
 | 
			
		||||
        self.backbone = LambdaLayer(self._apply_relevances,
 | 
			
		||||
                                    name="relevance scaling")
 | 
			
		||||
        self.backbone = LambdaLayer(self._apply_relevances, name="relevance scaling")
 | 
			
		||||
 | 
			
		||||
    def _apply_relevances(self, x):
 | 
			
		||||
        return x @ torch.diag(self._relevances)
 | 
			
		||||
@@ -210,8 +212,9 @@ class SiameseGMLVQ(SiameseGLVQ):
 | 
			
		||||
        super().__init__(hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # Override the backbone
 | 
			
		||||
        omega_initializer = kwargs.get("omega_initializer",
 | 
			
		||||
                                       EyeLinearTransformInitializer())
 | 
			
		||||
        omega_initializer = kwargs.get(
 | 
			
		||||
            "omega_initializer", EyeLinearTransformInitializer()
 | 
			
		||||
        )
 | 
			
		||||
        self.backbone = LinearTransform(
 | 
			
		||||
            self.hparams["input_dim"],
 | 
			
		||||
            self.hparams["latent_dim"],
 | 
			
		||||
@@ -229,6 +232,49 @@ class SiameseGMLVQ(SiameseGLVQ):
 | 
			
		||||
        return lam.detach().cpu()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GMLMLVQ(GLVQ):
 | 
			
		||||
    """Generalized Multi-Layer Matrix Learning Vector Quantization.
 | 
			
		||||
    Masks are applied to the omega layers to achieve sparsity and constrain
 | 
			
		||||
    learning to certain items of each omega.
 | 
			
		||||
 | 
			
		||||
    Implemented as a regular GLVQ network that simply uses a different distance
 | 
			
		||||
    function. This makes it easier to implement a localized variant.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Parameters
 | 
			
		||||
    _omegas: list[torch.Tensor]
 | 
			
		||||
    masks: list[torch.Tensor]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hparams, **kwargs):
 | 
			
		||||
        distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
 | 
			
		||||
        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # Additional parameters
 | 
			
		||||
        self._masks = ParameterList(
 | 
			
		||||
            [Parameter(mask, requires_grad=False) for mask in kwargs.get("masks")]
 | 
			
		||||
        )
 | 
			
		||||
        self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in self._masks])
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def omega_matrices(self):
 | 
			
		||||
        return [_omega.detach().cpu() for _omega in self._omegas]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def lambda_matrix(self):
 | 
			
		||||
        # TODO update to respective lambda calculation rules.
 | 
			
		||||
        omega = self._omega.detach()  # (input_dim, latent_dim)
 | 
			
		||||
        lam = omega @ omega.T
 | 
			
		||||
        return lam.detach().cpu()
 | 
			
		||||
 | 
			
		||||
    def compute_distances(self, x):
 | 
			
		||||
        protos, _ = self.proto_layer()
 | 
			
		||||
        distances = self.distance_layer(x, protos, self._omegas, self._masks)
 | 
			
		||||
        return distances
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self):
 | 
			
		||||
        return f"(omegas): (shapes: {[tuple(_omega.shape) for _omega in self._omegas]})"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GMLVQ(GLVQ):
 | 
			
		||||
    """Generalized Matrix Learning Vector Quantization.
 | 
			
		||||
 | 
			
		||||
@@ -245,10 +291,12 @@ class GMLVQ(GLVQ):
 | 
			
		||||
        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # Additional parameters
 | 
			
		||||
        omega_initializer = kwargs.get("omega_initializer",
 | 
			
		||||
                                       EyeLinearTransformInitializer())
 | 
			
		||||
        omega = omega_initializer.generate(self.hparams["input_dim"],
 | 
			
		||||
                                           self.hparams["latent_dim"])
 | 
			
		||||
        omega_initializer = kwargs.get(
 | 
			
		||||
            "omega_initializer", EyeLinearTransformInitializer()
 | 
			
		||||
        )
 | 
			
		||||
        omega = omega_initializer.generate(
 | 
			
		||||
            self.hparams["input_dim"], self.hparams["latent_dim"]
 | 
			
		||||
        )
 | 
			
		||||
        self.register_parameter("_omega", Parameter(omega))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user