2022-05-18 13:43:09 +00:00
|
|
|
from dataclasses import dataclass, field
|
2022-05-17 15:25:51 +00:00
|
|
|
from typing import Callable, Type
|
2022-05-17 14:25:43 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
from prototorch.core.competitions import WTAC
|
|
|
|
from prototorch.core.components import LabeledComponents
|
|
|
|
from prototorch.core.distances import euclidean_distance
|
|
|
|
from prototorch.core.initializers import (
|
|
|
|
AbstractComponentsInitializer,
|
|
|
|
LabelsInitializer,
|
|
|
|
)
|
|
|
|
from prototorch.core.losses import GLVQLoss
|
2022-05-18 12:11:46 +00:00
|
|
|
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
|
2022-05-17 14:25:43 +00:00
|
|
|
from prototorch.nn.wrappers import LambdaLayer
|
|
|
|
|
|
|
|
|
2022-05-18 13:43:09 +00:00
|
|
|
class SupervisedArchitecture(BaseYArchitecture):
|
|
|
|
components_layer: LabeledComponents
|
2022-05-17 14:25:43 +00:00
|
|
|
|
2022-05-17 15:25:51 +00:00
|
|
|
@dataclass
|
|
|
|
class HyperParameters:
|
|
|
|
distribution: dict[str, int]
|
|
|
|
component_initializer: AbstractComponentsInitializer
|
2022-05-17 14:25:43 +00:00
|
|
|
|
2022-05-17 15:25:51 +00:00
|
|
|
def init_components(self, hparams: HyperParameters):
|
2022-05-17 14:25:43 +00:00
|
|
|
self.components_layer = LabeledComponents(
|
|
|
|
distribution=hparams.distribution,
|
|
|
|
components_initializer=hparams.component_initializer,
|
|
|
|
labels_initializer=LabelsInitializer(),
|
|
|
|
)
|
|
|
|
|
2022-05-18 13:43:09 +00:00
|
|
|
@property
|
|
|
|
def prototypes(self):
|
|
|
|
return self.components_layer.components.detach().cpu()
|
2022-05-17 15:25:51 +00:00
|
|
|
|
2022-05-18 13:43:09 +00:00
|
|
|
@property
|
|
|
|
def prototype_labels(self):
|
|
|
|
return self.components_layer.labels.detach().cpu()
|
|
|
|
|
|
|
|
|
|
|
|
class WTACompetitionMixin(BaseYArchitecture):
|
2022-05-17 15:25:51 +00:00
|
|
|
|
|
|
|
@dataclass
|
2022-05-18 13:43:09 +00:00
|
|
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def init_inference(self, hparams: HyperParameters):
|
|
|
|
self.competition_layer = WTAC()
|
|
|
|
|
|
|
|
def inference(self, comparison_measures, components):
|
|
|
|
comp_labels = components[1]
|
|
|
|
return self.competition_layer(comparison_measures, comp_labels)
|
|
|
|
|
|
|
|
|
|
|
|
class GLVQLossMixin(BaseYArchitecture):
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
2022-05-17 15:25:51 +00:00
|
|
|
margin: float = 0.0
|
2022-05-18 13:43:09 +00:00
|
|
|
|
|
|
|
transfer_fn: str = "sigmoid_beta"
|
|
|
|
transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
|
|
|
|
|
|
|
|
def init_loss(self, hparams: HyperParameters):
|
|
|
|
self.loss_layer = GLVQLoss(
|
|
|
|
margin=hparams.margin,
|
|
|
|
transfer_fn=hparams.transfer_fn,
|
|
|
|
**hparams.transfer_args,
|
|
|
|
)
|
|
|
|
|
|
|
|
def loss(self, comparison_measures, batch, components):
|
|
|
|
target = batch[1]
|
|
|
|
comp_labels = components[1]
|
|
|
|
loss = self.loss_layer(comparison_measures, target, comp_labels)
|
|
|
|
self.log('loss', loss)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
class SingleLearningRateMixin(BaseYArchitecture):
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
|
|
# Training Hyperparameters
|
|
|
|
lr: float = 0.01
|
2022-05-17 15:25:51 +00:00
|
|
|
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
|
|
|
|
|
|
|
def __init__(self, hparams: HyperParameters) -> None:
|
|
|
|
super().__init__(hparams)
|
|
|
|
self.lr = hparams.lr
|
|
|
|
self.optimizer = hparams.optimizer
|
|
|
|
|
2022-05-18 13:43:09 +00:00
|
|
|
def configure_optimizers(self):
|
|
|
|
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
|
2022-05-17 14:25:43 +00:00
|
|
|
|
|
|
|
|
2022-05-18 13:43:09 +00:00
|
|
|
class SimpleComparisonMixin(BaseYArchitecture):
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
|
|
# Training Hyperparameters
|
|
|
|
comparison_fn: Callable = euclidean_distance
|
|
|
|
comparison_args: dict = field(default_factory=lambda: dict())
|
|
|
|
|
|
|
|
def init_comparison(self, hparams: HyperParameters):
|
|
|
|
self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
|
|
|
|
**hparams.comparison_args)
|
2022-05-17 14:25:43 +00:00
|
|
|
|
|
|
|
def comparison(self, batch, components):
|
|
|
|
comp_tensor, _ = components
|
|
|
|
batch_tensor, _ = batch
|
|
|
|
|
|
|
|
comp_tensor = comp_tensor.unsqueeze(1)
|
|
|
|
|
|
|
|
distances = self.comparison_layer(batch_tensor, comp_tensor)
|
|
|
|
|
|
|
|
return distances
|
|
|
|
|
|
|
|
|
2022-05-18 13:43:09 +00:00
|
|
|
# ##############################################################################
|
|
|
|
# GLVQ
|
|
|
|
# ##############################################################################
|
|
|
|
class GLVQ(
|
|
|
|
SupervisedArchitecture,
|
|
|
|
SimpleComparisonMixin,
|
|
|
|
GLVQLossMixin,
|
|
|
|
WTACompetitionMixin,
|
|
|
|
SingleLearningRateMixin,
|
|
|
|
):
|
|
|
|
"""GLVQ using the new Scheme
|
|
|
|
"""
|
2022-05-17 14:25:43 +00:00
|
|
|
|
2022-05-18 13:43:09 +00:00
|
|
|
@dataclass
|
|
|
|
class HyperParameters(
|
|
|
|
SimpleComparisonMixin.HyperParameters,
|
|
|
|
SingleLearningRateMixin.HyperParameters,
|
|
|
|
GLVQLossMixin.HyperParameters,
|
|
|
|
WTACompetitionMixin.HyperParameters,
|
|
|
|
SupervisedArchitecture.HyperParameters,
|
|
|
|
):
|
|
|
|
pass
|