[REFACTOR] Add reasonings initializers
This commit is contained in:
		@@ -205,10 +205,7 @@ class AbstractLabelsInitializer(ABC):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LabelsInitializer(AbstractLabelsInitializer):
 | 
					class LabelsInitializer(AbstractLabelsInitializer):
 | 
				
			||||||
    """Generate labels with `self.distribution`."""
 | 
					    """Generate labels from `distribution`."""
 | 
				
			||||||
    def __init__(self, override_labels: list = []):
 | 
					 | 
				
			||||||
        self.override_labels = override_labels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self, distribution: Union[dict, list, tuple]):
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
        distribution = parse_distribution(distribution)
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
        labels = []
 | 
					        labels = []
 | 
				
			||||||
@@ -218,25 +215,79 @@ class LabelsInitializer(AbstractLabelsInitializer):
 | 
				
			|||||||
        return labels
 | 
					        return labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Reasonings
 | 
					class OneHotLabelsInitializer(LabelsInitializer):
 | 
				
			||||||
class AbstractReasoningsInitializer(ABC):
 | 
					    """Generate one-hot-encoded labels from `distribution`."""
 | 
				
			||||||
    """Abstract class for all reasonings initializers."""
 | 
					 | 
				
			||||||
    @abstractmethod
 | 
					 | 
				
			||||||
    def generate(self, distribution: Union[dict, list, tuple]):
 | 
					 | 
				
			||||||
        ...
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
 | 
					 | 
				
			||||||
    """Generate labels with `self.distribution`."""
 | 
					 | 
				
			||||||
    def generate(self, distribution: Union[dict, list, tuple]):
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
        distribution = parse_distribution(distribution)
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
        num_classes = len(distribution.keys())
 | 
					        num_classes = len(distribution.keys())
 | 
				
			||||||
 | 
					        # this breaks if class labels are not [0,...,nclasses]
 | 
				
			||||||
 | 
					        labels = torch.eye(num_classes)[super().generate(distribution)]
 | 
				
			||||||
 | 
					        return labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Reasonings
 | 
				
			||||||
 | 
					class AbstractReasoningsInitializer(ABC):
 | 
				
			||||||
 | 
					    """Abstract class for all reasonings initializers."""
 | 
				
			||||||
 | 
					    def __init__(self, components_first=True):
 | 
				
			||||||
 | 
					        self.components_first = components_first
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def compute_shape(self, distribution):
 | 
				
			||||||
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
        num_components = sum(distribution.values())
 | 
					        num_components = sum(distribution.values())
 | 
				
			||||||
        assert num_classes == num_components
 | 
					        num_classes = len(distribution.keys())
 | 
				
			||||||
        reasonings = torch.stack(
 | 
					        return (num_components, num_classes, 2)
 | 
				
			||||||
            [torch.eye(num_classes),
 | 
					
 | 
				
			||||||
             torch.zeros(num_classes, num_classes)],
 | 
					    def generate_end_hook(self, reasonings):
 | 
				
			||||||
            dim=0)
 | 
					        if not self.components_first:
 | 
				
			||||||
 | 
					            reasonings = reasonings.permute(2, 1, 0)
 | 
				
			||||||
 | 
					        return reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					        return generate_end_hook(...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
 | 
				
			||||||
 | 
					    """Reasonings are all initialized with zeros."""
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        shape = self.compute_shape(distribution)
 | 
				
			||||||
 | 
					        reasonings = torch.zeros(*shape)
 | 
				
			||||||
 | 
					        reasonings = self.generate_end_hook(reasonings)
 | 
				
			||||||
 | 
					        return reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class OnesReasoningsInitializer(AbstractReasoningsInitializer):
 | 
				
			||||||
 | 
					    """Reasonings are all initialized with ones."""
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        shape = self.compute_shape(distribution)
 | 
				
			||||||
 | 
					        reasonings = torch.ones(*shape)
 | 
				
			||||||
 | 
					        reasonings = self.generate_end_hook(reasonings)
 | 
				
			||||||
 | 
					        return reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RandomReasoningsInitializer(AbstractReasoningsInitializer):
 | 
				
			||||||
 | 
					    """Reasonings are randomly initialized."""
 | 
				
			||||||
 | 
					    def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					        self.minimum = minimum
 | 
				
			||||||
 | 
					        self.maximum = maximum
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        shape = self.compute_shape(distribution)
 | 
				
			||||||
 | 
					        reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
 | 
				
			||||||
 | 
					        reasonings = self.generate_end_hook(reasonings)
 | 
				
			||||||
 | 
					        return reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
 | 
				
			||||||
 | 
					    """Each component reasons positively for exactly one class."""
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        num_components, num_classes, _ = self.compute_shape(distribution)
 | 
				
			||||||
 | 
					        A = OneHotLabelsInitializer().generate(distribution)
 | 
				
			||||||
 | 
					        B = torch.zeros(num_components, num_classes)
 | 
				
			||||||
 | 
					        reasonings = torch.stack([A, B]).permute(2, 1, 0)
 | 
				
			||||||
 | 
					        reasonings = self.generate_end_hook(reasonings)
 | 
				
			||||||
        return reasonings
 | 
					        return reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -251,4 +302,13 @@ SCI = SelectionCompInitializer
 | 
				
			|||||||
MCI = MeanCompInitializer
 | 
					MCI = MeanCompInitializer
 | 
				
			||||||
SSCI = StratifiedSelectionCompInitializer
 | 
					SSCI = StratifiedSelectionCompInitializer
 | 
				
			||||||
SMCI = StratifiedMeanCompInitializer
 | 
					SMCI = StratifiedMeanCompInitializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Aliases - Labels
 | 
				
			||||||
 | 
					LI = LabelsInitializer
 | 
				
			||||||
 | 
					OHLI = OneHotLabelsInitializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Aliases - Reasonings
 | 
				
			||||||
 | 
					ZRI = ZerosReasoningsInitializer
 | 
				
			||||||
 | 
					ORI = OnesReasoningsInitializer
 | 
				
			||||||
 | 
					RRI = RandomReasoningsInitializer
 | 
				
			||||||
PPRI = PurePositiveReasoningsInitializer
 | 
					PPRI = PurePositiveReasoningsInitializer
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user