2022-05-19 14:13:08 +00:00
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
from prototorch.core.competitions import WTAC
|
2022-06-03 08:39:11 +00:00
|
|
|
from prototorch.y.architectures.base import BaseYArchitecture
|
2022-05-19 14:13:08 +00:00
|
|
|
|
|
|
|
|
|
|
|
class WTACompetitionMixin(BaseYArchitecture):
|
|
|
|
"""
|
|
|
|
Winner Take All Competition
|
|
|
|
|
|
|
|
A competition layer that uses the winner-take-all strategy.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# HyperParameters
|
|
|
|
# ----------------------------------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
|
|
class HyperParameters(BaseYArchitecture.HyperParameters):
|
|
|
|
"""
|
|
|
|
No hyperparameters.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Steps
|
|
|
|
# ----------------------------------------------------------------------------------------------------
|
|
|
|
def init_inference(self, hparams: HyperParameters):
|
|
|
|
self.competition_layer = WTAC()
|
|
|
|
|
|
|
|
def inference(self, comparison_measures, components):
|
|
|
|
comp_labels = components[1]
|
|
|
|
return self.competition_layer(comparison_measures, comp_labels)
|