[REFACTOR] Clean and move components and initializers into core
This commit is contained in:
parent
b8969347b1
commit
5dddb39ec4
@ -1,235 +0,0 @@
|
||||
"""ProtoTorch Components."""
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from prototorch.components.initializers import (ClassAwareInitializer,
|
||||
ComponentsInitializer,
|
||||
EqualLabelsInitializer,
|
||||
UnequalLabelsInitializer,
|
||||
ZeroReasoningsInitializer)
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .initializers import parse_data_arg
|
||||
|
||||
|
||||
def get_labels_initializer(distribution):
|
||||
if isinstance(distribution, dict):
|
||||
if "num_classes" in distribution.keys():
|
||||
labels = EqualLabelsInitializer(
|
||||
distribution["num_classes"],
|
||||
distribution["prototypes_per_class"])
|
||||
else:
|
||||
clabels = list(distribution.keys())
|
||||
dist = list(distribution.values())
|
||||
labels = UnequalLabelsInitializer(dist, clabels)
|
||||
elif isinstance(distribution, tuple):
|
||||
num_classes, prototypes_per_class = distribution
|
||||
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||
elif isinstance(distribution, list):
|
||||
labels = UnequalLabelsInitializer(distribution)
|
||||
else:
|
||||
msg = f"`distribution` not understood." \
|
||||
f"You have provided: {distribution=}."
|
||||
raise ValueError(msg)
|
||||
return labels
|
||||
|
||||
|
||||
def _precheck_initializer(initializer):
|
||||
if not isinstance(initializer, ComponentsInitializer):
|
||||
emsg = f"`initializer` has to be some subtype of " \
|
||||
f"{ComponentsInitializer}. " \
|
||||
f"You have provided: {initializer=} instead."
|
||||
raise TypeError(emsg)
|
||||
|
||||
|
||||
class Components(torch.nn.Module):
|
||||
"""Components is a set of learnable Tensors."""
|
||||
def __init__(self,
|
||||
num_components=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
super().__init__()
|
||||
|
||||
# Ignore all initialization settings if initialized_components is given.
|
||||
if initialized_components is not None:
|
||||
self._register_components(initialized_components)
|
||||
if num_components is not None or initializer is not None:
|
||||
wmsg = "Arguments ignored while initializing Components"
|
||||
warnings.warn(wmsg)
|
||||
else:
|
||||
self._initialize_components(num_components, initializer)
|
||||
|
||||
@property
|
||||
def num_components(self):
|
||||
return len(self._components)
|
||||
|
||||
def _register_components(self, components):
|
||||
self.register_parameter("_components", Parameter(components))
|
||||
|
||||
def _initialize_components(self, num_components, initializer):
|
||||
_precheck_initializer(initializer)
|
||||
_components = initializer.generate(num_components)
|
||||
self._register_components(_components)
|
||||
|
||||
def add_components(self,
|
||||
num=1,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
_components = torch.cat([self._components, initialized_components])
|
||||
else:
|
||||
_precheck_initializer(initializer)
|
||||
_new = initializer.generate(num)
|
||||
_components = torch.cat([self._components, _new])
|
||||
self._register_components(_components)
|
||||
|
||||
def remove_components(self, indices=None):
|
||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
||||
mask[indices] = False
|
||||
_components = self._components[mask]
|
||||
self._register_components(_components)
|
||||
return mask
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
"""Tensor containing the component tensors."""
|
||||
return self._components.detach()
|
||||
|
||||
def forward(self):
|
||||
return self._components
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(components): (shape: {tuple(self._components.shape)})"
|
||||
|
||||
|
||||
class LabeledComponents(Components):
|
||||
"""LabeledComponents generate a set of components and a set of labels.
|
||||
|
||||
Every Component has a label assigned.
|
||||
"""
|
||||
def __init__(self,
|
||||
distribution=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
components, component_labels = parse_data_arg(
|
||||
initialized_components)
|
||||
super().__init__(initialized_components=components)
|
||||
# self._labels = component_labels
|
||||
self._labels = component_labels
|
||||
else:
|
||||
labels_initializer = get_labels_initializer(distribution)
|
||||
self.initial_distribution = labels_initializer.distribution
|
||||
_labels = labels.generate()
|
||||
super().__init__(len(_labels), initializer=initializer)
|
||||
self._register_labels(_labels)
|
||||
|
||||
def _register_labels(self, labels):
|
||||
self.register_buffer("_labels", labels)
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
clabels, counts = torch.unique(self._labels,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
return dict(zip(clabels.tolist(), counts.tolist()))
|
||||
|
||||
def _initialize_components(self, num_components, initializer):
|
||||
if isinstance(initializer, ClassAwareInitializer):
|
||||
_precheck_initializer(initializer)
|
||||
_components = initializer.generate(num_components,
|
||||
self.initial_distribution)
|
||||
self._register_components(_components)
|
||||
else:
|
||||
super()._initialize_components(num_components, initializer)
|
||||
|
||||
def add_components(self, distribution, initializer):
|
||||
_precheck_initializer(initializer)
|
||||
|
||||
# Labels
|
||||
labels_initializer = get_labels_initializer(distribution)
|
||||
new_labels = labels_initializer.generate()
|
||||
_labels = torch.cat([self._labels, new_labels])
|
||||
self._register_labels(_labels)
|
||||
|
||||
# Components
|
||||
if isinstance(initializer, ClassAwareInitializer):
|
||||
_new = initializer.generate(len(new_labels), distribution)
|
||||
else:
|
||||
_new = initializer.generate(len(new_labels))
|
||||
_components = torch.cat([self._components, _new])
|
||||
self._register_components(_components)
|
||||
|
||||
def remove_components(self, indices=None):
|
||||
# Components
|
||||
mask = super().remove_components(indices)
|
||||
|
||||
# Labels
|
||||
_labels = self._labels[mask]
|
||||
self._register_labels(_labels)
|
||||
|
||||
@property
|
||||
def component_labels(self):
|
||||
"""Tensor containing the component tensors."""
|
||||
return self._labels.detach()
|
||||
|
||||
def forward(self):
|
||||
return super().forward(), self._labels
|
||||
|
||||
|
||||
class ReasoningComponents(Components):
|
||||
"""ReasoningComponents generate a set of components and a set of reasoning matrices.
|
||||
|
||||
Every Component has a reasoning matrix assigned.
|
||||
|
||||
A reasoning matrix is a Nx2 matrix, where N is the number of Classes. The
|
||||
first element is called positive reasoning :math:`p`, the second negative
|
||||
reasoning :math:`n`. A components can reason in favour (positive) of a
|
||||
class, against (negative) a class or not at all (neutral).
|
||||
|
||||
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
|
||||
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
|
||||
three element probability distribution.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
distribution=None,
|
||||
initializer=None,
|
||||
reasoning_initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
components, reasonings = initialized_components
|
||||
super().__init__(initialized_components=components)
|
||||
self.register_parameter("_reasonings", reasonings)
|
||||
else:
|
||||
labels_initializer = get_labels_initializer(distribution)
|
||||
self.initial_distribution = labels_initializer.distribution
|
||||
super().__init__(len(self.initial_distribution),
|
||||
initializer=initializer)
|
||||
reasonings = reasoning_initializer.generate()
|
||||
self._register_reasonings(reasonings)
|
||||
|
||||
def _initialize_reasonings(self, reasoning_initializer):
|
||||
if isinstance(reasonings, tuple):
|
||||
num_classes, num_components = reasonings
|
||||
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
|
||||
|
||||
_reasonings = reasonings.generate()
|
||||
self.register_parameter("_reasonings", _reasonings)
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
"""Returns Reasoning Matrix.
|
||||
|
||||
Dimension NxCx2
|
||||
|
||||
"""
|
||||
return self._reasonings.detach()
|
||||
|
||||
def forward(self):
|
||||
return super().forward(), self._reasonings
|
@ -1,225 +0,0 @@
|
||||
"""ProtoTroch Initializers."""
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from itertools import chain
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
def parse_data_arg(data_arg):
|
||||
if isinstance(data_arg, Dataset):
|
||||
data_arg = DataLoader(data_arg, batch_size=len(data_arg))
|
||||
|
||||
if isinstance(data_arg, DataLoader):
|
||||
data = torch.tensor([])
|
||||
targets = torch.tensor([])
|
||||
for x, y in data_arg:
|
||||
data = torch.cat([data, x])
|
||||
targets = torch.cat([targets, y])
|
||||
else:
|
||||
data, targets = data_arg
|
||||
if not isinstance(data, torch.Tensor):
|
||||
wmsg = f"Converting data to {torch.Tensor}."
|
||||
warnings.warn(wmsg)
|
||||
data = torch.Tensor(data)
|
||||
if not isinstance(targets, torch.Tensor):
|
||||
wmsg = f"Converting targets to {torch.Tensor}."
|
||||
warnings.warn(wmsg)
|
||||
targets = torch.Tensor(targets)
|
||||
return data, targets
|
||||
|
||||
|
||||
def get_subinitializers(data, targets, clabels, subinit_type):
|
||||
initializers = dict()
|
||||
for clabel in clabels:
|
||||
class_data = data[targets == clabel]
|
||||
class_initializer = subinit_type(class_data)
|
||||
initializers[clabel] = (class_initializer)
|
||||
return initializers
|
||||
|
||||
|
||||
# Components
|
||||
class ComponentsInitializer(object):
|
||||
def generate(self, number_of_components):
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
|
||||
class DimensionAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, dims):
|
||||
super().__init__()
|
||||
if isinstance(dims, Iterable):
|
||||
self.components_dims = tuple(dims)
|
||||
else:
|
||||
self.components_dims = (dims, )
|
||||
|
||||
|
||||
class OnesInitializer(DimensionAwareInitializer):
|
||||
def __init__(self, dims, scale=1.0):
|
||||
super().__init__(dims)
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.ones(gen_dims) * self.scale
|
||||
|
||||
|
||||
class ZerosInitializer(DimensionAwareInitializer):
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.zeros(gen_dims)
|
||||
|
||||
|
||||
class UniformInitializer(DimensionAwareInitializer):
|
||||
def __init__(self, dims, minimum=0.0, maximum=1.0, scale=1.0):
|
||||
super().__init__(dims)
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.ones(gen_dims).uniform_(self.minimum,
|
||||
self.maximum) * self.scale
|
||||
|
||||
|
||||
class DataAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, data, transform=torch.nn.Identity()):
|
||||
super().__init__()
|
||||
self.data = data
|
||||
self.transform = transform
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
|
||||
|
||||
class SelectionInitializer(DataAwareInitializer):
|
||||
def generate(self, length):
|
||||
indices = torch.LongTensor(length).random_(0, len(self.data))
|
||||
return self.transform(self.data[indices])
|
||||
|
||||
|
||||
class MeanInitializer(DataAwareInitializer):
|
||||
def generate(self, length):
|
||||
mean = torch.mean(self.data, dim=0)
|
||||
repeat_dim = [length] + [1] * len(mean.shape)
|
||||
return self.transform(mean.repeat(repeat_dim))
|
||||
|
||||
|
||||
class ClassAwareInitializer(DataAwareInitializer):
|
||||
def __init__(self, data, transform=torch.nn.Identity()):
|
||||
data, targets = parse_data_arg(data)
|
||||
super().__init__(data, transform)
|
||||
self.targets = targets
|
||||
self.clabels = torch.unique(self.targets).int().tolist()
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
def _get_samples_from_initializer(self, length, dist):
|
||||
if not dist:
|
||||
per_class = length // self.num_classes
|
||||
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
|
||||
if isinstance(dist, list):
|
||||
dist = dict(zip(self.clabels, dist))
|
||||
samples = [self.initializers[k].generate(n) for k, n in dist.items()]
|
||||
out = torch.vstack(samples)
|
||||
with torch.no_grad():
|
||||
out = self.transform(out)
|
||||
return out
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
del self.targets
|
||||
|
||||
|
||||
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||
def __init__(self, data, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
self.initializers = get_subinitializers(self.data, self.targets,
|
||||
self.clabels, MeanInitializer)
|
||||
|
||||
def generate(self, length, dist):
|
||||
samples = self._get_samples_from_initializer(length, dist)
|
||||
return samples
|
||||
|
||||
|
||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||
def __init__(self, data, noise=None, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
self.noise = noise
|
||||
self.initializers = get_subinitializers(self.data, self.targets,
|
||||
self.clabels,
|
||||
SelectionInitializer)
|
||||
|
||||
def add_noise_v1(self, x):
|
||||
return x + self.noise
|
||||
|
||||
def add_noise_v2(self, x):
|
||||
"""Shifts some dimensions of the data randomly."""
|
||||
n1 = torch.rand_like(x)
|
||||
n2 = torch.rand_like(x)
|
||||
mask = torch.bernoulli(n1) - torch.bernoulli(n2)
|
||||
return x + (self.noise * mask)
|
||||
|
||||
def generate(self, length, dist):
|
||||
samples = self._get_samples_from_initializer(length, dist)
|
||||
if self.noise is not None:
|
||||
samples = self.add_noise_v1(samples)
|
||||
return samples
|
||||
|
||||
|
||||
# Labels
|
||||
class LabelsInitializer:
|
||||
def generate(self):
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
|
||||
class UnequalLabelsInitializer(LabelsInitializer):
|
||||
def __init__(self, dist, clabels=None):
|
||||
self.dist = dist
|
||||
self.clabels = clabels or range(len(self.dist))
|
||||
|
||||
@property
|
||||
def distribution(self) -> List:
|
||||
return self.dist
|
||||
|
||||
def generate(self):
|
||||
targets = list(
|
||||
chain(*[[i] * n for i, n in zip(self.clabels, self.dist)]))
|
||||
return torch.LongTensor(targets)
|
||||
|
||||
|
||||
class EqualLabelsInitializer(LabelsInitializer):
|
||||
def __init__(self, classes, per_class):
|
||||
self.classes = classes
|
||||
self.per_class = per_class
|
||||
|
||||
@property
|
||||
def distribution(self) -> List:
|
||||
return self.classes * [self.per_class]
|
||||
|
||||
def generate(self):
|
||||
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
|
||||
|
||||
|
||||
# Reasonings
|
||||
class ReasoningsInitializer:
|
||||
def generate(self, length):
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
|
||||
class ZeroReasoningsInitializer(ReasoningsInitializer):
|
||||
def __init__(self, classes, length):
|
||||
self.classes = classes
|
||||
self.length = length
|
||||
|
||||
def generate(self):
|
||||
return torch.zeros((self.length, self.classes, 2))
|
||||
|
||||
|
||||
# Aliases
|
||||
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
|
||||
SMI = StratifiedMeanInitializer
|
||||
Random = RandomInitializer = UniformInitializer
|
||||
Zeros = ZerosInitializer
|
||||
Ones = OnesInitializer
|
@ -1,86 +0,0 @@
|
||||
"""ProtoTorch Labels."""
|
||||
|
||||
import torch
|
||||
from prototorch.components.components import get_labels_initializer
|
||||
from prototorch.components.initializers import (ClassAwareInitializer,
|
||||
ComponentsInitializer,
|
||||
EqualLabelsInitializer,
|
||||
UnequalLabelsInitializer)
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
def get_labels_initializer(distribution):
|
||||
if isinstance(distribution, dict):
|
||||
if "num_classes" in distribution.keys():
|
||||
labels = EqualLabelsInitializer(
|
||||
distribution["num_classes"],
|
||||
distribution["prototypes_per_class"])
|
||||
else:
|
||||
clabels = list(distribution.keys())
|
||||
dist = list(distribution.values())
|
||||
labels = UnequalLabelsInitializer(dist, clabels)
|
||||
elif isinstance(distribution, tuple):
|
||||
num_classes, prototypes_per_class = distribution
|
||||
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||
elif isinstance(distribution, list):
|
||||
labels = UnequalLabelsInitializer(distribution)
|
||||
else:
|
||||
msg = f"`distribution` not understood." \
|
||||
f"You have provided: {distribution=}."
|
||||
raise ValueError(msg)
|
||||
return labels
|
||||
|
||||
|
||||
class Labels(torch.nn.Module):
|
||||
def __init__(self,
|
||||
distribution=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_labels=None):
|
||||
_labels = self.get_labels(distribution,
|
||||
initializer,
|
||||
initialized_labels=initialized_labels)
|
||||
self._register_labels(_labels)
|
||||
|
||||
def _register_labels(self, labels):
|
||||
# self.register_buffer("_labels", labels)
|
||||
self.register_parameter("_labels",
|
||||
Parameter(labels, requires_grad=False))
|
||||
|
||||
def get_labels(self,
|
||||
distribution=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_labels=None):
|
||||
if initialized_labels is not None:
|
||||
_labels = initialized_labels
|
||||
else:
|
||||
labels_initializer = initializer or get_labels_initializer(
|
||||
distribution)
|
||||
self.initial_distribution = labels_initializer.distribution
|
||||
_labels = labels_initializer.generate()
|
||||
return _labels
|
||||
|
||||
def add_labels(self,
|
||||
distribution=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_labels=None):
|
||||
new_labels = self.get_labels(distribution,
|
||||
initializer,
|
||||
initialized_labels=initialized_labels)
|
||||
_labels = torch.cat([self._labels, new_labels])
|
||||
self._register_labels(_labels)
|
||||
|
||||
def remove_labels(self, indices=None):
|
||||
mask = torch.ones(len(self._labels, dtype=torch.bool))
|
||||
mask[indices] = False
|
||||
_labels = self._labels[mask]
|
||||
self._register_labels(_labels)
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
return self._labels
|
||||
|
||||
def forward(self):
|
||||
return self._labels
|
@ -1,3 +1,5 @@
|
||||
"""ProtoTorch core"""
|
||||
|
||||
from .components import *
|
||||
from .initializers import *
|
||||
from .labels import *
|
243
prototorch/core/components.py
Normal file
243
prototorch/core/components.py
Normal file
@ -0,0 +1,243 @@
|
||||
"""ProtoTorch components"""
|
||||
|
||||
import inspect
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ..utils import parse_distribution
|
||||
from .initializers import (
|
||||
AbstractComponentsInitializer,
|
||||
AbstractLabelsInitializer,
|
||||
AbstractReasoningsInitializer,
|
||||
ClassAwareCompInitializer,
|
||||
LabelsInitializer,
|
||||
)
|
||||
|
||||
|
||||
def validate_initializer(initializer, instanceof):
|
||||
if not isinstance(initializer, instanceof):
|
||||
emsg = f"`initializer` has to be an instance " \
|
||||
f"of some subtype of {instanceof}. " \
|
||||
f"You have provided: {initializer} instead. "
|
||||
helpmsg = ""
|
||||
if inspect.isclass(initializer):
|
||||
helpmsg = f"Perhaps you meant to say, {initializer.__name__}() " \
|
||||
f"with the brackets instead of just {initializer.__name__}?"
|
||||
raise TypeError(emsg + helpmsg)
|
||||
return True
|
||||
|
||||
|
||||
def validate_components_initializer(initializer):
|
||||
return validate_initializer(initializer, AbstractComponentsInitializer)
|
||||
|
||||
|
||||
def validate_labels_initializer(initializer):
|
||||
return validate_initializer(initializer, AbstractLabelsInitializer)
|
||||
|
||||
|
||||
def validate_reasonings_initializer(initializer):
|
||||
return validate_initializer(initializer, AbstractReasoningsInitializer)
|
||||
|
||||
|
||||
class AbstractComponents(torch.nn.Module):
|
||||
"""Abstract class for all components modules."""
|
||||
@property
|
||||
def num_components(self):
|
||||
"""Current number of components."""
|
||||
return len(self._components)
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
"""Detached Tensor containing the components."""
|
||||
return self._components.detach()
|
||||
|
||||
def _register_components(self, components):
|
||||
self.register_parameter("_components", Parameter(components))
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(components): (shape: {tuple(self._components.shape)})"
|
||||
|
||||
|
||||
class Components(AbstractComponents):
|
||||
"""A set of adaptable Tensors."""
|
||||
def __init__(self, num_components: int,
|
||||
initializer: AbstractComponentsInitializer, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.add_components(num_components, initializer)
|
||||
|
||||
def add_components(self, num: int,
|
||||
initializer: AbstractComponentsInitializer):
|
||||
"""Add new components."""
|
||||
assert validate_components_initializer(initializer)
|
||||
new_components = initializer.generate(num)
|
||||
# Register
|
||||
if hasattr(self, "_components"):
|
||||
_components = torch.cat([self._components, new_components])
|
||||
else:
|
||||
_components = new_components
|
||||
self._register_components(_components)
|
||||
return new_components
|
||||
|
||||
def remove_components(self, indices):
|
||||
"""Remove components at specified indices."""
|
||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
||||
mask[indices] = False
|
||||
_components = self._components[mask]
|
||||
self._register_components(_components)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the components parameter Tensor."""
|
||||
return self._components
|
||||
|
||||
|
||||
class LabeledComponents(AbstractComponents):
|
||||
"""A set of adaptable components and corresponding unadaptable labels."""
|
||||
def __init__(self, distribution: Union[dict, list, tuple],
|
||||
components_initializer: AbstractComponentsInitializer,
|
||||
labels_initializer: AbstractLabelsInitializer, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.add_components(distribution, components_initializer,
|
||||
labels_initializer)
|
||||
|
||||
@property
|
||||
def component_labels(self):
|
||||
"""Tensor containing the component tensors."""
|
||||
return self._labels.detach()
|
||||
|
||||
def _register_labels(self, labels):
|
||||
self.register_buffer("_labels", labels)
|
||||
|
||||
def add_components(
|
||||
self,
|
||||
distribution,
|
||||
components_initializer,
|
||||
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||
# Checks
|
||||
assert validate_components_initializer(components_initializer)
|
||||
assert validate_labels_initializer(labels_initializer)
|
||||
|
||||
distribution = parse_distribution(distribution)
|
||||
|
||||
# Generate new components
|
||||
if isinstance(components_initializer, ClassAwareCompInitializer):
|
||||
new_components = components_initializer.generate(distribution)
|
||||
else:
|
||||
num_components = sum(distribution.values())
|
||||
new_components = components_initializer.generate(num_components)
|
||||
|
||||
# Generate new labels
|
||||
new_labels = labels_initializer.generate(distribution)
|
||||
|
||||
# Register
|
||||
if hasattr(self, "_components"):
|
||||
_components = torch.cat([self._components, new_components])
|
||||
else:
|
||||
_components = new_components
|
||||
if hasattr(self, "_labels"):
|
||||
_labels = torch.cat([self._labels, new_labels])
|
||||
else:
|
||||
_labels = new_labels
|
||||
self._register_components(_components)
|
||||
self._register_labels(_labels)
|
||||
|
||||
return new_components, new_labels
|
||||
|
||||
def remove_components(self, indices):
|
||||
"""Remove components and labels at specified indices."""
|
||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
||||
mask[indices] = False
|
||||
_components = self._components[mask]
|
||||
_labels = self._labels[mask]
|
||||
self._register_components(_components)
|
||||
self._register_labels(_labels)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the components parameter Tensor and labels."""
|
||||
return self._components, self._labels
|
||||
|
||||
|
||||
class ReasoningComponents(AbstractComponents):
|
||||
"""A set of components and a corresponding adapatable reasoning matrices.
|
||||
|
||||
Every component has its own reasoning matrix.
|
||||
|
||||
A reasoning matrix is an Nx2 matrix, where N is the number of classes. The
|
||||
first element is called positive reasoning :math:`p`, the second negative
|
||||
reasoning :math:`n`. A components can reason in favour (positive) of a
|
||||
class, against (negative) a class or not at all (neutral).
|
||||
|
||||
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
|
||||
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
|
||||
three element probability distribution.
|
||||
|
||||
"""
|
||||
def __init__(self, distribution: Union[dict, list, tuple],
|
||||
components_initializer: AbstractComponentsInitializer,
|
||||
reasonings_initializer: AbstractReasoningsInitializer,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.add_components(distribution, components_initializer,
|
||||
reasonings_initializer)
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
"""Returns Reasoning Matrix.
|
||||
|
||||
Dimension NxCx2
|
||||
|
||||
"""
|
||||
return self._reasonings.detach()
|
||||
|
||||
def _register_reasonings(self, reasonings):
|
||||
self.register_parameter("_reasonings", Parameter(reasonings))
|
||||
|
||||
def add_components(self, distribution, components_initializer,
|
||||
reasonings_initializer: AbstractReasoningsInitializer):
|
||||
# Checks
|
||||
assert validate_components_initializer(components_initializer)
|
||||
assert validate_reasonings_initializer(reasonings_initializer)
|
||||
|
||||
distribution = parse_distribution(distribution)
|
||||
|
||||
# Generate new components
|
||||
if isinstance(components_initializer, ClassAwareCompInitializer):
|
||||
new_components = components_initializer.generate(distribution)
|
||||
else:
|
||||
num_components = sum(distribution.values())
|
||||
new_components = components_initializer.generate(num_components)
|
||||
|
||||
# Generate new reasonings
|
||||
new_reasonings = reasonings_initializer.generate(distribution)
|
||||
|
||||
# Register
|
||||
if hasattr(self, "_components"):
|
||||
_components = torch.cat([self._components, new_components])
|
||||
else:
|
||||
_components = new_components
|
||||
if hasattr(self, "_reasonings"):
|
||||
_reasonings = torch.cat([self._reasonings, new_reasonings])
|
||||
else:
|
||||
_reasonings = new_reasonings
|
||||
self._register_components(_components)
|
||||
self._register_reasonings(_reasonings)
|
||||
|
||||
return new_components, new_reasonings
|
||||
|
||||
def remove_components(self, indices):
|
||||
"""Remove components and labels at specified indices."""
|
||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
||||
mask[indices] = False
|
||||
_components = self._components[mask]
|
||||
# TODO
|
||||
# _reasonings = self._reasonings[mask]
|
||||
self._register_components(_components)
|
||||
# self._register_reasonings(_reasonings)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the components and reasonings."""
|
||||
return self._components, self._reasonings
|
258
prototorch/core/initializers.py
Normal file
258
prototorch/core/initializers.py
Normal file
@ -0,0 +1,258 @@
|
||||
"""ProtoTorch code initializers"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import parse_data_arg, parse_distribution
|
||||
|
||||
|
||||
# Components
|
||||
class AbstractComponentsInitializer(ABC):
|
||||
"""Abstract class for all components initializers."""
|
||||
...
|
||||
|
||||
|
||||
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all dimension-aware components initializers."""
|
||||
def __init__(self, shape: Union[Iterable, int]):
|
||||
if isinstance(shape, Iterable):
|
||||
self.component_shape = tuple(shape)
|
||||
else:
|
||||
self.component_shape = (shape, )
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, num_components: int):
|
||||
...
|
||||
|
||||
|
||||
class DataAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all data-aware components initializers.
|
||||
|
||||
Components generated by data-aware components initializers inherit the shape
|
||||
of the provided data.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
data,
|
||||
noise: float = 0.0,
|
||||
transform: callable = torch.nn.Identity()):
|
||||
self.data = data
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
|
||||
def generate_end_hook(self, samples):
|
||||
drift = torch.rand_like(samples) * self.noise
|
||||
components = self.transform(samples + drift)
|
||||
return components
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, num_components: int):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
|
||||
|
||||
class ClassAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all class-aware components initializers.
|
||||
|
||||
Components generated by class-aware components initializers inherit the shape
|
||||
of the provided data.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
data,
|
||||
noise: float = 0.0,
|
||||
transform: callable = torch.nn.Identity()):
|
||||
self.data, self.targets = parse_data_arg(data)
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
self.clabels = torch.unique(self.targets).int().tolist()
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def subinit_type(self) -> DataAwareCompInitializer:
|
||||
...
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
initializers = {
|
||||
k: self.subinit_type(self.data[self.targets == k])
|
||||
for k in distribution.keys()
|
||||
}
|
||||
components = torch.tensor([])
|
||||
for k, v in distribution.items():
|
||||
stratified_data = self.data[self.targets == k]
|
||||
# skip transform here
|
||||
initializer = self.subinit_type(
|
||||
stratified_data,
|
||||
noise=self.noise,
|
||||
transform=self.transform,
|
||||
)
|
||||
samples = initializer.generate(num_components=v)
|
||||
components = torch.cat([components, samples])
|
||||
return components
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
del self.targets
|
||||
|
||||
|
||||
class LiteralCompInitializer(DataAwareCompInitializer):
|
||||
"""'Generate' the provided components.
|
||||
|
||||
Use this to 'generate' pre-initialized components from elsewhere.
|
||||
|
||||
"""
|
||||
def generate(self, num_components: int):
|
||||
"""Ignore `num_components` and simply return transformed `self.data`."""
|
||||
components = self.transform(self.data)
|
||||
return components
|
||||
|
||||
|
||||
class ZerosCompInitializer(ShapeAwareCompInitializer):
|
||||
"""Generate zeros corresponding to the components shape."""
|
||||
def generate(self, num_components: int):
|
||||
components = torch.zeros((num_components, ) + self.component_shape)
|
||||
return components
|
||||
|
||||
|
||||
class OnesCompInitializer(ShapeAwareCompInitializer):
|
||||
"""Generate ones corresponding to the components shape."""
|
||||
def generate(self, num_components: int):
|
||||
components = torch.ones((num_components, ) + self.component_shape)
|
||||
return components
|
||||
|
||||
|
||||
class FillValueCompInitializer(OnesCompInitializer):
|
||||
"""Generate components with the provided `fill_value`."""
|
||||
def __init__(self, shape, fill_value: float = 1.0):
|
||||
super().__init__(shape)
|
||||
self.fill_value = fill_value
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = ones.fill_(self.fill_value)
|
||||
return components
|
||||
|
||||
|
||||
class UniformCompInitializer(OnesCompInitializer):
|
||||
"""Generate components by sampling from a continuous uniform distribution."""
|
||||
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
|
||||
super().__init__(shape)
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = self.scale * ones.uniform_(self.minimum, self.maximum)
|
||||
return components
|
||||
|
||||
|
||||
class RandomNormalCompInitializer(OnesCompInitializer):
|
||||
"""Generate components by sampling from a standard normal distribution."""
|
||||
def __init__(self, shape, scale=1.0):
|
||||
super().__init__(shape)
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = self.scale * torch.randn_like(ones)
|
||||
return components
|
||||
|
||||
|
||||
class SelectionCompInitializer(DataAwareCompInitializer):
|
||||
"""Generate components by uniformly sampling from the provided data."""
|
||||
def generate(self, num_components: int):
|
||||
indices = torch.LongTensor(num_components).random_(0, len(self.data))
|
||||
samples = self.data[indices]
|
||||
components = self.generate_end_hook(samples)
|
||||
return components
|
||||
|
||||
|
||||
class MeanCompInitializer(DataAwareCompInitializer):
|
||||
"""Generate components by computing the mean of the provided data."""
|
||||
def generate(self, num_components: int):
|
||||
mean = torch.mean(self.data, dim=0)
|
||||
repeat_dim = [num_components] + [1] * len(mean.shape)
|
||||
samples = mean.repeat(repeat_dim)
|
||||
components = self.generate_end_hook(samples)
|
||||
return components
|
||||
|
||||
|
||||
class StratifiedSelectionCompInitializer(ClassAwareCompInitializer):
|
||||
"""Generate components using stratified sampling from the provided data."""
|
||||
@property
|
||||
def subinit_type(self):
|
||||
return SelectionCompInitializer
|
||||
|
||||
|
||||
class StratifiedMeanCompInitializer(ClassAwareCompInitializer):
|
||||
"""Generate components at stratified means of the provided data."""
|
||||
@property
|
||||
def subinit_type(self):
|
||||
return MeanCompInitializer
|
||||
|
||||
|
||||
# Labels
|
||||
class AbstractLabelsInitializer(ABC):
|
||||
"""Abstract class for all labels initializers."""
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
|
||||
|
||||
class LabelsInitializer(AbstractLabelsInitializer):
|
||||
"""Generate labels with `self.distribution`."""
|
||||
def __init__(self, override_labels: list = []):
|
||||
self.override_labels = override_labels
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
labels = []
|
||||
for k, v in distribution.items():
|
||||
labels.extend([k] * v)
|
||||
labels = torch.LongTensor(labels)
|
||||
return labels
|
||||
|
||||
|
||||
# Reasonings
|
||||
class AbstractReasoningsInitializer(ABC):
|
||||
"""Abstract class for all reasonings initializers."""
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
|
||||
|
||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Generate labels with `self.distribution`."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
num_classes = len(distribution.keys())
|
||||
num_components = sum(distribution.values())
|
||||
assert num_classes == num_components
|
||||
reasonings = torch.stack(
|
||||
[torch.eye(num_classes),
|
||||
torch.zeros(num_classes, num_classes)],
|
||||
dim=0)
|
||||
return reasonings
|
||||
|
||||
|
||||
# Aliases - Components
|
||||
ZCI = ZerosCompInitializer
|
||||
OCI = OnesCompInitializer
|
||||
FVCI = FillValueCompInitializer
|
||||
LCI = LiteralCompInitializer
|
||||
UCI = UniformCompInitializer
|
||||
RNCI = RandomNormalCompInitializer
|
||||
SCI = SelectionCompInitializer
|
||||
MCI = MeanCompInitializer
|
||||
SSCI = StratifiedSelectionCompInitializer
|
||||
SMCI = StratifiedMeanCompInitializer
|
||||
PPRI = PurePositiveReasoningsInitializer
|
@ -23,7 +23,10 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
||||
return mesh, xx, yy
|
||||
|
||||
|
||||
def parse_distribution(user_distribution: Union[dict, list, tuple]):
|
||||
def parse_distribution(
|
||||
user_distribution: Union[dict[int, int], dict[str, str], list[int],
|
||||
tuple[int]]
|
||||
) -> dict[int, int]:
|
||||
"""Parse user-provided distribution.
|
||||
|
||||
Return a dictionary with integer keys that represent the class labels and
|
||||
@ -51,14 +54,15 @@ def parse_distribution(user_distribution: Union[dict, list, tuple]):
|
||||
|
||||
if isinstance(user_distribution, dict):
|
||||
if "num_classes" in user_distribution.keys():
|
||||
num_classes = user_distribution["num_classes"]
|
||||
per_class = user_distribution["per_class"]
|
||||
num_classes = int(user_distribution["num_classes"])
|
||||
per_class = int(user_distribution["per_class"])
|
||||
return from_list([per_class] * num_classes)
|
||||
else:
|
||||
return user_distribution
|
||||
elif isinstance(user_distribution, tuple):
|
||||
assert len(user_distribution) == 2
|
||||
num_classes, per_class = user_distribution
|
||||
num_classes, per_class = int(num_classes), int(per_class)
|
||||
return from_list([per_class] * num_classes)
|
||||
elif isinstance(user_distribution, list):
|
||||
return from_list(user_distribution)
|
||||
|
Loading…
Reference in New Issue
Block a user