54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
|
from dataclasses import dataclass
|
||
|
|
||
|
from prototorch.core.components import LabeledComponents
|
||
|
from prototorch.core.initializers import (
|
||
|
AbstractComponentsInitializer,
|
||
|
LabelsInitializer,
|
||
|
)
|
||
|
from prototorch.models.y_arch import BaseYArchitecture
|
||
|
|
||
|
|
||
|
class SupervisedArchitecture(BaseYArchitecture):
|
||
|
"""
|
||
|
Supervised Architecture
|
||
|
|
||
|
An architecture that uses labeled Components as component Layer.
|
||
|
"""
|
||
|
components_layer: LabeledComponents
|
||
|
|
||
|
# HyperParameters
|
||
|
# ----------------------------------------------------------------------------------------------------
|
||
|
@dataclass
|
||
|
class HyperParameters:
|
||
|
"""
|
||
|
distribution: A valid prototype distribution. No default possible.
|
||
|
components_initializer: An implementation of AbstractComponentsInitializer. No default possible.
|
||
|
"""
|
||
|
distribution: "dict[str, int]"
|
||
|
component_initializer: AbstractComponentsInitializer
|
||
|
|
||
|
# Steps
|
||
|
# ----------------------------------------------------------------------------------------------------
|
||
|
def init_components(self, hparams: HyperParameters):
|
||
|
self.components_layer = LabeledComponents(
|
||
|
distribution=hparams.distribution,
|
||
|
components_initializer=hparams.component_initializer,
|
||
|
labels_initializer=LabelsInitializer(),
|
||
|
)
|
||
|
|
||
|
# Properties
|
||
|
# ----------------------------------------------------------------------------------------------------
|
||
|
@property
|
||
|
def prototypes(self):
|
||
|
"""
|
||
|
Returns the position of the prototypes.
|
||
|
"""
|
||
|
return self.components_layer.components.detach().cpu()
|
||
|
|
||
|
@property
|
||
|
def prototype_labels(self):
|
||
|
"""
|
||
|
Returns the labels of the prototypes.
|
||
|
"""
|
||
|
return self.components_layer.labels.detach().cpu()
|