prototorch_models/prototorch/y_arch/architectures/competition.py

30 lines
951 B
Python
Raw Normal View History

2022-05-19 14:13:08 +00:00
from dataclasses import dataclass
from prototorch.core.competitions import WTAC
2022-05-31 15:56:03 +00:00
from prototorch.y_arch.architectures.base import BaseYArchitecture
2022-05-19 14:13:08 +00:00
class WTACompetitionMixin(BaseYArchitecture):
"""
Winner Take All Competition
A competition layer that uses the winner-take-all strategy.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
No hyperparameters.
"""
# Steps
# ----------------------------------------------------------------------------------------------------
def init_inference(self, hparams: HyperParameters):
self.competition_layer = WTAC()
def inference(self, comparison_measures, components):
comp_labels = components[1]
return self.competition_layer(comparison_measures, comp_labels)