[REFACTOR] Clean and move components and initializers into core
This commit is contained in:
5
prototorch/core/__init__.py
Normal file
5
prototorch/core/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""ProtoTorch core"""
|
||||
|
||||
from .components import *
|
||||
from .initializers import *
|
||||
from .labels import *
|
243
prototorch/core/components.py
Normal file
243
prototorch/core/components.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""ProtoTorch components"""
|
||||
|
||||
import inspect
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ..utils import parse_distribution
|
||||
from .initializers import (
|
||||
AbstractComponentsInitializer,
|
||||
AbstractLabelsInitializer,
|
||||
AbstractReasoningsInitializer,
|
||||
ClassAwareCompInitializer,
|
||||
LabelsInitializer,
|
||||
)
|
||||
|
||||
|
||||
def validate_initializer(initializer, instanceof):
|
||||
if not isinstance(initializer, instanceof):
|
||||
emsg = f"`initializer` has to be an instance " \
|
||||
f"of some subtype of {instanceof}. " \
|
||||
f"You have provided: {initializer} instead. "
|
||||
helpmsg = ""
|
||||
if inspect.isclass(initializer):
|
||||
helpmsg = f"Perhaps you meant to say, {initializer.__name__}() " \
|
||||
f"with the brackets instead of just {initializer.__name__}?"
|
||||
raise TypeError(emsg + helpmsg)
|
||||
return True
|
||||
|
||||
|
||||
def validate_components_initializer(initializer):
|
||||
return validate_initializer(initializer, AbstractComponentsInitializer)
|
||||
|
||||
|
||||
def validate_labels_initializer(initializer):
|
||||
return validate_initializer(initializer, AbstractLabelsInitializer)
|
||||
|
||||
|
||||
def validate_reasonings_initializer(initializer):
|
||||
return validate_initializer(initializer, AbstractReasoningsInitializer)
|
||||
|
||||
|
||||
class AbstractComponents(torch.nn.Module):
|
||||
"""Abstract class for all components modules."""
|
||||
@property
|
||||
def num_components(self):
|
||||
"""Current number of components."""
|
||||
return len(self._components)
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
"""Detached Tensor containing the components."""
|
||||
return self._components.detach()
|
||||
|
||||
def _register_components(self, components):
|
||||
self.register_parameter("_components", Parameter(components))
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(components): (shape: {tuple(self._components.shape)})"
|
||||
|
||||
|
||||
class Components(AbstractComponents):
|
||||
"""A set of adaptable Tensors."""
|
||||
def __init__(self, num_components: int,
|
||||
initializer: AbstractComponentsInitializer, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.add_components(num_components, initializer)
|
||||
|
||||
def add_components(self, num: int,
|
||||
initializer: AbstractComponentsInitializer):
|
||||
"""Add new components."""
|
||||
assert validate_components_initializer(initializer)
|
||||
new_components = initializer.generate(num)
|
||||
# Register
|
||||
if hasattr(self, "_components"):
|
||||
_components = torch.cat([self._components, new_components])
|
||||
else:
|
||||
_components = new_components
|
||||
self._register_components(_components)
|
||||
return new_components
|
||||
|
||||
def remove_components(self, indices):
|
||||
"""Remove components at specified indices."""
|
||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
||||
mask[indices] = False
|
||||
_components = self._components[mask]
|
||||
self._register_components(_components)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the components parameter Tensor."""
|
||||
return self._components
|
||||
|
||||
|
||||
class LabeledComponents(AbstractComponents):
|
||||
"""A set of adaptable components and corresponding unadaptable labels."""
|
||||
def __init__(self, distribution: Union[dict, list, tuple],
|
||||
components_initializer: AbstractComponentsInitializer,
|
||||
labels_initializer: AbstractLabelsInitializer, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.add_components(distribution, components_initializer,
|
||||
labels_initializer)
|
||||
|
||||
@property
|
||||
def component_labels(self):
|
||||
"""Tensor containing the component tensors."""
|
||||
return self._labels.detach()
|
||||
|
||||
def _register_labels(self, labels):
|
||||
self.register_buffer("_labels", labels)
|
||||
|
||||
def add_components(
|
||||
self,
|
||||
distribution,
|
||||
components_initializer,
|
||||
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||
# Checks
|
||||
assert validate_components_initializer(components_initializer)
|
||||
assert validate_labels_initializer(labels_initializer)
|
||||
|
||||
distribution = parse_distribution(distribution)
|
||||
|
||||
# Generate new components
|
||||
if isinstance(components_initializer, ClassAwareCompInitializer):
|
||||
new_components = components_initializer.generate(distribution)
|
||||
else:
|
||||
num_components = sum(distribution.values())
|
||||
new_components = components_initializer.generate(num_components)
|
||||
|
||||
# Generate new labels
|
||||
new_labels = labels_initializer.generate(distribution)
|
||||
|
||||
# Register
|
||||
if hasattr(self, "_components"):
|
||||
_components = torch.cat([self._components, new_components])
|
||||
else:
|
||||
_components = new_components
|
||||
if hasattr(self, "_labels"):
|
||||
_labels = torch.cat([self._labels, new_labels])
|
||||
else:
|
||||
_labels = new_labels
|
||||
self._register_components(_components)
|
||||
self._register_labels(_labels)
|
||||
|
||||
return new_components, new_labels
|
||||
|
||||
def remove_components(self, indices):
|
||||
"""Remove components and labels at specified indices."""
|
||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
||||
mask[indices] = False
|
||||
_components = self._components[mask]
|
||||
_labels = self._labels[mask]
|
||||
self._register_components(_components)
|
||||
self._register_labels(_labels)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the components parameter Tensor and labels."""
|
||||
return self._components, self._labels
|
||||
|
||||
|
||||
class ReasoningComponents(AbstractComponents):
|
||||
"""A set of components and a corresponding adapatable reasoning matrices.
|
||||
|
||||
Every component has its own reasoning matrix.
|
||||
|
||||
A reasoning matrix is an Nx2 matrix, where N is the number of classes. The
|
||||
first element is called positive reasoning :math:`p`, the second negative
|
||||
reasoning :math:`n`. A components can reason in favour (positive) of a
|
||||
class, against (negative) a class or not at all (neutral).
|
||||
|
||||
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
|
||||
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
|
||||
three element probability distribution.
|
||||
|
||||
"""
|
||||
def __init__(self, distribution: Union[dict, list, tuple],
|
||||
components_initializer: AbstractComponentsInitializer,
|
||||
reasonings_initializer: AbstractReasoningsInitializer,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.add_components(distribution, components_initializer,
|
||||
reasonings_initializer)
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
"""Returns Reasoning Matrix.
|
||||
|
||||
Dimension NxCx2
|
||||
|
||||
"""
|
||||
return self._reasonings.detach()
|
||||
|
||||
def _register_reasonings(self, reasonings):
|
||||
self.register_parameter("_reasonings", Parameter(reasonings))
|
||||
|
||||
def add_components(self, distribution, components_initializer,
|
||||
reasonings_initializer: AbstractReasoningsInitializer):
|
||||
# Checks
|
||||
assert validate_components_initializer(components_initializer)
|
||||
assert validate_reasonings_initializer(reasonings_initializer)
|
||||
|
||||
distribution = parse_distribution(distribution)
|
||||
|
||||
# Generate new components
|
||||
if isinstance(components_initializer, ClassAwareCompInitializer):
|
||||
new_components = components_initializer.generate(distribution)
|
||||
else:
|
||||
num_components = sum(distribution.values())
|
||||
new_components = components_initializer.generate(num_components)
|
||||
|
||||
# Generate new reasonings
|
||||
new_reasonings = reasonings_initializer.generate(distribution)
|
||||
|
||||
# Register
|
||||
if hasattr(self, "_components"):
|
||||
_components = torch.cat([self._components, new_components])
|
||||
else:
|
||||
_components = new_components
|
||||
if hasattr(self, "_reasonings"):
|
||||
_reasonings = torch.cat([self._reasonings, new_reasonings])
|
||||
else:
|
||||
_reasonings = new_reasonings
|
||||
self._register_components(_components)
|
||||
self._register_reasonings(_reasonings)
|
||||
|
||||
return new_components, new_reasonings
|
||||
|
||||
def remove_components(self, indices):
|
||||
"""Remove components and labels at specified indices."""
|
||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
||||
mask[indices] = False
|
||||
_components = self._components[mask]
|
||||
# TODO
|
||||
# _reasonings = self._reasonings[mask]
|
||||
self._register_components(_components)
|
||||
# self._register_reasonings(_reasonings)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the components and reasonings."""
|
||||
return self._components, self._reasonings
|
258
prototorch/core/initializers.py
Normal file
258
prototorch/core/initializers.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""ProtoTorch code initializers"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import parse_data_arg, parse_distribution
|
||||
|
||||
|
||||
# Components
|
||||
class AbstractComponentsInitializer(ABC):
|
||||
"""Abstract class for all components initializers."""
|
||||
...
|
||||
|
||||
|
||||
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all dimension-aware components initializers."""
|
||||
def __init__(self, shape: Union[Iterable, int]):
|
||||
if isinstance(shape, Iterable):
|
||||
self.component_shape = tuple(shape)
|
||||
else:
|
||||
self.component_shape = (shape, )
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, num_components: int):
|
||||
...
|
||||
|
||||
|
||||
class DataAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all data-aware components initializers.
|
||||
|
||||
Components generated by data-aware components initializers inherit the shape
|
||||
of the provided data.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
data,
|
||||
noise: float = 0.0,
|
||||
transform: callable = torch.nn.Identity()):
|
||||
self.data = data
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
|
||||
def generate_end_hook(self, samples):
|
||||
drift = torch.rand_like(samples) * self.noise
|
||||
components = self.transform(samples + drift)
|
||||
return components
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, num_components: int):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
|
||||
|
||||
class ClassAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all class-aware components initializers.
|
||||
|
||||
Components generated by class-aware components initializers inherit the shape
|
||||
of the provided data.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
data,
|
||||
noise: float = 0.0,
|
||||
transform: callable = torch.nn.Identity()):
|
||||
self.data, self.targets = parse_data_arg(data)
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
self.clabels = torch.unique(self.targets).int().tolist()
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def subinit_type(self) -> DataAwareCompInitializer:
|
||||
...
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
initializers = {
|
||||
k: self.subinit_type(self.data[self.targets == k])
|
||||
for k in distribution.keys()
|
||||
}
|
||||
components = torch.tensor([])
|
||||
for k, v in distribution.items():
|
||||
stratified_data = self.data[self.targets == k]
|
||||
# skip transform here
|
||||
initializer = self.subinit_type(
|
||||
stratified_data,
|
||||
noise=self.noise,
|
||||
transform=self.transform,
|
||||
)
|
||||
samples = initializer.generate(num_components=v)
|
||||
components = torch.cat([components, samples])
|
||||
return components
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
del self.targets
|
||||
|
||||
|
||||
class LiteralCompInitializer(DataAwareCompInitializer):
|
||||
"""'Generate' the provided components.
|
||||
|
||||
Use this to 'generate' pre-initialized components from elsewhere.
|
||||
|
||||
"""
|
||||
def generate(self, num_components: int):
|
||||
"""Ignore `num_components` and simply return transformed `self.data`."""
|
||||
components = self.transform(self.data)
|
||||
return components
|
||||
|
||||
|
||||
class ZerosCompInitializer(ShapeAwareCompInitializer):
|
||||
"""Generate zeros corresponding to the components shape."""
|
||||
def generate(self, num_components: int):
|
||||
components = torch.zeros((num_components, ) + self.component_shape)
|
||||
return components
|
||||
|
||||
|
||||
class OnesCompInitializer(ShapeAwareCompInitializer):
|
||||
"""Generate ones corresponding to the components shape."""
|
||||
def generate(self, num_components: int):
|
||||
components = torch.ones((num_components, ) + self.component_shape)
|
||||
return components
|
||||
|
||||
|
||||
class FillValueCompInitializer(OnesCompInitializer):
|
||||
"""Generate components with the provided `fill_value`."""
|
||||
def __init__(self, shape, fill_value: float = 1.0):
|
||||
super().__init__(shape)
|
||||
self.fill_value = fill_value
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = ones.fill_(self.fill_value)
|
||||
return components
|
||||
|
||||
|
||||
class UniformCompInitializer(OnesCompInitializer):
|
||||
"""Generate components by sampling from a continuous uniform distribution."""
|
||||
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
|
||||
super().__init__(shape)
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = self.scale * ones.uniform_(self.minimum, self.maximum)
|
||||
return components
|
||||
|
||||
|
||||
class RandomNormalCompInitializer(OnesCompInitializer):
|
||||
"""Generate components by sampling from a standard normal distribution."""
|
||||
def __init__(self, shape, scale=1.0):
|
||||
super().__init__(shape)
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = self.scale * torch.randn_like(ones)
|
||||
return components
|
||||
|
||||
|
||||
class SelectionCompInitializer(DataAwareCompInitializer):
|
||||
"""Generate components by uniformly sampling from the provided data."""
|
||||
def generate(self, num_components: int):
|
||||
indices = torch.LongTensor(num_components).random_(0, len(self.data))
|
||||
samples = self.data[indices]
|
||||
components = self.generate_end_hook(samples)
|
||||
return components
|
||||
|
||||
|
||||
class MeanCompInitializer(DataAwareCompInitializer):
|
||||
"""Generate components by computing the mean of the provided data."""
|
||||
def generate(self, num_components: int):
|
||||
mean = torch.mean(self.data, dim=0)
|
||||
repeat_dim = [num_components] + [1] * len(mean.shape)
|
||||
samples = mean.repeat(repeat_dim)
|
||||
components = self.generate_end_hook(samples)
|
||||
return components
|
||||
|
||||
|
||||
class StratifiedSelectionCompInitializer(ClassAwareCompInitializer):
|
||||
"""Generate components using stratified sampling from the provided data."""
|
||||
@property
|
||||
def subinit_type(self):
|
||||
return SelectionCompInitializer
|
||||
|
||||
|
||||
class StratifiedMeanCompInitializer(ClassAwareCompInitializer):
|
||||
"""Generate components at stratified means of the provided data."""
|
||||
@property
|
||||
def subinit_type(self):
|
||||
return MeanCompInitializer
|
||||
|
||||
|
||||
# Labels
|
||||
class AbstractLabelsInitializer(ABC):
|
||||
"""Abstract class for all labels initializers."""
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
|
||||
|
||||
class LabelsInitializer(AbstractLabelsInitializer):
|
||||
"""Generate labels with `self.distribution`."""
|
||||
def __init__(self, override_labels: list = []):
|
||||
self.override_labels = override_labels
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
labels = []
|
||||
for k, v in distribution.items():
|
||||
labels.extend([k] * v)
|
||||
labels = torch.LongTensor(labels)
|
||||
return labels
|
||||
|
||||
|
||||
# Reasonings
|
||||
class AbstractReasoningsInitializer(ABC):
|
||||
"""Abstract class for all reasonings initializers."""
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
|
||||
|
||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Generate labels with `self.distribution`."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
num_classes = len(distribution.keys())
|
||||
num_components = sum(distribution.values())
|
||||
assert num_classes == num_components
|
||||
reasonings = torch.stack(
|
||||
[torch.eye(num_classes),
|
||||
torch.zeros(num_classes, num_classes)],
|
||||
dim=0)
|
||||
return reasonings
|
||||
|
||||
|
||||
# Aliases - Components
|
||||
ZCI = ZerosCompInitializer
|
||||
OCI = OnesCompInitializer
|
||||
FVCI = FillValueCompInitializer
|
||||
LCI = LiteralCompInitializer
|
||||
UCI = UniformCompInitializer
|
||||
RNCI = RandomNormalCompInitializer
|
||||
SCI = SelectionCompInitializer
|
||||
MCI = MeanCompInitializer
|
||||
SSCI = StratifiedSelectionCompInitializer
|
||||
SMCI = StratifiedMeanCompInitializer
|
||||
PPRI = PurePositiveReasoningsInitializer
|
Reference in New Issue
Block a user