[REFACTOR] Clean and move components and initializers into core

This commit is contained in:
Jensun Ravichandran
2021-06-12 20:29:24 +02:00
parent b8969347b1
commit 5dddb39ec4
7 changed files with 510 additions and 549 deletions

View File

@@ -0,0 +1,5 @@
"""ProtoTorch core"""
from .components import *
from .initializers import *
from .labels import *

View File

@@ -0,0 +1,243 @@
"""ProtoTorch components"""
import inspect
from typing import Union
import torch
from torch.nn.parameter import Parameter
from ..utils import parse_distribution
from .initializers import (
AbstractComponentsInitializer,
AbstractLabelsInitializer,
AbstractReasoningsInitializer,
ClassAwareCompInitializer,
LabelsInitializer,
)
def validate_initializer(initializer, instanceof):
if not isinstance(initializer, instanceof):
emsg = f"`initializer` has to be an instance " \
f"of some subtype of {instanceof}. " \
f"You have provided: {initializer} instead. "
helpmsg = ""
if inspect.isclass(initializer):
helpmsg = f"Perhaps you meant to say, {initializer.__name__}() " \
f"with the brackets instead of just {initializer.__name__}?"
raise TypeError(emsg + helpmsg)
return True
def validate_components_initializer(initializer):
return validate_initializer(initializer, AbstractComponentsInitializer)
def validate_labels_initializer(initializer):
return validate_initializer(initializer, AbstractLabelsInitializer)
def validate_reasonings_initializer(initializer):
return validate_initializer(initializer, AbstractReasoningsInitializer)
class AbstractComponents(torch.nn.Module):
"""Abstract class for all components modules."""
@property
def num_components(self):
"""Current number of components."""
return len(self._components)
@property
def components(self):
"""Detached Tensor containing the components."""
return self._components.detach()
def _register_components(self, components):
self.register_parameter("_components", Parameter(components))
def extra_repr(self):
return f"(components): (shape: {tuple(self._components.shape)})"
class Components(AbstractComponents):
"""A set of adaptable Tensors."""
def __init__(self, num_components: int,
initializer: AbstractComponentsInitializer, **kwargs):
super().__init__(**kwargs)
self.add_components(num_components, initializer)
def add_components(self, num: int,
initializer: AbstractComponentsInitializer):
"""Add new components."""
assert validate_components_initializer(initializer)
new_components = initializer.generate(num)
# Register
if hasattr(self, "_components"):
_components = torch.cat([self._components, new_components])
else:
_components = new_components
self._register_components(_components)
return new_components
def remove_components(self, indices):
"""Remove components at specified indices."""
mask = torch.ones(self.num_components, dtype=torch.bool)
mask[indices] = False
_components = self._components[mask]
self._register_components(_components)
return mask
def forward(self):
"""Simply return the components parameter Tensor."""
return self._components
class LabeledComponents(AbstractComponents):
"""A set of adaptable components and corresponding unadaptable labels."""
def __init__(self, distribution: Union[dict, list, tuple],
components_initializer: AbstractComponentsInitializer,
labels_initializer: AbstractLabelsInitializer, **kwargs):
super().__init__(**kwargs)
self.add_components(distribution, components_initializer,
labels_initializer)
@property
def component_labels(self):
"""Tensor containing the component tensors."""
return self._labels.detach()
def _register_labels(self, labels):
self.register_buffer("_labels", labels)
def add_components(
self,
distribution,
components_initializer,
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
# Checks
assert validate_components_initializer(components_initializer)
assert validate_labels_initializer(labels_initializer)
distribution = parse_distribution(distribution)
# Generate new components
if isinstance(components_initializer, ClassAwareCompInitializer):
new_components = components_initializer.generate(distribution)
else:
num_components = sum(distribution.values())
new_components = components_initializer.generate(num_components)
# Generate new labels
new_labels = labels_initializer.generate(distribution)
# Register
if hasattr(self, "_components"):
_components = torch.cat([self._components, new_components])
else:
_components = new_components
if hasattr(self, "_labels"):
_labels = torch.cat([self._labels, new_labels])
else:
_labels = new_labels
self._register_components(_components)
self._register_labels(_labels)
return new_components, new_labels
def remove_components(self, indices):
"""Remove components and labels at specified indices."""
mask = torch.ones(self.num_components, dtype=torch.bool)
mask[indices] = False
_components = self._components[mask]
_labels = self._labels[mask]
self._register_components(_components)
self._register_labels(_labels)
return mask
def forward(self):
"""Simply return the components parameter Tensor and labels."""
return self._components, self._labels
class ReasoningComponents(AbstractComponents):
"""A set of components and a corresponding adapatable reasoning matrices.
Every component has its own reasoning matrix.
A reasoning matrix is an Nx2 matrix, where N is the number of classes. The
first element is called positive reasoning :math:`p`, the second negative
reasoning :math:`n`. A components can reason in favour (positive) of a
class, against (negative) a class or not at all (neutral).
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
three element probability distribution.
"""
def __init__(self, distribution: Union[dict, list, tuple],
components_initializer: AbstractComponentsInitializer,
reasonings_initializer: AbstractReasoningsInitializer,
**kwargs):
super().__init__(**kwargs)
self.add_components(distribution, components_initializer,
reasonings_initializer)
@property
def reasonings(self):
"""Returns Reasoning Matrix.
Dimension NxCx2
"""
return self._reasonings.detach()
def _register_reasonings(self, reasonings):
self.register_parameter("_reasonings", Parameter(reasonings))
def add_components(self, distribution, components_initializer,
reasonings_initializer: AbstractReasoningsInitializer):
# Checks
assert validate_components_initializer(components_initializer)
assert validate_reasonings_initializer(reasonings_initializer)
distribution = parse_distribution(distribution)
# Generate new components
if isinstance(components_initializer, ClassAwareCompInitializer):
new_components = components_initializer.generate(distribution)
else:
num_components = sum(distribution.values())
new_components = components_initializer.generate(num_components)
# Generate new reasonings
new_reasonings = reasonings_initializer.generate(distribution)
# Register
if hasattr(self, "_components"):
_components = torch.cat([self._components, new_components])
else:
_components = new_components
if hasattr(self, "_reasonings"):
_reasonings = torch.cat([self._reasonings, new_reasonings])
else:
_reasonings = new_reasonings
self._register_components(_components)
self._register_reasonings(_reasonings)
return new_components, new_reasonings
def remove_components(self, indices):
"""Remove components and labels at specified indices."""
mask = torch.ones(self.num_components, dtype=torch.bool)
mask[indices] = False
_components = self._components[mask]
# TODO
# _reasonings = self._reasonings[mask]
self._register_components(_components)
# self._register_reasonings(_reasonings)
return mask
def forward(self):
"""Simply return the components and reasonings."""
return self._components, self._reasonings

View File

@@ -0,0 +1,258 @@
"""ProtoTorch code initializers"""
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Union
import torch
from ..utils import parse_data_arg, parse_distribution
# Components
class AbstractComponentsInitializer(ABC):
"""Abstract class for all components initializers."""
...
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all dimension-aware components initializers."""
def __init__(self, shape: Union[Iterable, int]):
if isinstance(shape, Iterable):
self.component_shape = tuple(shape)
else:
self.component_shape = (shape, )
@abstractmethod
def generate(self, num_components: int):
...
class DataAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all data-aware components initializers.
Components generated by data-aware components initializers inherit the shape
of the provided data.
"""
def __init__(self,
data,
noise: float = 0.0,
transform: callable = torch.nn.Identity()):
self.data = data
self.noise = noise
self.transform = transform
def generate_end_hook(self, samples):
drift = torch.rand_like(samples) * self.noise
components = self.transform(samples + drift)
return components
@abstractmethod
def generate(self, num_components: int):
...
return self.generate_end_hook(...)
def __del__(self):
del self.data
class ClassAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all class-aware components initializers.
Components generated by class-aware components initializers inherit the shape
of the provided data.
"""
def __init__(self,
data,
noise: float = 0.0,
transform: callable = torch.nn.Identity()):
self.data, self.targets = parse_data_arg(data)
self.noise = noise
self.transform = transform
self.clabels = torch.unique(self.targets).int().tolist()
self.num_classes = len(self.clabels)
@property
@abstractmethod
def subinit_type(self) -> DataAwareCompInitializer:
...
def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution)
initializers = {
k: self.subinit_type(self.data[self.targets == k])
for k in distribution.keys()
}
components = torch.tensor([])
for k, v in distribution.items():
stratified_data = self.data[self.targets == k]
# skip transform here
initializer = self.subinit_type(
stratified_data,
noise=self.noise,
transform=self.transform,
)
samples = initializer.generate(num_components=v)
components = torch.cat([components, samples])
return components
def __del__(self):
del self.data
del self.targets
class LiteralCompInitializer(DataAwareCompInitializer):
"""'Generate' the provided components.
Use this to 'generate' pre-initialized components from elsewhere.
"""
def generate(self, num_components: int):
"""Ignore `num_components` and simply return transformed `self.data`."""
components = self.transform(self.data)
return components
class ZerosCompInitializer(ShapeAwareCompInitializer):
"""Generate zeros corresponding to the components shape."""
def generate(self, num_components: int):
components = torch.zeros((num_components, ) + self.component_shape)
return components
class OnesCompInitializer(ShapeAwareCompInitializer):
"""Generate ones corresponding to the components shape."""
def generate(self, num_components: int):
components = torch.ones((num_components, ) + self.component_shape)
return components
class FillValueCompInitializer(OnesCompInitializer):
"""Generate components with the provided `fill_value`."""
def __init__(self, shape, fill_value: float = 1.0):
super().__init__(shape)
self.fill_value = fill_value
def generate(self, num_components: int):
ones = super().generate(num_components)
components = ones.fill_(self.fill_value)
return components
class UniformCompInitializer(OnesCompInitializer):
"""Generate components by sampling from a continuous uniform distribution."""
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
super().__init__(shape)
self.minimum = minimum
self.maximum = maximum
self.scale = scale
def generate(self, num_components: int):
ones = super().generate(num_components)
components = self.scale * ones.uniform_(self.minimum, self.maximum)
return components
class RandomNormalCompInitializer(OnesCompInitializer):
"""Generate components by sampling from a standard normal distribution."""
def __init__(self, shape, scale=1.0):
super().__init__(shape)
self.scale = scale
def generate(self, num_components: int):
ones = super().generate(num_components)
components = self.scale * torch.randn_like(ones)
return components
class SelectionCompInitializer(DataAwareCompInitializer):
"""Generate components by uniformly sampling from the provided data."""
def generate(self, num_components: int):
indices = torch.LongTensor(num_components).random_(0, len(self.data))
samples = self.data[indices]
components = self.generate_end_hook(samples)
return components
class MeanCompInitializer(DataAwareCompInitializer):
"""Generate components by computing the mean of the provided data."""
def generate(self, num_components: int):
mean = torch.mean(self.data, dim=0)
repeat_dim = [num_components] + [1] * len(mean.shape)
samples = mean.repeat(repeat_dim)
components = self.generate_end_hook(samples)
return components
class StratifiedSelectionCompInitializer(ClassAwareCompInitializer):
"""Generate components using stratified sampling from the provided data."""
@property
def subinit_type(self):
return SelectionCompInitializer
class StratifiedMeanCompInitializer(ClassAwareCompInitializer):
"""Generate components at stratified means of the provided data."""
@property
def subinit_type(self):
return MeanCompInitializer
# Labels
class AbstractLabelsInitializer(ABC):
"""Abstract class for all labels initializers."""
@abstractmethod
def generate(self, distribution: Union[dict, list, tuple]):
...
class LabelsInitializer(AbstractLabelsInitializer):
"""Generate labels with `self.distribution`."""
def __init__(self, override_labels: list = []):
self.override_labels = override_labels
def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution)
labels = []
for k, v in distribution.items():
labels.extend([k] * v)
labels = torch.LongTensor(labels)
return labels
# Reasonings
class AbstractReasoningsInitializer(ABC):
"""Abstract class for all reasonings initializers."""
@abstractmethod
def generate(self, distribution: Union[dict, list, tuple]):
...
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
"""Generate labels with `self.distribution`."""
def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution)
num_classes = len(distribution.keys())
num_components = sum(distribution.values())
assert num_classes == num_components
reasonings = torch.stack(
[torch.eye(num_classes),
torch.zeros(num_classes, num_classes)],
dim=0)
return reasonings
# Aliases - Components
ZCI = ZerosCompInitializer
OCI = OnesCompInitializer
FVCI = FillValueCompInitializer
LCI = LiteralCompInitializer
UCI = UniformCompInitializer
RNCI = RandomNormalCompInitializer
SCI = SelectionCompInitializer
MCI = MeanCompInitializer
SSCI = StratifiedSelectionCompInitializer
SMCI = StratifiedMeanCompInitializer
PPRI = PurePositiveReasoningsInitializer