[REFACTOR] Add reasonings initializers
This commit is contained in:
parent
668c9a1fb7
commit
083cc929be
@ -205,10 +205,7 @@ class AbstractLabelsInitializer(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class LabelsInitializer(AbstractLabelsInitializer):
|
class LabelsInitializer(AbstractLabelsInitializer):
|
||||||
"""Generate labels with `self.distribution`."""
|
"""Generate labels from `distribution`."""
|
||||||
def __init__(self, override_labels: list = []):
|
|
||||||
self.override_labels = override_labels
|
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
distribution = parse_distribution(distribution)
|
distribution = parse_distribution(distribution)
|
||||||
labels = []
|
labels = []
|
||||||
@ -218,25 +215,79 @@ class LabelsInitializer(AbstractLabelsInitializer):
|
|||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
# Reasonings
|
class OneHotLabelsInitializer(LabelsInitializer):
|
||||||
class AbstractReasoningsInitializer(ABC):
|
"""Generate one-hot-encoded labels from `distribution`."""
|
||||||
"""Abstract class for all reasonings initializers."""
|
|
||||||
@abstractmethod
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
|
||||||
"""Generate labels with `self.distribution`."""
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
distribution = parse_distribution(distribution)
|
distribution = parse_distribution(distribution)
|
||||||
num_classes = len(distribution.keys())
|
num_classes = len(distribution.keys())
|
||||||
|
# this breaks if class labels are not [0,...,nclasses]
|
||||||
|
labels = torch.eye(num_classes)[super().generate(distribution)]
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
# Reasonings
|
||||||
|
class AbstractReasoningsInitializer(ABC):
|
||||||
|
"""Abstract class for all reasonings initializers."""
|
||||||
|
def __init__(self, components_first=True):
|
||||||
|
self.components_first = components_first
|
||||||
|
|
||||||
|
def compute_shape(self, distribution):
|
||||||
|
distribution = parse_distribution(distribution)
|
||||||
num_components = sum(distribution.values())
|
num_components = sum(distribution.values())
|
||||||
assert num_classes == num_components
|
num_classes = len(distribution.keys())
|
||||||
reasonings = torch.stack(
|
return (num_components, num_classes, 2)
|
||||||
[torch.eye(num_classes),
|
|
||||||
torch.zeros(num_classes, num_classes)],
|
def generate_end_hook(self, reasonings):
|
||||||
dim=0)
|
if not self.components_first:
|
||||||
|
reasonings = reasonings.permute(2, 1, 0)
|
||||||
|
return reasonings
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
|
...
|
||||||
|
return generate_end_hook(...)
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
|
"""Reasonings are all initialized with zeros."""
|
||||||
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
|
shape = self.compute_shape(distribution)
|
||||||
|
reasonings = torch.zeros(*shape)
|
||||||
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
|
return reasonings
|
||||||
|
|
||||||
|
|
||||||
|
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
|
"""Reasonings are all initialized with ones."""
|
||||||
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
|
shape = self.compute_shape(distribution)
|
||||||
|
reasonings = torch.ones(*shape)
|
||||||
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
|
return reasonings
|
||||||
|
|
||||||
|
|
||||||
|
class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
|
"""Reasonings are randomly initialized."""
|
||||||
|
def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.minimum = minimum
|
||||||
|
self.maximum = maximum
|
||||||
|
|
||||||
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
|
shape = self.compute_shape(distribution)
|
||||||
|
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
||||||
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
|
return reasonings
|
||||||
|
|
||||||
|
|
||||||
|
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
|
"""Each component reasons positively for exactly one class."""
|
||||||
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
|
num_components, num_classes, _ = self.compute_shape(distribution)
|
||||||
|
A = OneHotLabelsInitializer().generate(distribution)
|
||||||
|
B = torch.zeros(num_components, num_classes)
|
||||||
|
reasonings = torch.stack([A, B]).permute(2, 1, 0)
|
||||||
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
|
|
||||||
|
|
||||||
@ -251,4 +302,13 @@ SCI = SelectionCompInitializer
|
|||||||
MCI = MeanCompInitializer
|
MCI = MeanCompInitializer
|
||||||
SSCI = StratifiedSelectionCompInitializer
|
SSCI = StratifiedSelectionCompInitializer
|
||||||
SMCI = StratifiedMeanCompInitializer
|
SMCI = StratifiedMeanCompInitializer
|
||||||
|
|
||||||
|
# Aliases - Labels
|
||||||
|
LI = LabelsInitializer
|
||||||
|
OHLI = OneHotLabelsInitializer
|
||||||
|
|
||||||
|
# Aliases - Reasonings
|
||||||
|
ZRI = ZerosReasoningsInitializer
|
||||||
|
ORI = OnesReasoningsInitializer
|
||||||
|
RRI = RandomReasoningsInitializer
|
||||||
PPRI = PurePositiveReasoningsInitializer
|
PPRI = PurePositiveReasoningsInitializer
|
||||||
|
Loading…
Reference in New Issue
Block a user