prototorch_models/prototorch/y_arch/architectures/competition.py
2022-05-31 17:56:03 +02:00

30 lines
951 B
Python

from dataclasses import dataclass
from prototorch.core.competitions import WTAC
from prototorch.y_arch.architectures.base import BaseYArchitecture
class WTACompetitionMixin(BaseYArchitecture):
"""
Winner Take All Competition
A competition layer that uses the winner-take-all strategy.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
No hyperparameters.
"""
# Steps
# ----------------------------------------------------------------------------------------------------
def init_inference(self, hparams: HyperParameters):
self.competition_layer = WTAC()
def inference(self, comparison_measures, components):
comp_labels = components[1]
return self.competition_layer(comparison_measures, comp_labels)