Compare commits
	
		
			5 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					
						
						
							
						
						adafb49985
	
				 | 
					
					
						||
| 
						 | 
					
						
						
							
						
						78f8b6cc00
	
				 | 
					
					
						||
| 
						 | 
					
						
						
							
						
						c6f718a1d4
	
				 | 
					
					
						||
| 
						 | 
					
						
						
							
						
						1786031b4e
	
				 | 
					
					
						||
| 
						 | 
					
						
						
							
						
						824dfced92
	
				 | 
					
					
						
@@ -1,13 +1,15 @@
 | 
				
			|||||||
"""Models based on the GLVQ framework."""
 | 
					"""Models based on the GLVQ framework."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from numpy.typing import NDArray
 | 
				
			||||||
from prototorch.core.competitions import wtac
 | 
					from prototorch.core.competitions import wtac
 | 
				
			||||||
from prototorch.core.distances import (
 | 
					from prototorch.core.distances import (
 | 
				
			||||||
 | 
					    ML_omega_distance,
 | 
				
			||||||
    lomega_distance,
 | 
					    lomega_distance,
 | 
				
			||||||
    omega_distance,
 | 
					    omega_distance,
 | 
				
			||||||
    squared_euclidean_distance,
 | 
					    squared_euclidean_distance,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from prototorch.core.initializers import EyeLinearTransformInitializer
 | 
					from prototorch.core.initializers import LLTI, EyeLinearTransformInitializer
 | 
				
			||||||
from prototorch.core.losses import (
 | 
					from prototorch.core.losses import (
 | 
				
			||||||
    GLVQLoss,
 | 
					    GLVQLoss,
 | 
				
			||||||
    lvq1_loss,
 | 
					    lvq1_loss,
 | 
				
			||||||
@@ -15,7 +17,7 @@ from prototorch.core.losses import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
from prototorch.core.transforms import LinearTransform
 | 
					from prototorch.core.transforms import LinearTransform
 | 
				
			||||||
from prototorch.nn.wrappers import LambdaLayer, LossLayer
 | 
					from prototorch.nn.wrappers import LambdaLayer, LossLayer
 | 
				
			||||||
from torch.nn.parameter import Parameter
 | 
					from torch.nn import Parameter, ParameterList
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
					from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
				
			||||||
from .extras import ltangent_distance, orthogonalization
 | 
					from .extras import ltangent_distance, orthogonalization
 | 
				
			||||||
@@ -45,26 +47,28 @@ class GLVQ(SupervisedPrototypeModel):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def initialize_prototype_win_ratios(self):
 | 
					    def initialize_prototype_win_ratios(self):
 | 
				
			||||||
        self.register_buffer(
 | 
					        self.register_buffer(
 | 
				
			||||||
            "prototype_win_ratios",
 | 
					            "prototype_win_ratios", torch.zeros(self.num_prototypes, device=self.device)
 | 
				
			||||||
            torch.zeros(self.num_prototypes, device=self.device))
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def on_train_epoch_start(self):
 | 
					    def on_train_epoch_start(self):
 | 
				
			||||||
        self.initialize_prototype_win_ratios()
 | 
					        self.initialize_prototype_win_ratios()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def log_prototype_win_ratios(self, distances):
 | 
					    def log_prototype_win_ratios(self, distances):
 | 
				
			||||||
        batch_size = len(distances)
 | 
					        batch_size = len(distances)
 | 
				
			||||||
        prototype_wc = torch.zeros(self.num_prototypes,
 | 
					        prototype_wc = torch.zeros(
 | 
				
			||||||
                                   dtype=torch.long,
 | 
					            self.num_prototypes, dtype=torch.long, device=self.device
 | 
				
			||||||
                                   device=self.device)
 | 
					        )
 | 
				
			||||||
        wi, wc = torch.unique(distances.min(dim=-1).indices,
 | 
					        wi, wc = torch.unique(
 | 
				
			||||||
                              sorted=True,
 | 
					            distances.min(dim=-1).indices, sorted=True, return_counts=True
 | 
				
			||||||
                              return_counts=True)
 | 
					        )
 | 
				
			||||||
        prototype_wc[wi] = wc
 | 
					        prototype_wc[wi] = wc
 | 
				
			||||||
        prototype_wr = prototype_wc / batch_size
 | 
					        prototype_wr = prototype_wc / batch_size
 | 
				
			||||||
        self.prototype_win_ratios = torch.vstack([
 | 
					        self.prototype_win_ratios = torch.vstack(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
                self.prototype_win_ratios,
 | 
					                self.prototype_win_ratios,
 | 
				
			||||||
                prototype_wr,
 | 
					                prototype_wr,
 | 
				
			||||||
        ])
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def shared_step(self, batch, batch_idx):
 | 
					    def shared_step(self, batch, batch_idx):
 | 
				
			||||||
        x, y = batch
 | 
					        x, y = batch
 | 
				
			||||||
@@ -109,11 +113,9 @@ class SiameseGLVQ(GLVQ):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(
 | 
				
			||||||
                 hparams,
 | 
					        self, hparams, backbone=torch.nn.Identity(), both_path_gradients=False, **kwargs
 | 
				
			||||||
                 backbone=torch.nn.Identity(),
 | 
					    ):
 | 
				
			||||||
                 both_path_gradients=False,
 | 
					 | 
				
			||||||
                 **kwargs):
 | 
					 | 
				
			||||||
        distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
 | 
					        distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
 | 
				
			||||||
        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
					        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
				
			||||||
        self.backbone = backbone
 | 
					        self.backbone = backbone
 | 
				
			||||||
@@ -175,6 +177,7 @@ class GRLVQ(SiameseGLVQ):
 | 
				
			|||||||
    TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
 | 
					    TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _relevances: torch.Tensor
 | 
					    _relevances: torch.Tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, hparams, **kwargs):
 | 
					    def __init__(self, hparams, **kwargs):
 | 
				
			||||||
@@ -185,8 +188,7 @@ class GRLVQ(SiameseGLVQ):
 | 
				
			|||||||
        self.register_parameter("_relevances", Parameter(relevances))
 | 
					        self.register_parameter("_relevances", Parameter(relevances))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Override the backbone
 | 
					        # Override the backbone
 | 
				
			||||||
        self.backbone = LambdaLayer(self._apply_relevances,
 | 
					        self.backbone = LambdaLayer(self._apply_relevances, name="relevance scaling")
 | 
				
			||||||
                                    name="relevance scaling")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _apply_relevances(self, x):
 | 
					    def _apply_relevances(self, x):
 | 
				
			||||||
        return x @ torch.diag(self._relevances)
 | 
					        return x @ torch.diag(self._relevances)
 | 
				
			||||||
@@ -210,8 +212,9 @@ class SiameseGMLVQ(SiameseGLVQ):
 | 
				
			|||||||
        super().__init__(hparams, **kwargs)
 | 
					        super().__init__(hparams, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Override the backbone
 | 
					        # Override the backbone
 | 
				
			||||||
        omega_initializer = kwargs.get("omega_initializer",
 | 
					        omega_initializer = kwargs.get(
 | 
				
			||||||
                                       EyeLinearTransformInitializer())
 | 
					            "omega_initializer", EyeLinearTransformInitializer()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.backbone = LinearTransform(
 | 
					        self.backbone = LinearTransform(
 | 
				
			||||||
            self.hparams["input_dim"],
 | 
					            self.hparams["input_dim"],
 | 
				
			||||||
            self.hparams["latent_dim"],
 | 
					            self.hparams["latent_dim"],
 | 
				
			||||||
@@ -229,6 +232,49 @@ class SiameseGMLVQ(SiameseGLVQ):
 | 
				
			|||||||
        return lam.detach().cpu()
 | 
					        return lam.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GMLMLVQ(GLVQ):
 | 
				
			||||||
 | 
					    """Generalized Multi-Layer Matrix Learning Vector Quantization.
 | 
				
			||||||
 | 
					    Masks are applied to the omega layers to achieve sparsity and constrain
 | 
				
			||||||
 | 
					    learning to certain items of each omega.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Implemented as a regular GLVQ network that simply uses a different distance
 | 
				
			||||||
 | 
					    function. This makes it easier to implement a localized variant.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Parameters
 | 
				
			||||||
 | 
					    _omegas: list[torch.Tensor]
 | 
				
			||||||
 | 
					    masks: list[torch.Tensor]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, hparams, **kwargs):
 | 
				
			||||||
 | 
					        distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
 | 
				
			||||||
 | 
					        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Additional parameters
 | 
				
			||||||
 | 
					        self._masks = ParameterList(
 | 
				
			||||||
 | 
					            [Parameter(mask, requires_grad=False) for mask in kwargs.get("masks")]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in self._masks])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def omega_matrices(self):
 | 
				
			||||||
 | 
					        return [_omega.detach().cpu() for _omega in self._omegas]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def lambda_matrix(self):
 | 
				
			||||||
 | 
					        # TODO update to respective lambda calculation rules.
 | 
				
			||||||
 | 
					        omega = self._omega.detach()  # (input_dim, latent_dim)
 | 
				
			||||||
 | 
					        lam = omega @ omega.T
 | 
				
			||||||
 | 
					        return lam.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def compute_distances(self, x):
 | 
				
			||||||
 | 
					        protos, _ = self.proto_layer()
 | 
				
			||||||
 | 
					        distances = self.distance_layer(x, protos, self._omegas, self._masks)
 | 
				
			||||||
 | 
					        return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def extra_repr(self):
 | 
				
			||||||
 | 
					        return f"(omegas): (shapes: {[tuple(_omega.shape) for _omega in self._omegas]})"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GMLVQ(GLVQ):
 | 
					class GMLVQ(GLVQ):
 | 
				
			||||||
    """Generalized Matrix Learning Vector Quantization.
 | 
					    """Generalized Matrix Learning Vector Quantization.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -245,10 +291,12 @@ class GMLVQ(GLVQ):
 | 
				
			|||||||
        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
					        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Additional parameters
 | 
					        # Additional parameters
 | 
				
			||||||
        omega_initializer = kwargs.get("omega_initializer",
 | 
					        omega_initializer = kwargs.get(
 | 
				
			||||||
                                       EyeLinearTransformInitializer())
 | 
					            "omega_initializer", EyeLinearTransformInitializer()
 | 
				
			||||||
        omega = omega_initializer.generate(self.hparams["input_dim"],
 | 
					        )
 | 
				
			||||||
                                           self.hparams["latent_dim"])
 | 
					        omega = omega_initializer.generate(
 | 
				
			||||||
 | 
					            self.hparams["input_dim"], self.hparams["latent_dim"]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.register_parameter("_omega", Parameter(omega))
 | 
					        self.register_parameter("_omega", Parameter(omega))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user