[REFACTOR] Add reasonings initializers

This commit is contained in:
Jensun Ravichandran 2021-06-14 17:19:45 +02:00
parent 668c9a1fb7
commit 083cc929be

View File

@ -205,10 +205,7 @@ class AbstractLabelsInitializer(ABC):
class LabelsInitializer(AbstractLabelsInitializer): class LabelsInitializer(AbstractLabelsInitializer):
"""Generate labels with `self.distribution`.""" """Generate labels from `distribution`."""
def __init__(self, override_labels: list = []):
self.override_labels = override_labels
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution) distribution = parse_distribution(distribution)
labels = [] labels = []
@ -218,25 +215,79 @@ class LabelsInitializer(AbstractLabelsInitializer):
return labels return labels
# Reasonings class OneHotLabelsInitializer(LabelsInitializer):
class AbstractReasoningsInitializer(ABC): """Generate one-hot-encoded labels from `distribution`."""
"""Abstract class for all reasonings initializers."""
@abstractmethod
def generate(self, distribution: Union[dict, list, tuple]):
...
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
"""Generate labels with `self.distribution`."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution) distribution = parse_distribution(distribution)
num_classes = len(distribution.keys()) num_classes = len(distribution.keys())
# this breaks if class labels are not [0,...,nclasses]
labels = torch.eye(num_classes)[super().generate(distribution)]
return labels
# Reasonings
class AbstractReasoningsInitializer(ABC):
"""Abstract class for all reasonings initializers."""
def __init__(self, components_first=True):
self.components_first = components_first
def compute_shape(self, distribution):
distribution = parse_distribution(distribution)
num_components = sum(distribution.values()) num_components = sum(distribution.values())
assert num_classes == num_components num_classes = len(distribution.keys())
reasonings = torch.stack( return (num_components, num_classes, 2)
[torch.eye(num_classes),
torch.zeros(num_classes, num_classes)], def generate_end_hook(self, reasonings):
dim=0) if not self.components_first:
reasonings = reasonings.permute(2, 1, 0)
return reasonings
@abstractmethod
def generate(self, distribution: Union[dict, list, tuple]):
...
return generate_end_hook(...)
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with zeros."""
def generate(self, distribution: Union[dict, list, tuple]):
shape = self.compute_shape(distribution)
reasonings = torch.zeros(*shape)
reasonings = self.generate_end_hook(reasonings)
return reasonings
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with ones."""
def generate(self, distribution: Union[dict, list, tuple]):
shape = self.compute_shape(distribution)
reasonings = torch.ones(*shape)
reasonings = self.generate_end_hook(reasonings)
return reasonings
class RandomReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are randomly initialized."""
def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
super().__init__(**kwargs)
self.minimum = minimum
self.maximum = maximum
def generate(self, distribution: Union[dict, list, tuple]):
shape = self.compute_shape(distribution)
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
reasonings = self.generate_end_hook(reasonings)
return reasonings
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
"""Each component reasons positively for exactly one class."""
def generate(self, distribution: Union[dict, list, tuple]):
num_components, num_classes, _ = self.compute_shape(distribution)
A = OneHotLabelsInitializer().generate(distribution)
B = torch.zeros(num_components, num_classes)
reasonings = torch.stack([A, B]).permute(2, 1, 0)
reasonings = self.generate_end_hook(reasonings)
return reasonings return reasonings
@ -251,4 +302,13 @@ SCI = SelectionCompInitializer
MCI = MeanCompInitializer MCI = MeanCompInitializer
SSCI = StratifiedSelectionCompInitializer SSCI = StratifiedSelectionCompInitializer
SMCI = StratifiedMeanCompInitializer SMCI = StratifiedMeanCompInitializer
# Aliases - Labels
LI = LabelsInitializer
OHLI = OneHotLabelsInitializer
# Aliases - Reasonings
ZRI = ZerosReasoningsInitializer
ORI = OnesReasoningsInitializer
RRI = RandomReasoningsInitializer
PPRI = PurePositiveReasoningsInitializer PPRI = PurePositiveReasoningsInitializer