diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index b361f35..da0341d 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -205,10 +205,7 @@ class AbstractLabelsInitializer(ABC): class LabelsInitializer(AbstractLabelsInitializer): - """Generate labels with `self.distribution`.""" - def __init__(self, override_labels: list = []): - self.override_labels = override_labels - + """Generate labels from `distribution`.""" def generate(self, distribution: Union[dict, list, tuple]): distribution = parse_distribution(distribution) labels = [] @@ -218,25 +215,79 @@ class LabelsInitializer(AbstractLabelsInitializer): return labels -# Reasonings -class AbstractReasoningsInitializer(ABC): - """Abstract class for all reasonings initializers.""" - @abstractmethod - def generate(self, distribution: Union[dict, list, tuple]): - ... - - -class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): - """Generate labels with `self.distribution`.""" +class OneHotLabelsInitializer(LabelsInitializer): + """Generate one-hot-encoded labels from `distribution`.""" def generate(self, distribution: Union[dict, list, tuple]): distribution = parse_distribution(distribution) num_classes = len(distribution.keys()) + # this breaks if class labels are not [0,...,nclasses] + labels = torch.eye(num_classes)[super().generate(distribution)] + return labels + + +# Reasonings +class AbstractReasoningsInitializer(ABC): + """Abstract class for all reasonings initializers.""" + def __init__(self, components_first=True): + self.components_first = components_first + + def compute_shape(self, distribution): + distribution = parse_distribution(distribution) num_components = sum(distribution.values()) - assert num_classes == num_components - reasonings = torch.stack( - [torch.eye(num_classes), - torch.zeros(num_classes, num_classes)], - dim=0) + num_classes = len(distribution.keys()) + return (num_components, num_classes, 2) + + def generate_end_hook(self, reasonings): + if not self.components_first: + reasonings = reasonings.permute(2, 1, 0) + return reasonings + + @abstractmethod + def generate(self, distribution: Union[dict, list, tuple]): + ... + return generate_end_hook(...) + + +class ZerosReasoningsInitializer(AbstractReasoningsInitializer): + """Reasonings are all initialized with zeros.""" + def generate(self, distribution: Union[dict, list, tuple]): + shape = self.compute_shape(distribution) + reasonings = torch.zeros(*shape) + reasonings = self.generate_end_hook(reasonings) + return reasonings + + +class OnesReasoningsInitializer(AbstractReasoningsInitializer): + """Reasonings are all initialized with ones.""" + def generate(self, distribution: Union[dict, list, tuple]): + shape = self.compute_shape(distribution) + reasonings = torch.ones(*shape) + reasonings = self.generate_end_hook(reasonings) + return reasonings + + +class RandomReasoningsInitializer(AbstractReasoningsInitializer): + """Reasonings are randomly initialized.""" + def __init__(self, minimum=0.4, maximum=0.6, **kwargs): + super().__init__(**kwargs) + self.minimum = minimum + self.maximum = maximum + + def generate(self, distribution: Union[dict, list, tuple]): + shape = self.compute_shape(distribution) + reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum) + reasonings = self.generate_end_hook(reasonings) + return reasonings + + +class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): + """Each component reasons positively for exactly one class.""" + def generate(self, distribution: Union[dict, list, tuple]): + num_components, num_classes, _ = self.compute_shape(distribution) + A = OneHotLabelsInitializer().generate(distribution) + B = torch.zeros(num_components, num_classes) + reasonings = torch.stack([A, B]).permute(2, 1, 0) + reasonings = self.generate_end_hook(reasonings) return reasonings @@ -251,4 +302,13 @@ SCI = SelectionCompInitializer MCI = MeanCompInitializer SSCI = StratifiedSelectionCompInitializer SMCI = StratifiedMeanCompInitializer + +# Aliases - Labels +LI = LabelsInitializer +OHLI = OneHotLabelsInitializer + +# Aliases - Reasonings +ZRI = ZerosReasoningsInitializer +ORI = OnesReasoningsInitializer +RRI = RandomReasoningsInitializer PPRI = PurePositiveReasoningsInitializer