prototorch_models/prototorch/models/architectures/comparison.py

140 lines
4.5 KiB
Python
Raw Normal View History

2022-05-19 14:13:08 +00:00
from __future__ import annotations
from dataclasses import dataclass, field
2022-08-15 10:14:14 +00:00
from typing import Callable
2022-05-19 14:13:08 +00:00
import torch
2022-05-31 15:56:03 +00:00
from prototorch.core.distances import euclidean_distance
2022-05-19 14:13:08 +00:00
from prototorch.core.initializers import (
AbstractLinearTransformInitializer,
EyeLinearTransformInitializer,
)
2022-08-15 10:14:14 +00:00
from prototorch.models.architectures.base import BaseYArchitecture
2022-05-19 14:13:08 +00:00
from prototorch.nn.wrappers import LambdaLayer
2022-05-31 15:56:03 +00:00
from torch import Tensor
2022-05-19 14:13:08 +00:00
from torch.nn.parameter import Parameter
2022-05-31 15:56:03 +00:00
class SimpleComparisonMixin(BaseYArchitecture):
2022-05-19 14:13:08 +00:00
"""
2022-05-31 15:56:03 +00:00
Simple Comparison
2022-05-19 14:13:08 +00:00
2022-08-15 10:14:14 +00:00
A comparison layer that only uses the positions of the components
and the batch for dissimilarity computation.
2022-05-19 14:13:08 +00:00
"""
# HyperParameters
2022-08-15 10:14:14 +00:00
# ----------------------------------------------------------------------------------------------
2022-05-19 14:13:08 +00:00
@dataclass
2022-05-31 15:56:03 +00:00
class HyperParameters(BaseYArchitecture.HyperParameters):
2022-05-19 14:13:08 +00:00
"""
2022-05-31 15:56:03 +00:00
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
comparison_args: Keyword arguments for the comparison function. Default: {}.
2022-05-19 14:13:08 +00:00
"""
2022-05-31 15:56:03 +00:00
comparison_fn: Callable = euclidean_distance
2022-05-19 14:13:08 +00:00
comparison_args: dict = field(default_factory=lambda: dict())
2022-05-31 15:56:03 +00:00
comparison_parameters: dict = field(default_factory=lambda: dict())
2022-05-19 14:13:08 +00:00
# Steps
2022-08-15 10:14:14 +00:00
# ----------------------------------------------------------------------------------------------
2022-05-31 15:56:03 +00:00
def init_comparison(self, hparams: HyperParameters):
2022-05-19 14:13:08 +00:00
self.comparison_layer = LambdaLayer(
fn=hparams.comparison_fn,
**hparams.comparison_args,
)
2022-05-31 15:56:03 +00:00
self.comparison_kwargs: dict[str, Tensor] = dict()
2022-05-19 14:13:08 +00:00
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(
batch_tensor,
comp_tensor,
2022-05-31 15:56:03 +00:00
**self.comparison_kwargs,
2022-05-19 14:13:08 +00:00
)
return distances
2022-05-31 15:56:03 +00:00
class OmegaComparisonMixin(SimpleComparisonMixin):
"""
Omega Comparison
2022-08-15 10:14:14 +00:00
A comparison layer that uses the positions of the components
and the batch for dissimilarity computation.
2022-05-31 15:56:03 +00:00
"""
_omega: torch.Tensor
# HyperParameters
2022-08-15 10:14:14 +00:00
# ----------------------------------------------------------------------------------------------
2022-05-31 15:56:03 +00:00
@dataclass
class HyperParameters(SimpleComparisonMixin.HyperParameters):
"""
input_dim: Necessary Field: The dimensionality of the input.
2022-08-15 10:14:14 +00:00
latent_dim:
The dimensionality of the latent space. Default: 2.
omega_initializer:
The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
2022-05-31 15:56:03 +00:00
"""
input_dim: int | None = None
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
2022-05-31 15:56:03 +00:00
# Steps
2022-08-15 10:14:14 +00:00
# ----------------------------------------------------------------------------------------------
2022-05-31 15:56:03 +00:00
def init_comparison(self, hparams: HyperParameters) -> None:
super().init_comparison(hparams)
# Initialize the omega matrix
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
else:
omega = hparams.omega_initializer(
**hparams.omega_initializer_kwargs).generate(
hparams.input_dim,
hparams.latent_dim,
)
2022-05-31 15:56:03 +00:00
self.register_parameter("_omega", Parameter(omega))
self.comparison_kwargs = dict(omega=self._omega)
2022-05-19 14:13:08 +00:00
# Properties
2022-08-15 10:14:14 +00:00
# ----------------------------------------------------------------------------------------------
2022-05-19 14:13:08 +00:00
@property
def omega_matrix(self):
2022-08-15 10:14:14 +00:00
'''
Omega Matrix. Mapping applied to data and prototypes.
'''
2022-05-19 14:13:08 +00:00
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
2022-08-15 10:14:14 +00:00
'''
Lambda Matrix.
'''
2022-05-19 14:13:08 +00:00
omega = self._omega.detach()
lam = omega @ omega.T
return lam.detach().cpu()
2022-08-15 10:14:14 +00:00
@property
def relevance_profile(self):
'''
Relevance Profile. Main Diagonal of the Lambda Matrix.
'''
return self.lambda_matrix.diag().abs()
@property
def classification_influence_profile(self):
'''
Classification Influence Profile. Influence of each dimension.
'''
lam = self.lambda_matrix
return lam.abs().sum(0)