prototorch_models/prototorch/y/architectures/comparison.py

113 lines
3.8 KiB
Python
Raw Normal View History

2022-05-19 14:13:08 +00:00
from __future__ import annotations
from dataclasses import dataclass, field
2022-05-31 15:56:03 +00:00
from typing import Callable, Dict
2022-05-19 14:13:08 +00:00
import torch
2022-05-31 15:56:03 +00:00
from prototorch.core.distances import euclidean_distance
2022-05-19 14:13:08 +00:00
from prototorch.core.initializers import (
AbstractLinearTransformInitializer,
EyeLinearTransformInitializer,
)
from prototorch.nn.wrappers import LambdaLayer
2022-06-03 08:39:11 +00:00
from prototorch.y.architectures.base import BaseYArchitecture
2022-05-31 15:56:03 +00:00
from torch import Tensor
2022-05-19 14:13:08 +00:00
from torch.nn.parameter import Parameter
2022-05-31 15:56:03 +00:00
class SimpleComparisonMixin(BaseYArchitecture):
2022-05-19 14:13:08 +00:00
"""
2022-05-31 15:56:03 +00:00
Simple Comparison
2022-05-19 14:13:08 +00:00
2022-05-31 15:56:03 +00:00
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
2022-05-19 14:13:08 +00:00
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
2022-05-31 15:56:03 +00:00
class HyperParameters(BaseYArchitecture.HyperParameters):
2022-05-19 14:13:08 +00:00
"""
2022-05-31 15:56:03 +00:00
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
comparison_args: Keyword arguments for the comparison function. Default: {}.
2022-05-19 14:13:08 +00:00
"""
2022-05-31 15:56:03 +00:00
comparison_fn: Callable = euclidean_distance
2022-05-19 14:13:08 +00:00
comparison_args: dict = field(default_factory=lambda: dict())
2022-05-31 15:56:03 +00:00
comparison_parameters: dict = field(default_factory=lambda: dict())
2022-05-19 14:13:08 +00:00
# Steps
# ----------------------------------------------------------------------------------------------------
2022-05-31 15:56:03 +00:00
def init_comparison(self, hparams: HyperParameters):
2022-05-19 14:13:08 +00:00
self.comparison_layer = LambdaLayer(
fn=hparams.comparison_fn,
**hparams.comparison_args,
)
2022-05-31 15:56:03 +00:00
self.comparison_kwargs: dict[str, Tensor] = dict()
2022-05-19 14:13:08 +00:00
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(
batch_tensor,
comp_tensor,
2022-05-31 15:56:03 +00:00
**self.comparison_kwargs,
2022-05-19 14:13:08 +00:00
)
return distances
2022-05-31 15:56:03 +00:00
class OmegaComparisonMixin(SimpleComparisonMixin):
"""
Omega Comparison
A comparison layer that uses the positions of the components and the batch for dissimilarity computation.
"""
_omega: torch.Tensor
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(SimpleComparisonMixin.HyperParameters):
"""
input_dim: Necessary Field: The dimensionality of the input.
latent_dim: The dimensionality of the latent space. Default: 2.
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
"""
input_dim: int | None = None
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
# Steps
# ----------------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters) -> None:
super().init_comparison(hparams)
# Initialize the omega matrix
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
else:
omega = hparams.omega_initializer().generate(
hparams.input_dim,
hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega))
self.comparison_kwargs = dict(omega=self._omega)
2022-05-19 14:13:08 +00:00
# Properties
# ----------------------------------------------------------------------------------------------------
@property
def omega_matrix(self):
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
omega = self._omega.detach()
lam = omega @ omega.T
return lam.detach().cpu()