42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import Callable
|
|
|
|
from prototorch.core.distances import euclidean_distance
|
|
from prototorch.models.y_arch.architectures.base import BaseYArchitecture
|
|
from prototorch.nn.wrappers import LambdaLayer
|
|
|
|
|
|
class SimpleComparisonMixin(BaseYArchitecture):
|
|
"""
|
|
Simple Comparison
|
|
|
|
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
|
|
"""
|
|
|
|
# HyperParameters
|
|
# ----------------------------------------------------------------------------------------------------
|
|
@dataclass
|
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
"""
|
|
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
|
|
comparison_args: Keyword arguments for the comparison function. Default: {}.
|
|
"""
|
|
comparison_fn: Callable = euclidean_distance
|
|
comparison_args: dict = field(default_factory=lambda: dict())
|
|
|
|
# Steps
|
|
# ----------------------------------------------------------------------------------------------------
|
|
def init_comparison(self, hparams: HyperParameters):
|
|
self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
|
|
**hparams.comparison_args)
|
|
|
|
def comparison(self, batch, components):
|
|
comp_tensor, _ = components
|
|
batch_tensor, _ = batch
|
|
|
|
comp_tensor = comp_tensor.unsqueeze(1)
|
|
|
|
distances = self.comparison_layer(batch_tensor, comp_tensor)
|
|
|
|
return distances
|