Add standalone labels module

This commit is contained in:
Jensun Ravichandran 2021-06-13 22:54:29 +00:00
parent 84e08955f7
commit 2af1da7f23
2 changed files with 149 additions and 43 deletions

View File

@ -41,6 +41,24 @@ def validate_reasonings_initializer(initializer):
return validate_initializer(initializer, AbstractReasoningsInitializer) return validate_initializer(initializer, AbstractReasoningsInitializer)
def gencat(ins, attr, init, *iargs, **ikwargs):
"""Generate new items and concatenate with existing items."""
new_items = init.generate(*iargs, **ikwargs)
if hasattr(ins, attr):
items = torch.cat([getattr(ins, attr), new_items])
else:
items = new_items
return items, new_items
def removeind(ins, attr, indices):
"""Remove items at specified indices."""
mask = torch.ones(len(ins), dtype=torch.bool)
mask[indices] = False
items = getattr(ins, attr)[mask]
return items, mask
class AbstractComponents(torch.nn.Module): class AbstractComponents(torch.nn.Module):
"""Abstract class for all components modules.""" """Abstract class for all components modules."""
@property @property
@ -57,7 +75,10 @@ class AbstractComponents(torch.nn.Module):
self.register_parameter("_components", Parameter(components)) self.register_parameter("_components", Parameter(components))
def extra_repr(self): def extra_repr(self):
return f"(components): (shape: {tuple(self._components.shape)})" return f"components: (shape: {tuple(self._components.shape)})"
def __len__(self):
return self.num_components
class Components(AbstractComponents): class Components(AbstractComponents):
@ -67,24 +88,18 @@ class Components(AbstractComponents):
super().__init__(**kwargs) super().__init__(**kwargs)
self.add_components(num_components, initializer) self.add_components(num_components, initializer)
def add_components(self, num: int, def add_components(self, num_components: int,
initializer: AbstractComponentsInitializer): initializer: AbstractComponentsInitializer):
"""Add new components.""" """Generate and add new components."""
assert validate_components_initializer(initializer) assert validate_components_initializer(initializer)
new_components = initializer.generate(num) _components, new_components = gencat(self, "_components", initializer,
# Register num_components)
if hasattr(self, "_components"):
_components = torch.cat([self._components, new_components])
else:
_components = new_components
self._register_components(_components) self._register_components(_components)
return new_components return new_components
def remove_components(self, indices): def remove_components(self, indices):
"""Remove components at specified indices.""" """Remove components at specified indices."""
mask = torch.ones(self.num_components, dtype=torch.bool) _components, mask = removeind(self, "_components", indices)
mask[indices] = False
_components = self._components[mask]
self._register_components(_components) self._register_components(_components)
return mask return mask
@ -93,19 +108,90 @@ class Components(AbstractComponents):
return self._components return self._components
class AbstractLabels(torch.nn.Module):
"""Abstract class for all labels modules."""
@property
def labels(self):
return self._labels
@property
def num_labels(self):
return len(self.labels)
@property
def unique_labels(self):
return torch.unique(self._labels)
@property
def num_unique(self):
return len(self.unique_labels)
@property
def distribution(self):
unique, counts = torch.unique(self._labels,
sorted=True,
return_counts=True)
return dict(zip(unique.tolist(), counts.tolist()))
def _register_labels(self, labels):
self.register_buffer("_labels", labels)
def extra_repr(self):
r = f"num_labels: {self.num_labels}, num_unique: {self.num_unique}"
if len(self.distribution) < 11: # avoid lengthy representations
d = self.distribution
unique, counts = list(d.keys()), list(d.values())
r += f", unique: {unique}, counts: {counts}"
return r
def __len__(self):
return self.num_labels
class Labels(AbstractLabels):
"""A set of standalone labels."""
def __init__(self,
distribution: Union[dict, list, tuple],
initializer: AbstractLabelsInitializer = LabelsInitializer(),
**kwargs):
super().__init__(**kwargs)
self.add_labels(distribution, initializer)
def add_labels(
self,
distribution: Union[dict, tuple, list],
initializer: AbstractLabelsInitializer = LabelsInitializer()):
"""Generate and add new labels."""
assert validate_labels_initializer(initializer)
_labels, new_labels = gencat(self, "_labels", initializer,
distribution)
self._register_labels(_labels)
return new_labels
def remove_labels(self, indices):
"""Remove labels at specified indices."""
_labels, mask = removeind(self, "_labels", indices)
self._register_labels(_labels)
return mask
class LabeledComponents(AbstractComponents): class LabeledComponents(AbstractComponents):
"""A set of adaptable components and corresponding unadaptable labels.""" """A set of adaptable components and corresponding unadaptable labels."""
def __init__(self, distribution: Union[dict, list, tuple], def __init__(
self,
distribution: Union[dict, list, tuple],
components_initializer: AbstractComponentsInitializer, components_initializer: AbstractComponentsInitializer,
labels_initializer: AbstractLabelsInitializer, **kwargs): labels_initializer: AbstractLabelsInitializer = LabelsInitializer(
),
**kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.add_components(distribution, components_initializer, self.add_components(distribution, components_initializer,
labels_initializer) labels_initializer)
@property @property
def component_labels(self): def labels(self):
"""Tensor containing the component tensors.""" """Tensor containing the component labels."""
return self._labels.detach() return self._labels
def _register_labels(self, labels): def _register_labels(self, labels):
self.register_buffer("_labels", labels) self.register_buffer("_labels", labels)
@ -115,42 +201,28 @@ class LabeledComponents(AbstractComponents):
distribution, distribution,
components_initializer, components_initializer,
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()): labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
# Checks """Generate and add new components and labels."""
assert validate_components_initializer(components_initializer) assert validate_components_initializer(components_initializer)
assert validate_labels_initializer(labels_initializer) assert validate_labels_initializer(labels_initializer)
distribution = parse_distribution(distribution)
# Generate new components
if isinstance(components_initializer, ClassAwareCompInitializer): if isinstance(components_initializer, ClassAwareCompInitializer):
new_components = components_initializer.generate(distribution) cikwargs = dict(distribution=distribution)
else: else:
distribution = parse_distribution(distribution)
num_components = sum(distribution.values()) num_components = sum(distribution.values())
new_components = components_initializer.generate(num_components) cikwargs = dict(num_components=num_components)
_components, new_components = gencat(self, "_components",
# Generate new labels components_initializer,
new_labels = labels_initializer.generate(distribution) **cikwargs)
_labels, new_labels = gencat(self, "_labels", labels_initializer,
# Register distribution)
if hasattr(self, "_components"):
_components = torch.cat([self._components, new_components])
else:
_components = new_components
if hasattr(self, "_labels"):
_labels = torch.cat([self._labels, new_labels])
else:
_labels = new_labels
self._register_components(_components) self._register_components(_components)
self._register_labels(_labels) self._register_labels(_labels)
return new_components, new_labels return new_components, new_labels
def remove_components(self, indices): def remove_components(self, indices):
"""Remove components and labels at specified indices.""" """Remove components and labels at specified indices."""
mask = torch.ones(self.num_components, dtype=torch.bool) _components, mask = removeind(self, "_components", indices)
mask[indices] = False _labels, mask = removeind(self, "_labels", indices)
_components = self._components[mask]
_labels = self._labels[mask]
self._register_components(_components) self._register_components(_components)
self._register_labels(_labels) self._register_labels(_labels)
return mask return mask

View File

@ -157,6 +157,40 @@ def test_components_zeros_init():
assert torch.allclose(c.components, torch.zeros(3, 2)) assert torch.allclose(c.components, torch.zeros(3, 2))
def test_labeled_components_dict_init():
c = pt.components.LabeledComponents({0: 3}, pt.initializers.OCI(2))
assert torch.allclose(c.components, torch.ones(3, 2))
assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long))
def test_labeled_components_list_init():
c = pt.components.LabeledComponents([3], pt.initializers.OCI(2))
assert torch.allclose(c.components, torch.ones(3, 2))
assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long))
def test_labeled_components_tuple_init():
c = pt.components.LabeledComponents({0: 1, 1: 2}, pt.initializers.OCI(2))
assert torch.allclose(c.components, torch.ones(3, 2))
assert torch.allclose(c.labels, torch.LongTensor([0, 1, 1]))
# Labels
def test_standalone_labels_dict_init():
l = pt.components.Labels({0: 3})
assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long))
def test_standalone_labels_list_init():
l = pt.components.Labels([3])
assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long))
def test_standalone_labels_tuple_init():
l = pt.components.Labels({0: 1, 1: 2})
assert torch.allclose(l.labels, torch.LongTensor([0, 1, 1]))
# Losses # Losses
def test_glvq_loss_int_labels(): def test_glvq_loss_int_labels():
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)