[REFACTOR] Clean and move components and initializers into core

This commit is contained in:
Jensun Ravichandran 2021-06-12 20:29:24 +02:00
parent b8969347b1
commit 5dddb39ec4
7 changed files with 510 additions and 549 deletions

View File

@ -1,235 +0,0 @@
"""ProtoTorch Components."""
import warnings
import torch
from prototorch.components.initializers import (ClassAwareInitializer,
ComponentsInitializer,
EqualLabelsInitializer,
UnequalLabelsInitializer,
ZeroReasoningsInitializer)
from torch.nn.parameter import Parameter
from .initializers import parse_data_arg
def get_labels_initializer(distribution):
if isinstance(distribution, dict):
if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer(
distribution["num_classes"],
distribution["prototypes_per_class"])
else:
clabels = list(distribution.keys())
dist = list(distribution.values())
labels = UnequalLabelsInitializer(dist, clabels)
elif isinstance(distribution, tuple):
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif isinstance(distribution, list):
labels = UnequalLabelsInitializer(distribution)
else:
msg = f"`distribution` not understood." \
f"You have provided: {distribution=}."
raise ValueError(msg)
return labels
def _precheck_initializer(initializer):
if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some subtype of " \
f"{ComponentsInitializer}. " \
f"You have provided: {initializer=} instead."
raise TypeError(emsg)
class Components(torch.nn.Module):
"""Components is a set of learnable Tensors."""
def __init__(self,
num_components=None,
initializer=None,
*,
initialized_components=None):
super().__init__()
# Ignore all initialization settings if initialized_components is given.
if initialized_components is not None:
self._register_components(initialized_components)
if num_components is not None or initializer is not None:
wmsg = "Arguments ignored while initializing Components"
warnings.warn(wmsg)
else:
self._initialize_components(num_components, initializer)
@property
def num_components(self):
return len(self._components)
def _register_components(self, components):
self.register_parameter("_components", Parameter(components))
def _initialize_components(self, num_components, initializer):
_precheck_initializer(initializer)
_components = initializer.generate(num_components)
self._register_components(_components)
def add_components(self,
num=1,
initializer=None,
*,
initialized_components=None):
if initialized_components is not None:
_components = torch.cat([self._components, initialized_components])
else:
_precheck_initializer(initializer)
_new = initializer.generate(num)
_components = torch.cat([self._components, _new])
self._register_components(_components)
def remove_components(self, indices=None):
mask = torch.ones(self.num_components, dtype=torch.bool)
mask[indices] = False
_components = self._components[mask]
self._register_components(_components)
return mask
@property
def components(self):
"""Tensor containing the component tensors."""
return self._components.detach()
def forward(self):
return self._components
def extra_repr(self):
return f"(components): (shape: {tuple(self._components.shape)})"
class LabeledComponents(Components):
"""LabeledComponents generate a set of components and a set of labels.
Every Component has a label assigned.
"""
def __init__(self,
distribution=None,
initializer=None,
*,
initialized_components=None):
if initialized_components is not None:
components, component_labels = parse_data_arg(
initialized_components)
super().__init__(initialized_components=components)
# self._labels = component_labels
self._labels = component_labels
else:
labels_initializer = get_labels_initializer(distribution)
self.initial_distribution = labels_initializer.distribution
_labels = labels.generate()
super().__init__(len(_labels), initializer=initializer)
self._register_labels(_labels)
def _register_labels(self, labels):
self.register_buffer("_labels", labels)
@property
def distribution(self):
clabels, counts = torch.unique(self._labels,
sorted=True,
return_counts=True)
return dict(zip(clabels.tolist(), counts.tolist()))
def _initialize_components(self, num_components, initializer):
if isinstance(initializer, ClassAwareInitializer):
_precheck_initializer(initializer)
_components = initializer.generate(num_components,
self.initial_distribution)
self._register_components(_components)
else:
super()._initialize_components(num_components, initializer)
def add_components(self, distribution, initializer):
_precheck_initializer(initializer)
# Labels
labels_initializer = get_labels_initializer(distribution)
new_labels = labels_initializer.generate()
_labels = torch.cat([self._labels, new_labels])
self._register_labels(_labels)
# Components
if isinstance(initializer, ClassAwareInitializer):
_new = initializer.generate(len(new_labels), distribution)
else:
_new = initializer.generate(len(new_labels))
_components = torch.cat([self._components, _new])
self._register_components(_components)
def remove_components(self, indices=None):
# Components
mask = super().remove_components(indices)
# Labels
_labels = self._labels[mask]
self._register_labels(_labels)
@property
def component_labels(self):
"""Tensor containing the component tensors."""
return self._labels.detach()
def forward(self):
return super().forward(), self._labels
class ReasoningComponents(Components):
"""ReasoningComponents generate a set of components and a set of reasoning matrices.
Every Component has a reasoning matrix assigned.
A reasoning matrix is a Nx2 matrix, where N is the number of Classes. The
first element is called positive reasoning :math:`p`, the second negative
reasoning :math:`n`. A components can reason in favour (positive) of a
class, against (negative) a class or not at all (neutral).
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
three element probability distribution.
"""
def __init__(self,
distribution=None,
initializer=None,
reasoning_initializer=None,
*,
initialized_components=None):
if initialized_components is not None:
components, reasonings = initialized_components
super().__init__(initialized_components=components)
self.register_parameter("_reasonings", reasonings)
else:
labels_initializer = get_labels_initializer(distribution)
self.initial_distribution = labels_initializer.distribution
super().__init__(len(self.initial_distribution),
initializer=initializer)
reasonings = reasoning_initializer.generate()
self._register_reasonings(reasonings)
def _initialize_reasonings(self, reasoning_initializer):
if isinstance(reasonings, tuple):
num_classes, num_components = reasonings
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
_reasonings = reasonings.generate()
self.register_parameter("_reasonings", _reasonings)
@property
def reasonings(self):
"""Returns Reasoning Matrix.
Dimension NxCx2
"""
return self._reasonings.detach()
def forward(self):
return super().forward(), self._reasonings

View File

@ -1,225 +0,0 @@
"""ProtoTroch Initializers."""
import warnings
from collections.abc import Iterable
from itertools import chain
from typing import List
import torch
from torch.utils.data import DataLoader, Dataset
def parse_data_arg(data_arg):
if isinstance(data_arg, Dataset):
data_arg = DataLoader(data_arg, batch_size=len(data_arg))
if isinstance(data_arg, DataLoader):
data = torch.tensor([])
targets = torch.tensor([])
for x, y in data_arg:
data = torch.cat([data, x])
targets = torch.cat([targets, y])
else:
data, targets = data_arg
if not isinstance(data, torch.Tensor):
wmsg = f"Converting data to {torch.Tensor}."
warnings.warn(wmsg)
data = torch.Tensor(data)
if not isinstance(targets, torch.Tensor):
wmsg = f"Converting targets to {torch.Tensor}."
warnings.warn(wmsg)
targets = torch.Tensor(targets)
return data, targets
def get_subinitializers(data, targets, clabels, subinit_type):
initializers = dict()
for clabel in clabels:
class_data = data[targets == clabel]
class_initializer = subinit_type(class_data)
initializers[clabel] = (class_initializer)
return initializers
# Components
class ComponentsInitializer(object):
def generate(self, number_of_components):
raise NotImplementedError("Subclasses should implement this!")
class DimensionAwareInitializer(ComponentsInitializer):
def __init__(self, dims):
super().__init__()
if isinstance(dims, Iterable):
self.components_dims = tuple(dims)
else:
self.components_dims = (dims, )
class OnesInitializer(DimensionAwareInitializer):
def __init__(self, dims, scale=1.0):
super().__init__(dims)
self.scale = scale
def generate(self, length):
gen_dims = (length, ) + self.components_dims
return torch.ones(gen_dims) * self.scale
class ZerosInitializer(DimensionAwareInitializer):
def generate(self, length):
gen_dims = (length, ) + self.components_dims
return torch.zeros(gen_dims)
class UniformInitializer(DimensionAwareInitializer):
def __init__(self, dims, minimum=0.0, maximum=1.0, scale=1.0):
super().__init__(dims)
self.minimum = minimum
self.maximum = maximum
self.scale = scale
def generate(self, length):
gen_dims = (length, ) + self.components_dims
return torch.ones(gen_dims).uniform_(self.minimum,
self.maximum) * self.scale
class DataAwareInitializer(ComponentsInitializer):
def __init__(self, data, transform=torch.nn.Identity()):
super().__init__()
self.data = data
self.transform = transform
def __del__(self):
del self.data
class SelectionInitializer(DataAwareInitializer):
def generate(self, length):
indices = torch.LongTensor(length).random_(0, len(self.data))
return self.transform(self.data[indices])
class MeanInitializer(DataAwareInitializer):
def generate(self, length):
mean = torch.mean(self.data, dim=0)
repeat_dim = [length] + [1] * len(mean.shape)
return self.transform(mean.repeat(repeat_dim))
class ClassAwareInitializer(DataAwareInitializer):
def __init__(self, data, transform=torch.nn.Identity()):
data, targets = parse_data_arg(data)
super().__init__(data, transform)
self.targets = targets
self.clabels = torch.unique(self.targets).int().tolist()
self.num_classes = len(self.clabels)
def _get_samples_from_initializer(self, length, dist):
if not dist:
per_class = length // self.num_classes
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
if isinstance(dist, list):
dist = dict(zip(self.clabels, dist))
samples = [self.initializers[k].generate(n) for k, n in dist.items()]
out = torch.vstack(samples)
with torch.no_grad():
out = self.transform(out)
return out
def __del__(self):
del self.data
del self.targets
class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, data, **kwargs):
super().__init__(data, **kwargs)
self.initializers = get_subinitializers(self.data, self.targets,
self.clabels, MeanInitializer)
def generate(self, length, dist):
samples = self._get_samples_from_initializer(length, dist)
return samples
class StratifiedSelectionInitializer(ClassAwareInitializer):
def __init__(self, data, noise=None, **kwargs):
super().__init__(data, **kwargs)
self.noise = noise
self.initializers = get_subinitializers(self.data, self.targets,
self.clabels,
SelectionInitializer)
def add_noise_v1(self, x):
return x + self.noise
def add_noise_v2(self, x):
"""Shifts some dimensions of the data randomly."""
n1 = torch.rand_like(x)
n2 = torch.rand_like(x)
mask = torch.bernoulli(n1) - torch.bernoulli(n2)
return x + (self.noise * mask)
def generate(self, length, dist):
samples = self._get_samples_from_initializer(length, dist)
if self.noise is not None:
samples = self.add_noise_v1(samples)
return samples
# Labels
class LabelsInitializer:
def generate(self):
raise NotImplementedError("Subclasses should implement this!")
class UnequalLabelsInitializer(LabelsInitializer):
def __init__(self, dist, clabels=None):
self.dist = dist
self.clabels = clabels or range(len(self.dist))
@property
def distribution(self) -> List:
return self.dist
def generate(self):
targets = list(
chain(*[[i] * n for i, n in zip(self.clabels, self.dist)]))
return torch.LongTensor(targets)
class EqualLabelsInitializer(LabelsInitializer):
def __init__(self, classes, per_class):
self.classes = classes
self.per_class = per_class
@property
def distribution(self) -> List:
return self.classes * [self.per_class]
def generate(self):
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
# Reasonings
class ReasoningsInitializer:
def generate(self, length):
raise NotImplementedError("Subclasses should implement this!")
class ZeroReasoningsInitializer(ReasoningsInitializer):
def __init__(self, classes, length):
self.classes = classes
self.length = length
def generate(self):
return torch.zeros((self.length, self.classes, 2))
# Aliases
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
SMI = StratifiedMeanInitializer
Random = RandomInitializer = UniformInitializer
Zeros = ZerosInitializer
Ones = OnesInitializer

View File

@ -1,86 +0,0 @@
"""ProtoTorch Labels."""
import torch
from prototorch.components.components import get_labels_initializer
from prototorch.components.initializers import (ClassAwareInitializer,
ComponentsInitializer,
EqualLabelsInitializer,
UnequalLabelsInitializer)
from torch.nn.parameter import Parameter
def get_labels_initializer(distribution):
if isinstance(distribution, dict):
if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer(
distribution["num_classes"],
distribution["prototypes_per_class"])
else:
clabels = list(distribution.keys())
dist = list(distribution.values())
labels = UnequalLabelsInitializer(dist, clabels)
elif isinstance(distribution, tuple):
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif isinstance(distribution, list):
labels = UnequalLabelsInitializer(distribution)
else:
msg = f"`distribution` not understood." \
f"You have provided: {distribution=}."
raise ValueError(msg)
return labels
class Labels(torch.nn.Module):
def __init__(self,
distribution=None,
initializer=None,
*,
initialized_labels=None):
_labels = self.get_labels(distribution,
initializer,
initialized_labels=initialized_labels)
self._register_labels(_labels)
def _register_labels(self, labels):
# self.register_buffer("_labels", labels)
self.register_parameter("_labels",
Parameter(labels, requires_grad=False))
def get_labels(self,
distribution=None,
initializer=None,
*,
initialized_labels=None):
if initialized_labels is not None:
_labels = initialized_labels
else:
labels_initializer = initializer or get_labels_initializer(
distribution)
self.initial_distribution = labels_initializer.distribution
_labels = labels_initializer.generate()
return _labels
def add_labels(self,
distribution=None,
initializer=None,
*,
initialized_labels=None):
new_labels = self.get_labels(distribution,
initializer,
initialized_labels=initialized_labels)
_labels = torch.cat([self._labels, new_labels])
self._register_labels(_labels)
def remove_labels(self, indices=None):
mask = torch.ones(len(self._labels, dtype=torch.bool))
mask[indices] = False
_labels = self._labels[mask]
self._register_labels(_labels)
@property
def labels(self):
return self._labels
def forward(self):
return self._labels

View File

@ -1,3 +1,5 @@
"""ProtoTorch core"""
from .components import * from .components import *
from .initializers import * from .initializers import *
from .labels import * from .labels import *

View File

@ -0,0 +1,243 @@
"""ProtoTorch components"""
import inspect
from typing import Union
import torch
from torch.nn.parameter import Parameter
from ..utils import parse_distribution
from .initializers import (
AbstractComponentsInitializer,
AbstractLabelsInitializer,
AbstractReasoningsInitializer,
ClassAwareCompInitializer,
LabelsInitializer,
)
def validate_initializer(initializer, instanceof):
if not isinstance(initializer, instanceof):
emsg = f"`initializer` has to be an instance " \
f"of some subtype of {instanceof}. " \
f"You have provided: {initializer} instead. "
helpmsg = ""
if inspect.isclass(initializer):
helpmsg = f"Perhaps you meant to say, {initializer.__name__}() " \
f"with the brackets instead of just {initializer.__name__}?"
raise TypeError(emsg + helpmsg)
return True
def validate_components_initializer(initializer):
return validate_initializer(initializer, AbstractComponentsInitializer)
def validate_labels_initializer(initializer):
return validate_initializer(initializer, AbstractLabelsInitializer)
def validate_reasonings_initializer(initializer):
return validate_initializer(initializer, AbstractReasoningsInitializer)
class AbstractComponents(torch.nn.Module):
"""Abstract class for all components modules."""
@property
def num_components(self):
"""Current number of components."""
return len(self._components)
@property
def components(self):
"""Detached Tensor containing the components."""
return self._components.detach()
def _register_components(self, components):
self.register_parameter("_components", Parameter(components))
def extra_repr(self):
return f"(components): (shape: {tuple(self._components.shape)})"
class Components(AbstractComponents):
"""A set of adaptable Tensors."""
def __init__(self, num_components: int,
initializer: AbstractComponentsInitializer, **kwargs):
super().__init__(**kwargs)
self.add_components(num_components, initializer)
def add_components(self, num: int,
initializer: AbstractComponentsInitializer):
"""Add new components."""
assert validate_components_initializer(initializer)
new_components = initializer.generate(num)
# Register
if hasattr(self, "_components"):
_components = torch.cat([self._components, new_components])
else:
_components = new_components
self._register_components(_components)
return new_components
def remove_components(self, indices):
"""Remove components at specified indices."""
mask = torch.ones(self.num_components, dtype=torch.bool)
mask[indices] = False
_components = self._components[mask]
self._register_components(_components)
return mask
def forward(self):
"""Simply return the components parameter Tensor."""
return self._components
class LabeledComponents(AbstractComponents):
"""A set of adaptable components and corresponding unadaptable labels."""
def __init__(self, distribution: Union[dict, list, tuple],
components_initializer: AbstractComponentsInitializer,
labels_initializer: AbstractLabelsInitializer, **kwargs):
super().__init__(**kwargs)
self.add_components(distribution, components_initializer,
labels_initializer)
@property
def component_labels(self):
"""Tensor containing the component tensors."""
return self._labels.detach()
def _register_labels(self, labels):
self.register_buffer("_labels", labels)
def add_components(
self,
distribution,
components_initializer,
labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
# Checks
assert validate_components_initializer(components_initializer)
assert validate_labels_initializer(labels_initializer)
distribution = parse_distribution(distribution)
# Generate new components
if isinstance(components_initializer, ClassAwareCompInitializer):
new_components = components_initializer.generate(distribution)
else:
num_components = sum(distribution.values())
new_components = components_initializer.generate(num_components)
# Generate new labels
new_labels = labels_initializer.generate(distribution)
# Register
if hasattr(self, "_components"):
_components = torch.cat([self._components, new_components])
else:
_components = new_components
if hasattr(self, "_labels"):
_labels = torch.cat([self._labels, new_labels])
else:
_labels = new_labels
self._register_components(_components)
self._register_labels(_labels)
return new_components, new_labels
def remove_components(self, indices):
"""Remove components and labels at specified indices."""
mask = torch.ones(self.num_components, dtype=torch.bool)
mask[indices] = False
_components = self._components[mask]
_labels = self._labels[mask]
self._register_components(_components)
self._register_labels(_labels)
return mask
def forward(self):
"""Simply return the components parameter Tensor and labels."""
return self._components, self._labels
class ReasoningComponents(AbstractComponents):
"""A set of components and a corresponding adapatable reasoning matrices.
Every component has its own reasoning matrix.
A reasoning matrix is an Nx2 matrix, where N is the number of classes. The
first element is called positive reasoning :math:`p`, the second negative
reasoning :math:`n`. A components can reason in favour (positive) of a
class, against (negative) a class or not at all (neutral).
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
three element probability distribution.
"""
def __init__(self, distribution: Union[dict, list, tuple],
components_initializer: AbstractComponentsInitializer,
reasonings_initializer: AbstractReasoningsInitializer,
**kwargs):
super().__init__(**kwargs)
self.add_components(distribution, components_initializer,
reasonings_initializer)
@property
def reasonings(self):
"""Returns Reasoning Matrix.
Dimension NxCx2
"""
return self._reasonings.detach()
def _register_reasonings(self, reasonings):
self.register_parameter("_reasonings", Parameter(reasonings))
def add_components(self, distribution, components_initializer,
reasonings_initializer: AbstractReasoningsInitializer):
# Checks
assert validate_components_initializer(components_initializer)
assert validate_reasonings_initializer(reasonings_initializer)
distribution = parse_distribution(distribution)
# Generate new components
if isinstance(components_initializer, ClassAwareCompInitializer):
new_components = components_initializer.generate(distribution)
else:
num_components = sum(distribution.values())
new_components = components_initializer.generate(num_components)
# Generate new reasonings
new_reasonings = reasonings_initializer.generate(distribution)
# Register
if hasattr(self, "_components"):
_components = torch.cat([self._components, new_components])
else:
_components = new_components
if hasattr(self, "_reasonings"):
_reasonings = torch.cat([self._reasonings, new_reasonings])
else:
_reasonings = new_reasonings
self._register_components(_components)
self._register_reasonings(_reasonings)
return new_components, new_reasonings
def remove_components(self, indices):
"""Remove components and labels at specified indices."""
mask = torch.ones(self.num_components, dtype=torch.bool)
mask[indices] = False
_components = self._components[mask]
# TODO
# _reasonings = self._reasonings[mask]
self._register_components(_components)
# self._register_reasonings(_reasonings)
return mask
def forward(self):
"""Simply return the components and reasonings."""
return self._components, self._reasonings

View File

@ -0,0 +1,258 @@
"""ProtoTorch code initializers"""
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Union
import torch
from ..utils import parse_data_arg, parse_distribution
# Components
class AbstractComponentsInitializer(ABC):
"""Abstract class for all components initializers."""
...
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all dimension-aware components initializers."""
def __init__(self, shape: Union[Iterable, int]):
if isinstance(shape, Iterable):
self.component_shape = tuple(shape)
else:
self.component_shape = (shape, )
@abstractmethod
def generate(self, num_components: int):
...
class DataAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all data-aware components initializers.
Components generated by data-aware components initializers inherit the shape
of the provided data.
"""
def __init__(self,
data,
noise: float = 0.0,
transform: callable = torch.nn.Identity()):
self.data = data
self.noise = noise
self.transform = transform
def generate_end_hook(self, samples):
drift = torch.rand_like(samples) * self.noise
components = self.transform(samples + drift)
return components
@abstractmethod
def generate(self, num_components: int):
...
return self.generate_end_hook(...)
def __del__(self):
del self.data
class ClassAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all class-aware components initializers.
Components generated by class-aware components initializers inherit the shape
of the provided data.
"""
def __init__(self,
data,
noise: float = 0.0,
transform: callable = torch.nn.Identity()):
self.data, self.targets = parse_data_arg(data)
self.noise = noise
self.transform = transform
self.clabels = torch.unique(self.targets).int().tolist()
self.num_classes = len(self.clabels)
@property
@abstractmethod
def subinit_type(self) -> DataAwareCompInitializer:
...
def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution)
initializers = {
k: self.subinit_type(self.data[self.targets == k])
for k in distribution.keys()
}
components = torch.tensor([])
for k, v in distribution.items():
stratified_data = self.data[self.targets == k]
# skip transform here
initializer = self.subinit_type(
stratified_data,
noise=self.noise,
transform=self.transform,
)
samples = initializer.generate(num_components=v)
components = torch.cat([components, samples])
return components
def __del__(self):
del self.data
del self.targets
class LiteralCompInitializer(DataAwareCompInitializer):
"""'Generate' the provided components.
Use this to 'generate' pre-initialized components from elsewhere.
"""
def generate(self, num_components: int):
"""Ignore `num_components` and simply return transformed `self.data`."""
components = self.transform(self.data)
return components
class ZerosCompInitializer(ShapeAwareCompInitializer):
"""Generate zeros corresponding to the components shape."""
def generate(self, num_components: int):
components = torch.zeros((num_components, ) + self.component_shape)
return components
class OnesCompInitializer(ShapeAwareCompInitializer):
"""Generate ones corresponding to the components shape."""
def generate(self, num_components: int):
components = torch.ones((num_components, ) + self.component_shape)
return components
class FillValueCompInitializer(OnesCompInitializer):
"""Generate components with the provided `fill_value`."""
def __init__(self, shape, fill_value: float = 1.0):
super().__init__(shape)
self.fill_value = fill_value
def generate(self, num_components: int):
ones = super().generate(num_components)
components = ones.fill_(self.fill_value)
return components
class UniformCompInitializer(OnesCompInitializer):
"""Generate components by sampling from a continuous uniform distribution."""
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
super().__init__(shape)
self.minimum = minimum
self.maximum = maximum
self.scale = scale
def generate(self, num_components: int):
ones = super().generate(num_components)
components = self.scale * ones.uniform_(self.minimum, self.maximum)
return components
class RandomNormalCompInitializer(OnesCompInitializer):
"""Generate components by sampling from a standard normal distribution."""
def __init__(self, shape, scale=1.0):
super().__init__(shape)
self.scale = scale
def generate(self, num_components: int):
ones = super().generate(num_components)
components = self.scale * torch.randn_like(ones)
return components
class SelectionCompInitializer(DataAwareCompInitializer):
"""Generate components by uniformly sampling from the provided data."""
def generate(self, num_components: int):
indices = torch.LongTensor(num_components).random_(0, len(self.data))
samples = self.data[indices]
components = self.generate_end_hook(samples)
return components
class MeanCompInitializer(DataAwareCompInitializer):
"""Generate components by computing the mean of the provided data."""
def generate(self, num_components: int):
mean = torch.mean(self.data, dim=0)
repeat_dim = [num_components] + [1] * len(mean.shape)
samples = mean.repeat(repeat_dim)
components = self.generate_end_hook(samples)
return components
class StratifiedSelectionCompInitializer(ClassAwareCompInitializer):
"""Generate components using stratified sampling from the provided data."""
@property
def subinit_type(self):
return SelectionCompInitializer
class StratifiedMeanCompInitializer(ClassAwareCompInitializer):
"""Generate components at stratified means of the provided data."""
@property
def subinit_type(self):
return MeanCompInitializer
# Labels
class AbstractLabelsInitializer(ABC):
"""Abstract class for all labels initializers."""
@abstractmethod
def generate(self, distribution: Union[dict, list, tuple]):
...
class LabelsInitializer(AbstractLabelsInitializer):
"""Generate labels with `self.distribution`."""
def __init__(self, override_labels: list = []):
self.override_labels = override_labels
def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution)
labels = []
for k, v in distribution.items():
labels.extend([k] * v)
labels = torch.LongTensor(labels)
return labels
# Reasonings
class AbstractReasoningsInitializer(ABC):
"""Abstract class for all reasonings initializers."""
@abstractmethod
def generate(self, distribution: Union[dict, list, tuple]):
...
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
"""Generate labels with `self.distribution`."""
def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution)
num_classes = len(distribution.keys())
num_components = sum(distribution.values())
assert num_classes == num_components
reasonings = torch.stack(
[torch.eye(num_classes),
torch.zeros(num_classes, num_classes)],
dim=0)
return reasonings
# Aliases - Components
ZCI = ZerosCompInitializer
OCI = OnesCompInitializer
FVCI = FillValueCompInitializer
LCI = LiteralCompInitializer
UCI = UniformCompInitializer
RNCI = RandomNormalCompInitializer
SCI = SelectionCompInitializer
MCI = MeanCompInitializer
SSCI = StratifiedSelectionCompInitializer
SMCI = StratifiedMeanCompInitializer
PPRI = PurePositiveReasoningsInitializer

View File

@ -23,7 +23,10 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
return mesh, xx, yy return mesh, xx, yy
def parse_distribution(user_distribution: Union[dict, list, tuple]): def parse_distribution(
user_distribution: Union[dict[int, int], dict[str, str], list[int],
tuple[int]]
) -> dict[int, int]:
"""Parse user-provided distribution. """Parse user-provided distribution.
Return a dictionary with integer keys that represent the class labels and Return a dictionary with integer keys that represent the class labels and
@ -51,14 +54,15 @@ def parse_distribution(user_distribution: Union[dict, list, tuple]):
if isinstance(user_distribution, dict): if isinstance(user_distribution, dict):
if "num_classes" in user_distribution.keys(): if "num_classes" in user_distribution.keys():
num_classes = user_distribution["num_classes"] num_classes = int(user_distribution["num_classes"])
per_class = user_distribution["per_class"] per_class = int(user_distribution["per_class"])
return from_list([per_class] * num_classes) return from_list([per_class] * num_classes)
else: else:
return user_distribution return user_distribution
elif isinstance(user_distribution, tuple): elif isinstance(user_distribution, tuple):
assert len(user_distribution) == 2 assert len(user_distribution) == 2
num_classes, per_class = user_distribution num_classes, per_class = user_distribution
num_classes, per_class = int(num_classes), int(per_class)
return from_list([per_class] * num_classes) return from_list([per_class] * num_classes)
elif isinstance(user_distribution, list): elif isinstance(user_distribution, list):
return from_list(user_distribution) return from_list(user_distribution)