prototorch_models/prototorch/models/architectures/comparison.py
2022-08-16 15:55:05 +02:00

140 lines
4.5 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable
import torch
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
AbstractLinearTransformInitializer,
EyeLinearTransformInitializer,
)
from prototorch.models.architectures.base import BaseYArchitecture
from prototorch.nn.wrappers import LambdaLayer
from torch import Tensor
from torch.nn.parameter import Parameter
class SimpleComparisonMixin(BaseYArchitecture):
"""
Simple Comparison
A comparison layer that only uses the positions of the components
and the batch for dissimilarity computation.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
comparison_args: Keyword arguments for the comparison function. Default: {}.
"""
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
comparison_parameters: dict = field(default_factory=lambda: dict())
# Steps
# ----------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters):
self.comparison_layer = LambdaLayer(
fn=hparams.comparison_fn,
**hparams.comparison_args,
)
self.comparison_kwargs: dict[str, Tensor] = dict()
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(
batch_tensor,
comp_tensor,
**self.comparison_kwargs,
)
return distances
class OmegaComparisonMixin(SimpleComparisonMixin):
"""
Omega Comparison
A comparison layer that uses the positions of the components
and the batch for dissimilarity computation.
"""
_omega: torch.Tensor
# HyperParameters
# ----------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(SimpleComparisonMixin.HyperParameters):
"""
input_dim: Necessary Field: The dimensionality of the input.
latent_dim:
The dimensionality of the latent space. Default: 2.
omega_initializer:
The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
"""
input_dim: int | None = None
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
# Steps
# ----------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters) -> None:
super().init_comparison(hparams)
# Initialize the omega matrix
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
else:
omega = hparams.omega_initializer(
**hparams.omega_initializer_kwargs).generate(
hparams.input_dim,
hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega))
self.comparison_kwargs = dict(omega=self._omega)
# Properties
# ----------------------------------------------------------------------------------------------
@property
def omega_matrix(self):
'''
Omega Matrix. Mapping applied to data and prototypes.
'''
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
'''
Lambda Matrix.
'''
omega = self._omega.detach()
lam = omega @ omega.T
return lam.detach().cpu()
@property
def relevance_profile(self):
'''
Relevance Profile. Main Diagonal of the Lambda Matrix.
'''
return self.lambda_matrix.diag().abs()
@property
def classification_influence_profile(self):
'''
Classification Influence Profile. Influence of each dimension.
'''
lam = self.lambda_matrix
return lam.abs().sum(0)