From 2af1da7f23705ff0162bc83d27b3575f88aef455 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sun, 13 Jun 2021 22:54:29 +0000 Subject: [PATCH] Add standalone labels module --- prototorch/core/components.py | 158 +++++++++++++++++++++++++--------- tests/test_core.py | 34 ++++++++ 2 files changed, 149 insertions(+), 43 deletions(-) diff --git a/prototorch/core/components.py b/prototorch/core/components.py index 53555af..d2b0f40 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -41,6 +41,24 @@ def validate_reasonings_initializer(initializer): return validate_initializer(initializer, AbstractReasoningsInitializer) +def gencat(ins, attr, init, *iargs, **ikwargs): + """Generate new items and concatenate with existing items.""" + new_items = init.generate(*iargs, **ikwargs) + if hasattr(ins, attr): + items = torch.cat([getattr(ins, attr), new_items]) + else: + items = new_items + return items, new_items + + +def removeind(ins, attr, indices): + """Remove items at specified indices.""" + mask = torch.ones(len(ins), dtype=torch.bool) + mask[indices] = False + items = getattr(ins, attr)[mask] + return items, mask + + class AbstractComponents(torch.nn.Module): """Abstract class for all components modules.""" @property @@ -57,7 +75,10 @@ class AbstractComponents(torch.nn.Module): self.register_parameter("_components", Parameter(components)) def extra_repr(self): - return f"(components): (shape: {tuple(self._components.shape)})" + return f"components: (shape: {tuple(self._components.shape)})" + + def __len__(self): + return self.num_components class Components(AbstractComponents): @@ -67,24 +88,18 @@ class Components(AbstractComponents): super().__init__(**kwargs) self.add_components(num_components, initializer) - def add_components(self, num: int, + def add_components(self, num_components: int, initializer: AbstractComponentsInitializer): - """Add new components.""" + """Generate and add new components.""" assert validate_components_initializer(initializer) - new_components = initializer.generate(num) - # Register - if hasattr(self, "_components"): - _components = torch.cat([self._components, new_components]) - else: - _components = new_components + _components, new_components = gencat(self, "_components", initializer, + num_components) self._register_components(_components) return new_components def remove_components(self, indices): """Remove components at specified indices.""" - mask = torch.ones(self.num_components, dtype=torch.bool) - mask[indices] = False - _components = self._components[mask] + _components, mask = removeind(self, "_components", indices) self._register_components(_components) return mask @@ -93,19 +108,90 @@ class Components(AbstractComponents): return self._components +class AbstractLabels(torch.nn.Module): + """Abstract class for all labels modules.""" + @property + def labels(self): + return self._labels + + @property + def num_labels(self): + return len(self.labels) + + @property + def unique_labels(self): + return torch.unique(self._labels) + + @property + def num_unique(self): + return len(self.unique_labels) + + @property + def distribution(self): + unique, counts = torch.unique(self._labels, + sorted=True, + return_counts=True) + return dict(zip(unique.tolist(), counts.tolist())) + + def _register_labels(self, labels): + self.register_buffer("_labels", labels) + + def extra_repr(self): + r = f"num_labels: {self.num_labels}, num_unique: {self.num_unique}" + if len(self.distribution) < 11: # avoid lengthy representations + d = self.distribution + unique, counts = list(d.keys()), list(d.values()) + r += f", unique: {unique}, counts: {counts}" + return r + + def __len__(self): + return self.num_labels + + +class Labels(AbstractLabels): + """A set of standalone labels.""" + def __init__(self, + distribution: Union[dict, list, tuple], + initializer: AbstractLabelsInitializer = LabelsInitializer(), + **kwargs): + super().__init__(**kwargs) + self.add_labels(distribution, initializer) + + def add_labels( + self, + distribution: Union[dict, tuple, list], + initializer: AbstractLabelsInitializer = LabelsInitializer()): + """Generate and add new labels.""" + assert validate_labels_initializer(initializer) + _labels, new_labels = gencat(self, "_labels", initializer, + distribution) + self._register_labels(_labels) + return new_labels + + def remove_labels(self, indices): + """Remove labels at specified indices.""" + _labels, mask = removeind(self, "_labels", indices) + self._register_labels(_labels) + return mask + + class LabeledComponents(AbstractComponents): """A set of adaptable components and corresponding unadaptable labels.""" - def __init__(self, distribution: Union[dict, list, tuple], - components_initializer: AbstractComponentsInitializer, - labels_initializer: AbstractLabelsInitializer, **kwargs): + def __init__( + self, + distribution: Union[dict, list, tuple], + components_initializer: AbstractComponentsInitializer, + labels_initializer: AbstractLabelsInitializer = LabelsInitializer( + ), + **kwargs): super().__init__(**kwargs) self.add_components(distribution, components_initializer, labels_initializer) @property - def component_labels(self): - """Tensor containing the component tensors.""" - return self._labels.detach() + def labels(self): + """Tensor containing the component labels.""" + return self._labels def _register_labels(self, labels): self.register_buffer("_labels", labels) @@ -115,42 +201,28 @@ class LabeledComponents(AbstractComponents): distribution, components_initializer, labels_initializer: AbstractLabelsInitializer = LabelsInitializer()): - # Checks + """Generate and add new components and labels.""" assert validate_components_initializer(components_initializer) assert validate_labels_initializer(labels_initializer) - - distribution = parse_distribution(distribution) - - # Generate new components if isinstance(components_initializer, ClassAwareCompInitializer): - new_components = components_initializer.generate(distribution) + cikwargs = dict(distribution=distribution) else: + distribution = parse_distribution(distribution) num_components = sum(distribution.values()) - new_components = components_initializer.generate(num_components) - - # Generate new labels - new_labels = labels_initializer.generate(distribution) - - # Register - if hasattr(self, "_components"): - _components = torch.cat([self._components, new_components]) - else: - _components = new_components - if hasattr(self, "_labels"): - _labels = torch.cat([self._labels, new_labels]) - else: - _labels = new_labels + cikwargs = dict(num_components=num_components) + _components, new_components = gencat(self, "_components", + components_initializer, + **cikwargs) + _labels, new_labels = gencat(self, "_labels", labels_initializer, + distribution) self._register_components(_components) self._register_labels(_labels) - return new_components, new_labels def remove_components(self, indices): """Remove components and labels at specified indices.""" - mask = torch.ones(self.num_components, dtype=torch.bool) - mask[indices] = False - _components = self._components[mask] - _labels = self._labels[mask] + _components, mask = removeind(self, "_components", indices) + _labels, mask = removeind(self, "_labels", indices) self._register_components(_components) self._register_labels(_labels) return mask diff --git a/tests/test_core.py b/tests/test_core.py index d2496c8..191569e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -157,6 +157,40 @@ def test_components_zeros_init(): assert torch.allclose(c.components, torch.zeros(3, 2)) +def test_labeled_components_dict_init(): + c = pt.components.LabeledComponents({0: 3}, pt.initializers.OCI(2)) + assert torch.allclose(c.components, torch.ones(3, 2)) + assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long)) + + +def test_labeled_components_list_init(): + c = pt.components.LabeledComponents([3], pt.initializers.OCI(2)) + assert torch.allclose(c.components, torch.ones(3, 2)) + assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long)) + + +def test_labeled_components_tuple_init(): + c = pt.components.LabeledComponents({0: 1, 1: 2}, pt.initializers.OCI(2)) + assert torch.allclose(c.components, torch.ones(3, 2)) + assert torch.allclose(c.labels, torch.LongTensor([0, 1, 1])) + + +# Labels +def test_standalone_labels_dict_init(): + l = pt.components.Labels({0: 3}) + assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long)) + + +def test_standalone_labels_list_init(): + l = pt.components.Labels([3]) + assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long)) + + +def test_standalone_labels_tuple_init(): + l = pt.components.Labels({0: 1, 1: 2}) + assert torch.allclose(l.labels, torch.LongTensor([0, 1, 1])) + + # Losses def test_glvq_loss_int_labels(): d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)