prototorch/prototorch/core/components.py
2021-06-13 22:54:29 +00:00

316 lines
11 KiB
Python

"""ProtoTorch components"""
import inspect
from typing import Union
import torch
from torch.nn.parameter import Parameter
from ..utils import parse_distribution
from .initializers import (
AbstractComponentsInitializer,
AbstractLabelsInitializer,
AbstractReasoningsInitializer,
ClassAwareCompInitializer,
LabelsInitializer,
)
def validate_initializer(initializer, instanceof):
if not isinstance(initializer, instanceof):
emsg = f"`initializer` has to be an instance " \
f"of some subtype of {instanceof}. " \
f"You have provided: {initializer} instead. "
helpmsg = ""
if inspect.isclass(initializer):
helpmsg = f"Perhaps you meant to say, {initializer.__name__}() " \
f"with the brackets instead of just {initializer.__name__}?"
raise TypeError(emsg + helpmsg)
return True
def validate_components_initializer(initializer):
return validate_initializer(initializer, AbstractComponentsInitializer)
def validate_labels_initializer(initializer):
return validate_initializer(initializer, AbstractLabelsInitializer)
def validate_reasonings_initializer(initializer):
return validate_initializer(initializer, AbstractReasoningsInitializer)
def gencat(ins, attr, init, *iargs, **ikwargs):
"""Generate new items and concatenate with existing items."""
new_items = init.generate(*iargs, **ikwargs)
if hasattr(ins, attr):
items = torch.cat([getattr(ins, attr), new_items])
else:
items = new_items
return items, new_items
def removeind(ins, attr, indices):
"""Remove items at specified indices."""
mask = torch.ones(len(ins), dtype=torch.bool)
mask[indices] = False
items = getattr(ins, attr)[mask]
return items, mask
class AbstractComponents(torch.nn.Module):
"""Abstract class for all components modules."""
@property
def num_components(self):
"""Current number of components."""
return len(self._components)
@property
def components(self):
"""Detached Tensor containing the components."""
return self._components.detach()
def _register_components(self, components):
self.register_parameter("_components", Parameter(components))
def extra_repr(self):
return f"components: (shape: {tuple(self._components.shape)})"
def __len__(self):
return self.num_components
class Components(AbstractComponents):
"""A set of adaptable Tensors."""
def __init__(self, num_components: int,
initializer: AbstractComponentsInitializer, **kwargs):
super().__init__(**kwargs)
self.add_components(num_components, initializer)
def add_components(self, num_components: int,
initializer: AbstractComponentsInitializer):
"""Generate and add new components."""
assert validate_components_initializer(initializer)
_components, new_components = gencat(self, "_components", initializer,
num_components)
self._register_components(_components)
return new_components
def remove_components(self, indices):
"""Remove components at specified indices."""
_components, mask = removeind(self, "_components", indices)
self._register_components(_components)
return mask
def forward(self):
"""Simply return the components parameter Tensor."""
return self._components
class AbstractLabels(torch.nn.Module):
"""Abstract class for all labels modules."""
@property
def labels(self):
return self._labels
@property
def num_labels(self):
return len(self.labels)
@property
def unique_labels(self):
return torch.unique(self._labels)
@property
def num_unique(self):
return len(self.unique_labels)
@property
def distribution(self):
unique, counts = torch.unique(self._labels,
sorted=True,
return_counts=True)
return dict(zip(unique.tolist(), counts.tolist()))
def _register_labels(self, labels):
self.register_buffer("_labels", labels)
def extra_repr(self):
r = f"num_labels: {self.num_labels}, num_unique: {self.num_unique}"
if len(self.distribution) < 11: # avoid lengthy representations
d = self.distribution
unique, counts = list(d.keys()), list(d.values())
r += f", unique: {unique}, counts: {counts}"
return r
def __len__(self):
return self.num_labels
class Labels(AbstractLabels):
"""A set of standalone labels."""
def __init__(self,
distribution: Union[dict, list, tuple],
initializer: AbstractLabelsInitializer = LabelsInitializer(),
**kwargs):
super().__init__(**kwargs)
self.add_labels(distribution, initializer)
def add_labels(
self,
distribution: Union[dict, tuple, list],
initializer: AbstractLabelsInitializer = LabelsInitializer()):
"""Generate and add new labels."""
assert validate_labels_initializer(initializer)
_labels, new_labels = gencat(self, "_labels", initializer,
distribution)
self._register_labels(_labels)
return new_labels
def remove_labels(self, indices):
"""Remove labels at specified indices."""
_labels, mask = removeind(self, "_labels", indices)
self._register_labels(_labels)
return mask
class LabeledComponents(AbstractComponents):
"""A set of adaptable components and corresponding unadaptable labels."""
def __init__(
self,
distribution: Union[dict, list, tuple],
components_initializer: AbstractComponentsInitializer,
labels_initializer: AbstractLabelsInitializer = LabelsInitializer(
),
**kwargs):
super().__init__(**kwargs)
self.add_components(distribution, components_initializer,
labels_initializer)
@property
def labels(self):
"""Tensor containing the component labels."""
return self._labels
def _register_labels(self, labels):
self.register_buffer("_labels", labels)
def add_components(
self,
distribution,
components_initializer,
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
"""Generate and add new components and labels."""
assert validate_components_initializer(components_initializer)
assert validate_labels_initializer(labels_initializer)
if isinstance(components_initializer, ClassAwareCompInitializer):
cikwargs = dict(distribution=distribution)
else:
distribution = parse_distribution(distribution)
num_components = sum(distribution.values())
cikwargs = dict(num_components=num_components)
_components, new_components = gencat(self, "_components",
components_initializer,
**cikwargs)
_labels, new_labels = gencat(self, "_labels", labels_initializer,
distribution)
self._register_components(_components)
self._register_labels(_labels)
return new_components, new_labels
def remove_components(self, indices):
"""Remove components and labels at specified indices."""
_components, mask = removeind(self, "_components", indices)
_labels, mask = removeind(self, "_labels", indices)
self._register_components(_components)
self._register_labels(_labels)
return mask
def forward(self):
"""Simply return the components parameter Tensor and labels."""
return self._components, self._labels
class ReasoningComponents(AbstractComponents):
"""A set of components and a corresponding adapatable reasoning matrices.
Every component has its own reasoning matrix.
A reasoning matrix is an Nx2 matrix, where N is the number of classes. The
first element is called positive reasoning :math:`p`, the second negative
reasoning :math:`n`. A components can reason in favour (positive) of a
class, against (negative) a class or not at all (neutral).
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
three element probability distribution.
"""
def __init__(self, distribution: Union[dict, list, tuple],
components_initializer: AbstractComponentsInitializer,
reasonings_initializer: AbstractReasoningsInitializer,
**kwargs):
super().__init__(**kwargs)
self.add_components(distribution, components_initializer,
reasonings_initializer)
@property
def reasonings(self):
"""Returns Reasoning Matrix.
Dimension NxCx2
"""
return self._reasonings.detach()
def _register_reasonings(self, reasonings):
self.register_parameter("_reasonings", Parameter(reasonings))
def add_components(self, distribution, components_initializer,
reasonings_initializer: AbstractReasoningsInitializer):
# Checks
assert validate_components_initializer(components_initializer)
assert validate_reasonings_initializer(reasonings_initializer)
distribution = parse_distribution(distribution)
# Generate new components
if isinstance(components_initializer, ClassAwareCompInitializer):
new_components = components_initializer.generate(distribution)
else:
num_components = sum(distribution.values())
new_components = components_initializer.generate(num_components)
# Generate new reasonings
new_reasonings = reasonings_initializer.generate(distribution)
# Register
if hasattr(self, "_components"):
_components = torch.cat([self._components, new_components])
else:
_components = new_components
if hasattr(self, "_reasonings"):
_reasonings = torch.cat([self._reasonings, new_reasonings])
else:
_reasonings = new_reasonings
self._register_components(_components)
self._register_reasonings(_reasonings)
return new_components, new_reasonings
def remove_components(self, indices):
"""Remove components and labels at specified indices."""
mask = torch.ones(self.num_components, dtype=torch.bool)
mask[indices] = False
_components = self._components[mask]
# TODO
# _reasonings = self._reasonings[mask]
self._register_components(_components)
# self._register_reasonings(_reasonings)
return mask
def forward(self):
"""Simply return the components and reasonings."""
return self._components, self._reasonings