Add standalone labels module
This commit is contained in:
parent
84e08955f7
commit
2af1da7f23
@ -41,6 +41,24 @@ def validate_reasonings_initializer(initializer):
|
|||||||
return validate_initializer(initializer, AbstractReasoningsInitializer)
|
return validate_initializer(initializer, AbstractReasoningsInitializer)
|
||||||
|
|
||||||
|
|
||||||
|
def gencat(ins, attr, init, *iargs, **ikwargs):
|
||||||
|
"""Generate new items and concatenate with existing items."""
|
||||||
|
new_items = init.generate(*iargs, **ikwargs)
|
||||||
|
if hasattr(ins, attr):
|
||||||
|
items = torch.cat([getattr(ins, attr), new_items])
|
||||||
|
else:
|
||||||
|
items = new_items
|
||||||
|
return items, new_items
|
||||||
|
|
||||||
|
|
||||||
|
def removeind(ins, attr, indices):
|
||||||
|
"""Remove items at specified indices."""
|
||||||
|
mask = torch.ones(len(ins), dtype=torch.bool)
|
||||||
|
mask[indices] = False
|
||||||
|
items = getattr(ins, attr)[mask]
|
||||||
|
return items, mask
|
||||||
|
|
||||||
|
|
||||||
class AbstractComponents(torch.nn.Module):
|
class AbstractComponents(torch.nn.Module):
|
||||||
"""Abstract class for all components modules."""
|
"""Abstract class for all components modules."""
|
||||||
@property
|
@property
|
||||||
@ -57,7 +75,10 @@ class AbstractComponents(torch.nn.Module):
|
|||||||
self.register_parameter("_components", Parameter(components))
|
self.register_parameter("_components", Parameter(components))
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f"(components): (shape: {tuple(self._components.shape)})"
|
return f"components: (shape: {tuple(self._components.shape)})"
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_components
|
||||||
|
|
||||||
|
|
||||||
class Components(AbstractComponents):
|
class Components(AbstractComponents):
|
||||||
@ -67,24 +88,18 @@ class Components(AbstractComponents):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.add_components(num_components, initializer)
|
self.add_components(num_components, initializer)
|
||||||
|
|
||||||
def add_components(self, num: int,
|
def add_components(self, num_components: int,
|
||||||
initializer: AbstractComponentsInitializer):
|
initializer: AbstractComponentsInitializer):
|
||||||
"""Add new components."""
|
"""Generate and add new components."""
|
||||||
assert validate_components_initializer(initializer)
|
assert validate_components_initializer(initializer)
|
||||||
new_components = initializer.generate(num)
|
_components, new_components = gencat(self, "_components", initializer,
|
||||||
# Register
|
num_components)
|
||||||
if hasattr(self, "_components"):
|
|
||||||
_components = torch.cat([self._components, new_components])
|
|
||||||
else:
|
|
||||||
_components = new_components
|
|
||||||
self._register_components(_components)
|
self._register_components(_components)
|
||||||
return new_components
|
return new_components
|
||||||
|
|
||||||
def remove_components(self, indices):
|
def remove_components(self, indices):
|
||||||
"""Remove components at specified indices."""
|
"""Remove components at specified indices."""
|
||||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
_components, mask = removeind(self, "_components", indices)
|
||||||
mask[indices] = False
|
|
||||||
_components = self._components[mask]
|
|
||||||
self._register_components(_components)
|
self._register_components(_components)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
@ -93,19 +108,90 @@ class Components(AbstractComponents):
|
|||||||
return self._components
|
return self._components
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractLabels(torch.nn.Module):
|
||||||
|
"""Abstract class for all labels modules."""
|
||||||
|
@property
|
||||||
|
def labels(self):
|
||||||
|
return self._labels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_labels(self):
|
||||||
|
return len(self.labels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unique_labels(self):
|
||||||
|
return torch.unique(self._labels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_unique(self):
|
||||||
|
return len(self.unique_labels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def distribution(self):
|
||||||
|
unique, counts = torch.unique(self._labels,
|
||||||
|
sorted=True,
|
||||||
|
return_counts=True)
|
||||||
|
return dict(zip(unique.tolist(), counts.tolist()))
|
||||||
|
|
||||||
|
def _register_labels(self, labels):
|
||||||
|
self.register_buffer("_labels", labels)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
r = f"num_labels: {self.num_labels}, num_unique: {self.num_unique}"
|
||||||
|
if len(self.distribution) < 11: # avoid lengthy representations
|
||||||
|
d = self.distribution
|
||||||
|
unique, counts = list(d.keys()), list(d.values())
|
||||||
|
r += f", unique: {unique}, counts: {counts}"
|
||||||
|
return r
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_labels
|
||||||
|
|
||||||
|
|
||||||
|
class Labels(AbstractLabels):
|
||||||
|
"""A set of standalone labels."""
|
||||||
|
def __init__(self,
|
||||||
|
distribution: Union[dict, list, tuple],
|
||||||
|
initializer: AbstractLabelsInitializer = LabelsInitializer(),
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.add_labels(distribution, initializer)
|
||||||
|
|
||||||
|
def add_labels(
|
||||||
|
self,
|
||||||
|
distribution: Union[dict, tuple, list],
|
||||||
|
initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||||
|
"""Generate and add new labels."""
|
||||||
|
assert validate_labels_initializer(initializer)
|
||||||
|
_labels, new_labels = gencat(self, "_labels", initializer,
|
||||||
|
distribution)
|
||||||
|
self._register_labels(_labels)
|
||||||
|
return new_labels
|
||||||
|
|
||||||
|
def remove_labels(self, indices):
|
||||||
|
"""Remove labels at specified indices."""
|
||||||
|
_labels, mask = removeind(self, "_labels", indices)
|
||||||
|
self._register_labels(_labels)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
class LabeledComponents(AbstractComponents):
|
class LabeledComponents(AbstractComponents):
|
||||||
"""A set of adaptable components and corresponding unadaptable labels."""
|
"""A set of adaptable components and corresponding unadaptable labels."""
|
||||||
def __init__(self, distribution: Union[dict, list, tuple],
|
def __init__(
|
||||||
components_initializer: AbstractComponentsInitializer,
|
self,
|
||||||
labels_initializer: AbstractLabelsInitializer, **kwargs):
|
distribution: Union[dict, list, tuple],
|
||||||
|
components_initializer: AbstractComponentsInitializer,
|
||||||
|
labels_initializer: AbstractLabelsInitializer = LabelsInitializer(
|
||||||
|
),
|
||||||
|
**kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.add_components(distribution, components_initializer,
|
self.add_components(distribution, components_initializer,
|
||||||
labels_initializer)
|
labels_initializer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def component_labels(self):
|
def labels(self):
|
||||||
"""Tensor containing the component tensors."""
|
"""Tensor containing the component labels."""
|
||||||
return self._labels.detach()
|
return self._labels
|
||||||
|
|
||||||
def _register_labels(self, labels):
|
def _register_labels(self, labels):
|
||||||
self.register_buffer("_labels", labels)
|
self.register_buffer("_labels", labels)
|
||||||
@ -115,42 +201,28 @@ class LabeledComponents(AbstractComponents):
|
|||||||
distribution,
|
distribution,
|
||||||
components_initializer,
|
components_initializer,
|
||||||
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||||
# Checks
|
"""Generate and add new components and labels."""
|
||||||
assert validate_components_initializer(components_initializer)
|
assert validate_components_initializer(components_initializer)
|
||||||
assert validate_labels_initializer(labels_initializer)
|
assert validate_labels_initializer(labels_initializer)
|
||||||
|
|
||||||
distribution = parse_distribution(distribution)
|
|
||||||
|
|
||||||
# Generate new components
|
|
||||||
if isinstance(components_initializer, ClassAwareCompInitializer):
|
if isinstance(components_initializer, ClassAwareCompInitializer):
|
||||||
new_components = components_initializer.generate(distribution)
|
cikwargs = dict(distribution=distribution)
|
||||||
else:
|
else:
|
||||||
|
distribution = parse_distribution(distribution)
|
||||||
num_components = sum(distribution.values())
|
num_components = sum(distribution.values())
|
||||||
new_components = components_initializer.generate(num_components)
|
cikwargs = dict(num_components=num_components)
|
||||||
|
_components, new_components = gencat(self, "_components",
|
||||||
# Generate new labels
|
components_initializer,
|
||||||
new_labels = labels_initializer.generate(distribution)
|
**cikwargs)
|
||||||
|
_labels, new_labels = gencat(self, "_labels", labels_initializer,
|
||||||
# Register
|
distribution)
|
||||||
if hasattr(self, "_components"):
|
|
||||||
_components = torch.cat([self._components, new_components])
|
|
||||||
else:
|
|
||||||
_components = new_components
|
|
||||||
if hasattr(self, "_labels"):
|
|
||||||
_labels = torch.cat([self._labels, new_labels])
|
|
||||||
else:
|
|
||||||
_labels = new_labels
|
|
||||||
self._register_components(_components)
|
self._register_components(_components)
|
||||||
self._register_labels(_labels)
|
self._register_labels(_labels)
|
||||||
|
|
||||||
return new_components, new_labels
|
return new_components, new_labels
|
||||||
|
|
||||||
def remove_components(self, indices):
|
def remove_components(self, indices):
|
||||||
"""Remove components and labels at specified indices."""
|
"""Remove components and labels at specified indices."""
|
||||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
_components, mask = removeind(self, "_components", indices)
|
||||||
mask[indices] = False
|
_labels, mask = removeind(self, "_labels", indices)
|
||||||
_components = self._components[mask]
|
|
||||||
_labels = self._labels[mask]
|
|
||||||
self._register_components(_components)
|
self._register_components(_components)
|
||||||
self._register_labels(_labels)
|
self._register_labels(_labels)
|
||||||
return mask
|
return mask
|
||||||
|
@ -157,6 +157,40 @@ def test_components_zeros_init():
|
|||||||
assert torch.allclose(c.components, torch.zeros(3, 2))
|
assert torch.allclose(c.components, torch.zeros(3, 2))
|
||||||
|
|
||||||
|
|
||||||
|
def test_labeled_components_dict_init():
|
||||||
|
c = pt.components.LabeledComponents({0: 3}, pt.initializers.OCI(2))
|
||||||
|
assert torch.allclose(c.components, torch.ones(3, 2))
|
||||||
|
assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long))
|
||||||
|
|
||||||
|
|
||||||
|
def test_labeled_components_list_init():
|
||||||
|
c = pt.components.LabeledComponents([3], pt.initializers.OCI(2))
|
||||||
|
assert torch.allclose(c.components, torch.ones(3, 2))
|
||||||
|
assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long))
|
||||||
|
|
||||||
|
|
||||||
|
def test_labeled_components_tuple_init():
|
||||||
|
c = pt.components.LabeledComponents({0: 1, 1: 2}, pt.initializers.OCI(2))
|
||||||
|
assert torch.allclose(c.components, torch.ones(3, 2))
|
||||||
|
assert torch.allclose(c.labels, torch.LongTensor([0, 1, 1]))
|
||||||
|
|
||||||
|
|
||||||
|
# Labels
|
||||||
|
def test_standalone_labels_dict_init():
|
||||||
|
l = pt.components.Labels({0: 3})
|
||||||
|
assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long))
|
||||||
|
|
||||||
|
|
||||||
|
def test_standalone_labels_list_init():
|
||||||
|
l = pt.components.Labels([3])
|
||||||
|
assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long))
|
||||||
|
|
||||||
|
|
||||||
|
def test_standalone_labels_tuple_init():
|
||||||
|
l = pt.components.Labels({0: 1, 1: 2})
|
||||||
|
assert torch.allclose(l.labels, torch.LongTensor([0, 1, 1]))
|
||||||
|
|
||||||
|
|
||||||
# Losses
|
# Losses
|
||||||
def test_glvq_loss_int_labels():
|
def test_glvq_loss_int_labels():
|
||||||
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||||
|
Loading…
Reference in New Issue
Block a user