2022-05-19 14:13:08 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field
|
2022-05-31 15:56:03 +00:00
|
|
|
from typing import Callable, Dict
|
2022-05-19 14:13:08 +00:00
|
|
|
|
|
|
|
import torch
|
2022-05-31 15:56:03 +00:00
|
|
|
from prototorch.core.distances import euclidean_distance
|
2022-05-19 14:13:08 +00:00
|
|
|
from prototorch.core.initializers import (
|
|
|
|
AbstractLinearTransformInitializer,
|
|
|
|
EyeLinearTransformInitializer,
|
|
|
|
)
|
|
|
|
from prototorch.nn.wrappers import LambdaLayer
|
2022-06-03 08:39:11 +00:00
|
|
|
from prototorch.y.architectures.base import BaseYArchitecture
|
2022-05-31 15:56:03 +00:00
|
|
|
from torch import Tensor
|
2022-05-19 14:13:08 +00:00
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
|
|
|
2022-05-31 15:56:03 +00:00
|
|
|
class SimpleComparisonMixin(BaseYArchitecture):
|
2022-05-19 14:13:08 +00:00
|
|
|
"""
|
2022-05-31 15:56:03 +00:00
|
|
|
Simple Comparison
|
2022-05-19 14:13:08 +00:00
|
|
|
|
2022-05-31 15:56:03 +00:00
|
|
|
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
|
2022-05-19 14:13:08 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
# HyperParameters
|
|
|
|
# ----------------------------------------------------------------------------------------------------
|
|
|
|
@dataclass
|
2022-05-31 15:56:03 +00:00
|
|
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
2022-05-19 14:13:08 +00:00
|
|
|
"""
|
2022-05-31 15:56:03 +00:00
|
|
|
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
|
|
|
|
comparison_args: Keyword arguments for the comparison function. Default: {}.
|
2022-05-19 14:13:08 +00:00
|
|
|
"""
|
2022-05-31 15:56:03 +00:00
|
|
|
comparison_fn: Callable = euclidean_distance
|
2022-05-19 14:13:08 +00:00
|
|
|
comparison_args: dict = field(default_factory=lambda: dict())
|
|
|
|
|
2022-05-31 15:56:03 +00:00
|
|
|
comparison_parameters: dict = field(default_factory=lambda: dict())
|
2022-05-19 14:13:08 +00:00
|
|
|
|
|
|
|
# Steps
|
|
|
|
# ----------------------------------------------------------------------------------------------------
|
2022-05-31 15:56:03 +00:00
|
|
|
def init_comparison(self, hparams: HyperParameters):
|
2022-05-19 14:13:08 +00:00
|
|
|
self.comparison_layer = LambdaLayer(
|
|
|
|
fn=hparams.comparison_fn,
|
|
|
|
**hparams.comparison_args,
|
|
|
|
)
|
|
|
|
|
2022-05-31 15:56:03 +00:00
|
|
|
self.comparison_kwargs: dict[str, Tensor] = dict()
|
|
|
|
|
2022-05-19 14:13:08 +00:00
|
|
|
def comparison(self, batch, components):
|
|
|
|
comp_tensor, _ = components
|
|
|
|
batch_tensor, _ = batch
|
|
|
|
|
|
|
|
comp_tensor = comp_tensor.unsqueeze(1)
|
|
|
|
|
|
|
|
distances = self.comparison_layer(
|
|
|
|
batch_tensor,
|
|
|
|
comp_tensor,
|
2022-05-31 15:56:03 +00:00
|
|
|
**self.comparison_kwargs,
|
2022-05-19 14:13:08 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
return distances
|
|
|
|
|
2022-05-31 15:56:03 +00:00
|
|
|
|
|
|
|
class OmegaComparisonMixin(SimpleComparisonMixin):
|
|
|
|
"""
|
|
|
|
Omega Comparison
|
|
|
|
|
|
|
|
A comparison layer that uses the positions of the components and the batch for dissimilarity computation.
|
|
|
|
"""
|
|
|
|
|
|
|
|
_omega: torch.Tensor
|
|
|
|
|
|
|
|
# HyperParameters
|
|
|
|
# ----------------------------------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
|
|
class HyperParameters(SimpleComparisonMixin.HyperParameters):
|
|
|
|
"""
|
|
|
|
input_dim: Necessary Field: The dimensionality of the input.
|
|
|
|
latent_dim: The dimensionality of the latent space. Default: 2.
|
|
|
|
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
|
|
|
|
"""
|
|
|
|
input_dim: int | None = None
|
|
|
|
latent_dim: int = 2
|
|
|
|
omega_initializer: type[
|
|
|
|
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
|
|
|
|
|
|
|
# Steps
|
|
|
|
# ----------------------------------------------------------------------------------------------------
|
|
|
|
def init_comparison(self, hparams: HyperParameters) -> None:
|
|
|
|
super().init_comparison(hparams)
|
|
|
|
|
|
|
|
# Initialize the omega matrix
|
|
|
|
if hparams.input_dim is None:
|
|
|
|
raise ValueError("input_dim must be specified.")
|
|
|
|
else:
|
|
|
|
omega = hparams.omega_initializer().generate(
|
|
|
|
hparams.input_dim,
|
|
|
|
hparams.latent_dim,
|
|
|
|
)
|
|
|
|
self.register_parameter("_omega", Parameter(omega))
|
|
|
|
self.comparison_kwargs = dict(omega=self._omega)
|
2022-05-19 14:13:08 +00:00
|
|
|
|
|
|
|
# Properties
|
|
|
|
# ----------------------------------------------------------------------------------------------------
|
|
|
|
@property
|
|
|
|
def omega_matrix(self):
|
|
|
|
return self._omega.detach().cpu()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def lambda_matrix(self):
|
|
|
|
omega = self._omega.detach()
|
|
|
|
lam = omega @ omega.T
|
|
|
|
return lam.detach().cpu()
|