54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
from dataclasses import dataclass
|
|
|
|
from prototorch.core.components import LabeledComponents
|
|
from prototorch.core.initializers import (
|
|
AbstractComponentsInitializer,
|
|
LabelsInitializer,
|
|
)
|
|
from prototorch.y_arch import BaseYArchitecture
|
|
|
|
|
|
class SupervisedArchitecture(BaseYArchitecture):
|
|
"""
|
|
Supervised Architecture
|
|
|
|
An architecture that uses labeled Components as component Layer.
|
|
"""
|
|
components_layer: LabeledComponents
|
|
|
|
# HyperParameters
|
|
# ----------------------------------------------------------------------------------------------------
|
|
@dataclass
|
|
class HyperParameters:
|
|
"""
|
|
distribution: A valid prototype distribution. No default possible.
|
|
components_initializer: An implementation of AbstractComponentsInitializer. No default possible.
|
|
"""
|
|
distribution: "dict[str, int]"
|
|
component_initializer: AbstractComponentsInitializer
|
|
|
|
# Steps
|
|
# ----------------------------------------------------------------------------------------------------
|
|
def init_components(self, hparams: HyperParameters):
|
|
self.components_layer = LabeledComponents(
|
|
distribution=hparams.distribution,
|
|
components_initializer=hparams.component_initializer,
|
|
labels_initializer=LabelsInitializer(),
|
|
)
|
|
|
|
# Properties
|
|
# ----------------------------------------------------------------------------------------------------
|
|
@property
|
|
def prototypes(self):
|
|
"""
|
|
Returns the position of the prototypes.
|
|
"""
|
|
return self.components_layer.components.detach().cpu()
|
|
|
|
@property
|
|
def prototype_labels(self):
|
|
"""
|
|
Returns the labels of the prototypes.
|
|
"""
|
|
return self.components_layer.labels.detach().cpu()
|