[REFACTOR] Add reasonings initializers
This commit is contained in:
parent
668c9a1fb7
commit
083cc929be
@ -205,10 +205,7 @@ class AbstractLabelsInitializer(ABC):
|
||||
|
||||
|
||||
class LabelsInitializer(AbstractLabelsInitializer):
|
||||
"""Generate labels with `self.distribution`."""
|
||||
def __init__(self, override_labels: list = []):
|
||||
self.override_labels = override_labels
|
||||
|
||||
"""Generate labels from `distribution`."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
labels = []
|
||||
@ -218,25 +215,79 @@ class LabelsInitializer(AbstractLabelsInitializer):
|
||||
return labels
|
||||
|
||||
|
||||
# Reasonings
|
||||
class AbstractReasoningsInitializer(ABC):
|
||||
"""Abstract class for all reasonings initializers."""
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
|
||||
|
||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Generate labels with `self.distribution`."""
|
||||
class OneHotLabelsInitializer(LabelsInitializer):
|
||||
"""Generate one-hot-encoded labels from `distribution`."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
num_classes = len(distribution.keys())
|
||||
# this breaks if class labels are not [0,...,nclasses]
|
||||
labels = torch.eye(num_classes)[super().generate(distribution)]
|
||||
return labels
|
||||
|
||||
|
||||
# Reasonings
|
||||
class AbstractReasoningsInitializer(ABC):
|
||||
"""Abstract class for all reasonings initializers."""
|
||||
def __init__(self, components_first=True):
|
||||
self.components_first = components_first
|
||||
|
||||
def compute_shape(self, distribution):
|
||||
distribution = parse_distribution(distribution)
|
||||
num_components = sum(distribution.values())
|
||||
assert num_classes == num_components
|
||||
reasonings = torch.stack(
|
||||
[torch.eye(num_classes),
|
||||
torch.zeros(num_classes, num_classes)],
|
||||
dim=0)
|
||||
num_classes = len(distribution.keys())
|
||||
return (num_components, num_classes, 2)
|
||||
|
||||
def generate_end_hook(self, reasonings):
|
||||
if not self.components_first:
|
||||
reasonings = reasonings.permute(2, 1, 0)
|
||||
return reasonings
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
return generate_end_hook(...)
|
||||
|
||||
|
||||
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are all initialized with zeros."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
reasonings = torch.zeros(*shape)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are all initialized with ones."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
reasonings = torch.ones(*shape)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are randomly initialized."""
|
||||
def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Each component reasons positively for exactly one class."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
num_components, num_classes, _ = self.compute_shape(distribution)
|
||||
A = OneHotLabelsInitializer().generate(distribution)
|
||||
B = torch.zeros(num_components, num_classes)
|
||||
reasonings = torch.stack([A, B]).permute(2, 1, 0)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
@ -251,4 +302,13 @@ SCI = SelectionCompInitializer
|
||||
MCI = MeanCompInitializer
|
||||
SSCI = StratifiedSelectionCompInitializer
|
||||
SMCI = StratifiedMeanCompInitializer
|
||||
|
||||
# Aliases - Labels
|
||||
LI = LabelsInitializer
|
||||
OHLI = OneHotLabelsInitializer
|
||||
|
||||
# Aliases - Reasonings
|
||||
ZRI = ZerosReasoningsInitializer
|
||||
ORI = OnesReasoningsInitializer
|
||||
RRI = RandomReasoningsInitializer
|
||||
PPRI = PurePositiveReasoningsInitializer
|
||||
|
Loading…
Reference in New Issue
Block a user