65 lines
2.3 KiB
Python
65 lines
2.3 KiB
Python
from dataclasses import dataclass
|
|
|
|
from prototorch.core.components import LabeledComponents
|
|
from prototorch.core.initializers import (
|
|
AbstractComponentsInitializer,
|
|
LabelsInitializer,
|
|
ZerosCompInitializer,
|
|
)
|
|
from prototorch.y import BaseYArchitecture
|
|
|
|
|
|
class SupervisedArchitecture(BaseYArchitecture):
|
|
"""
|
|
Supervised Architecture
|
|
|
|
An architecture that uses labeled Components as component Layer.
|
|
"""
|
|
components_layer: LabeledComponents
|
|
|
|
# HyperParameters
|
|
# ----------------------------------------------------------------------------------------------------
|
|
@dataclass
|
|
class HyperParameters:
|
|
"""
|
|
distribution: A valid prototype distribution. No default possible.
|
|
components_initializer: An implementation of AbstractComponentsInitializer. No default possible.
|
|
"""
|
|
distribution: "dict[str, int]"
|
|
component_initializer: AbstractComponentsInitializer
|
|
|
|
# Steps
|
|
# ----------------------------------------------------------------------------------------------------
|
|
def init_components(self, hparams: HyperParameters):
|
|
if hparams.component_initializer is not None:
|
|
self.components_layer = LabeledComponents(
|
|
distribution=hparams.distribution,
|
|
components_initializer=hparams.component_initializer,
|
|
labels_initializer=LabelsInitializer(),
|
|
)
|
|
proto_shape = self.components_layer.components.shape[1:]
|
|
self.hparams["initialized_proto_shape"] = proto_shape
|
|
else:
|
|
# when restoring a checkpointed model
|
|
self.components_layer = LabeledComponents(
|
|
distribution=hparams.distribution,
|
|
components_initializer=ZerosCompInitializer(
|
|
self.hparams["initialized_proto_shape"]),
|
|
)
|
|
|
|
# Properties
|
|
# ----------------------------------------------------------------------------------------------------
|
|
@property
|
|
def prototypes(self):
|
|
"""
|
|
Returns the position of the prototypes.
|
|
"""
|
|
return self.components_layer.components.detach().cpu()
|
|
|
|
@property
|
|
def prototype_labels(self):
|
|
"""
|
|
Returns the labels of the prototypes.
|
|
"""
|
|
return self.components_layer.labels.detach().cpu()
|