diff --git a/prototorch/components/components.py b/prototorch/components/components.py deleted file mode 100644 index 7ae1df6..0000000 --- a/prototorch/components/components.py +++ /dev/null @@ -1,235 +0,0 @@ -"""ProtoTorch Components.""" - -import warnings - -import torch -from prototorch.components.initializers import (ClassAwareInitializer, - ComponentsInitializer, - EqualLabelsInitializer, - UnequalLabelsInitializer, - ZeroReasoningsInitializer) -from torch.nn.parameter import Parameter - -from .initializers import parse_data_arg - - -def get_labels_initializer(distribution): - if isinstance(distribution, dict): - if "num_classes" in distribution.keys(): - labels = EqualLabelsInitializer( - distribution["num_classes"], - distribution["prototypes_per_class"]) - else: - clabels = list(distribution.keys()) - dist = list(distribution.values()) - labels = UnequalLabelsInitializer(dist, clabels) - elif isinstance(distribution, tuple): - num_classes, prototypes_per_class = distribution - labels = EqualLabelsInitializer(num_classes, prototypes_per_class) - elif isinstance(distribution, list): - labels = UnequalLabelsInitializer(distribution) - else: - msg = f"`distribution` not understood." \ - f"You have provided: {distribution=}." - raise ValueError(msg) - return labels - - -def _precheck_initializer(initializer): - if not isinstance(initializer, ComponentsInitializer): - emsg = f"`initializer` has to be some subtype of " \ - f"{ComponentsInitializer}. " \ - f"You have provided: {initializer=} instead." - raise TypeError(emsg) - - -class Components(torch.nn.Module): - """Components is a set of learnable Tensors.""" - def __init__(self, - num_components=None, - initializer=None, - *, - initialized_components=None): - super().__init__() - - # Ignore all initialization settings if initialized_components is given. - if initialized_components is not None: - self._register_components(initialized_components) - if num_components is not None or initializer is not None: - wmsg = "Arguments ignored while initializing Components" - warnings.warn(wmsg) - else: - self._initialize_components(num_components, initializer) - - @property - def num_components(self): - return len(self._components) - - def _register_components(self, components): - self.register_parameter("_components", Parameter(components)) - - def _initialize_components(self, num_components, initializer): - _precheck_initializer(initializer) - _components = initializer.generate(num_components) - self._register_components(_components) - - def add_components(self, - num=1, - initializer=None, - *, - initialized_components=None): - if initialized_components is not None: - _components = torch.cat([self._components, initialized_components]) - else: - _precheck_initializer(initializer) - _new = initializer.generate(num) - _components = torch.cat([self._components, _new]) - self._register_components(_components) - - def remove_components(self, indices=None): - mask = torch.ones(self.num_components, dtype=torch.bool) - mask[indices] = False - _components = self._components[mask] - self._register_components(_components) - return mask - - @property - def components(self): - """Tensor containing the component tensors.""" - return self._components.detach() - - def forward(self): - return self._components - - def extra_repr(self): - return f"(components): (shape: {tuple(self._components.shape)})" - - -class LabeledComponents(Components): - """LabeledComponents generate a set of components and a set of labels. - - Every Component has a label assigned. - """ - def __init__(self, - distribution=None, - initializer=None, - *, - initialized_components=None): - if initialized_components is not None: - components, component_labels = parse_data_arg( - initialized_components) - super().__init__(initialized_components=components) - # self._labels = component_labels - self._labels = component_labels - else: - labels_initializer = get_labels_initializer(distribution) - self.initial_distribution = labels_initializer.distribution - _labels = labels.generate() - super().__init__(len(_labels), initializer=initializer) - self._register_labels(_labels) - - def _register_labels(self, labels): - self.register_buffer("_labels", labels) - - @property - def distribution(self): - clabels, counts = torch.unique(self._labels, - sorted=True, - return_counts=True) - return dict(zip(clabels.tolist(), counts.tolist())) - - def _initialize_components(self, num_components, initializer): - if isinstance(initializer, ClassAwareInitializer): - _precheck_initializer(initializer) - _components = initializer.generate(num_components, - self.initial_distribution) - self._register_components(_components) - else: - super()._initialize_components(num_components, initializer) - - def add_components(self, distribution, initializer): - _precheck_initializer(initializer) - - # Labels - labels_initializer = get_labels_initializer(distribution) - new_labels = labels_initializer.generate() - _labels = torch.cat([self._labels, new_labels]) - self._register_labels(_labels) - - # Components - if isinstance(initializer, ClassAwareInitializer): - _new = initializer.generate(len(new_labels), distribution) - else: - _new = initializer.generate(len(new_labels)) - _components = torch.cat([self._components, _new]) - self._register_components(_components) - - def remove_components(self, indices=None): - # Components - mask = super().remove_components(indices) - - # Labels - _labels = self._labels[mask] - self._register_labels(_labels) - - @property - def component_labels(self): - """Tensor containing the component tensors.""" - return self._labels.detach() - - def forward(self): - return super().forward(), self._labels - - -class ReasoningComponents(Components): - """ReasoningComponents generate a set of components and a set of reasoning matrices. - - Every Component has a reasoning matrix assigned. - - A reasoning matrix is a Nx2 matrix, where N is the number of Classes. The - first element is called positive reasoning :math:`p`, the second negative - reasoning :math:`n`. A components can reason in favour (positive) of a - class, against (negative) a class or not at all (neutral). - - It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0 - \leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a - three element probability distribution. - - """ - def __init__(self, - distribution=None, - initializer=None, - reasoning_initializer=None, - *, - initialized_components=None): - if initialized_components is not None: - components, reasonings = initialized_components - super().__init__(initialized_components=components) - self.register_parameter("_reasonings", reasonings) - else: - labels_initializer = get_labels_initializer(distribution) - self.initial_distribution = labels_initializer.distribution - super().__init__(len(self.initial_distribution), - initializer=initializer) - reasonings = reasoning_initializer.generate() - self._register_reasonings(reasonings) - - def _initialize_reasonings(self, reasoning_initializer): - if isinstance(reasonings, tuple): - num_classes, num_components = reasonings - reasonings = ZeroReasoningsInitializer(num_classes, num_components) - - _reasonings = reasonings.generate() - self.register_parameter("_reasonings", _reasonings) - - @property - def reasonings(self): - """Returns Reasoning Matrix. - - Dimension NxCx2 - - """ - return self._reasonings.detach() - - def forward(self): - return super().forward(), self._reasonings diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py deleted file mode 100644 index 8839451..0000000 --- a/prototorch/components/initializers.py +++ /dev/null @@ -1,225 +0,0 @@ -"""ProtoTroch Initializers.""" -import warnings -from collections.abc import Iterable -from itertools import chain -from typing import List - -import torch -from torch.utils.data import DataLoader, Dataset - - -def parse_data_arg(data_arg): - if isinstance(data_arg, Dataset): - data_arg = DataLoader(data_arg, batch_size=len(data_arg)) - - if isinstance(data_arg, DataLoader): - data = torch.tensor([]) - targets = torch.tensor([]) - for x, y in data_arg: - data = torch.cat([data, x]) - targets = torch.cat([targets, y]) - else: - data, targets = data_arg - if not isinstance(data, torch.Tensor): - wmsg = f"Converting data to {torch.Tensor}." - warnings.warn(wmsg) - data = torch.Tensor(data) - if not isinstance(targets, torch.Tensor): - wmsg = f"Converting targets to {torch.Tensor}." - warnings.warn(wmsg) - targets = torch.Tensor(targets) - return data, targets - - -def get_subinitializers(data, targets, clabels, subinit_type): - initializers = dict() - for clabel in clabels: - class_data = data[targets == clabel] - class_initializer = subinit_type(class_data) - initializers[clabel] = (class_initializer) - return initializers - - -# Components -class ComponentsInitializer(object): - def generate(self, number_of_components): - raise NotImplementedError("Subclasses should implement this!") - - -class DimensionAwareInitializer(ComponentsInitializer): - def __init__(self, dims): - super().__init__() - if isinstance(dims, Iterable): - self.components_dims = tuple(dims) - else: - self.components_dims = (dims, ) - - -class OnesInitializer(DimensionAwareInitializer): - def __init__(self, dims, scale=1.0): - super().__init__(dims) - self.scale = scale - - def generate(self, length): - gen_dims = (length, ) + self.components_dims - return torch.ones(gen_dims) * self.scale - - -class ZerosInitializer(DimensionAwareInitializer): - def generate(self, length): - gen_dims = (length, ) + self.components_dims - return torch.zeros(gen_dims) - - -class UniformInitializer(DimensionAwareInitializer): - def __init__(self, dims, minimum=0.0, maximum=1.0, scale=1.0): - super().__init__(dims) - self.minimum = minimum - self.maximum = maximum - self.scale = scale - - def generate(self, length): - gen_dims = (length, ) + self.components_dims - return torch.ones(gen_dims).uniform_(self.minimum, - self.maximum) * self.scale - - -class DataAwareInitializer(ComponentsInitializer): - def __init__(self, data, transform=torch.nn.Identity()): - super().__init__() - self.data = data - self.transform = transform - - def __del__(self): - del self.data - - -class SelectionInitializer(DataAwareInitializer): - def generate(self, length): - indices = torch.LongTensor(length).random_(0, len(self.data)) - return self.transform(self.data[indices]) - - -class MeanInitializer(DataAwareInitializer): - def generate(self, length): - mean = torch.mean(self.data, dim=0) - repeat_dim = [length] + [1] * len(mean.shape) - return self.transform(mean.repeat(repeat_dim)) - - -class ClassAwareInitializer(DataAwareInitializer): - def __init__(self, data, transform=torch.nn.Identity()): - data, targets = parse_data_arg(data) - super().__init__(data, transform) - self.targets = targets - self.clabels = torch.unique(self.targets).int().tolist() - self.num_classes = len(self.clabels) - - def _get_samples_from_initializer(self, length, dist): - if not dist: - per_class = length // self.num_classes - dist = dict(zip(self.clabels, self.num_classes * [per_class])) - if isinstance(dist, list): - dist = dict(zip(self.clabels, dist)) - samples = [self.initializers[k].generate(n) for k, n in dist.items()] - out = torch.vstack(samples) - with torch.no_grad(): - out = self.transform(out) - return out - - def __del__(self): - del self.data - del self.targets - - -class StratifiedMeanInitializer(ClassAwareInitializer): - def __init__(self, data, **kwargs): - super().__init__(data, **kwargs) - self.initializers = get_subinitializers(self.data, self.targets, - self.clabels, MeanInitializer) - - def generate(self, length, dist): - samples = self._get_samples_from_initializer(length, dist) - return samples - - -class StratifiedSelectionInitializer(ClassAwareInitializer): - def __init__(self, data, noise=None, **kwargs): - super().__init__(data, **kwargs) - self.noise = noise - self.initializers = get_subinitializers(self.data, self.targets, - self.clabels, - SelectionInitializer) - - def add_noise_v1(self, x): - return x + self.noise - - def add_noise_v2(self, x): - """Shifts some dimensions of the data randomly.""" - n1 = torch.rand_like(x) - n2 = torch.rand_like(x) - mask = torch.bernoulli(n1) - torch.bernoulli(n2) - return x + (self.noise * mask) - - def generate(self, length, dist): - samples = self._get_samples_from_initializer(length, dist) - if self.noise is not None: - samples = self.add_noise_v1(samples) - return samples - - -# Labels -class LabelsInitializer: - def generate(self): - raise NotImplementedError("Subclasses should implement this!") - - -class UnequalLabelsInitializer(LabelsInitializer): - def __init__(self, dist, clabels=None): - self.dist = dist - self.clabels = clabels or range(len(self.dist)) - - @property - def distribution(self) -> List: - return self.dist - - def generate(self): - targets = list( - chain(*[[i] * n for i, n in zip(self.clabels, self.dist)])) - return torch.LongTensor(targets) - - -class EqualLabelsInitializer(LabelsInitializer): - def __init__(self, classes, per_class): - self.classes = classes - self.per_class = per_class - - @property - def distribution(self) -> List: - return self.classes * [self.per_class] - - def generate(self): - return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten() - - -# Reasonings -class ReasoningsInitializer: - def generate(self, length): - raise NotImplementedError("Subclasses should implement this!") - - -class ZeroReasoningsInitializer(ReasoningsInitializer): - def __init__(self, classes, length): - self.classes = classes - self.length = length - - def generate(self): - return torch.zeros((self.length, self.classes, 2)) - - -# Aliases -SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer -SMI = StratifiedMeanInitializer -Random = RandomInitializer = UniformInitializer -Zeros = ZerosInitializer -Ones = OnesInitializer diff --git a/prototorch/components/labels.py b/prototorch/components/labels.py deleted file mode 100644 index bf2620d..0000000 --- a/prototorch/components/labels.py +++ /dev/null @@ -1,86 +0,0 @@ -"""ProtoTorch Labels.""" - -import torch -from prototorch.components.components import get_labels_initializer -from prototorch.components.initializers import (ClassAwareInitializer, - ComponentsInitializer, - EqualLabelsInitializer, - UnequalLabelsInitializer) -from torch.nn.parameter import Parameter - - -def get_labels_initializer(distribution): - if isinstance(distribution, dict): - if "num_classes" in distribution.keys(): - labels = EqualLabelsInitializer( - distribution["num_classes"], - distribution["prototypes_per_class"]) - else: - clabels = list(distribution.keys()) - dist = list(distribution.values()) - labels = UnequalLabelsInitializer(dist, clabels) - elif isinstance(distribution, tuple): - num_classes, prototypes_per_class = distribution - labels = EqualLabelsInitializer(num_classes, prototypes_per_class) - elif isinstance(distribution, list): - labels = UnequalLabelsInitializer(distribution) - else: - msg = f"`distribution` not understood." \ - f"You have provided: {distribution=}." - raise ValueError(msg) - return labels - - -class Labels(torch.nn.Module): - def __init__(self, - distribution=None, - initializer=None, - *, - initialized_labels=None): - _labels = self.get_labels(distribution, - initializer, - initialized_labels=initialized_labels) - self._register_labels(_labels) - - def _register_labels(self, labels): - # self.register_buffer("_labels", labels) - self.register_parameter("_labels", - Parameter(labels, requires_grad=False)) - - def get_labels(self, - distribution=None, - initializer=None, - *, - initialized_labels=None): - if initialized_labels is not None: - _labels = initialized_labels - else: - labels_initializer = initializer or get_labels_initializer( - distribution) - self.initial_distribution = labels_initializer.distribution - _labels = labels_initializer.generate() - return _labels - - def add_labels(self, - distribution=None, - initializer=None, - *, - initialized_labels=None): - new_labels = self.get_labels(distribution, - initializer, - initialized_labels=initialized_labels) - _labels = torch.cat([self._labels, new_labels]) - self._register_labels(_labels) - - def remove_labels(self, indices=None): - mask = torch.ones(len(self._labels, dtype=torch.bool)) - mask[indices] = False - _labels = self._labels[mask] - self._register_labels(_labels) - - @property - def labels(self): - return self._labels - - def forward(self): - return self._labels diff --git a/prototorch/components/__init__.py b/prototorch/core/__init__.py similarity index 76% rename from prototorch/components/__init__.py rename to prototorch/core/__init__.py index 69293cb..17be644 100644 --- a/prototorch/components/__init__.py +++ b/prototorch/core/__init__.py @@ -1,3 +1,5 @@ +"""ProtoTorch core""" + from .components import * from .initializers import * from .labels import * diff --git a/prototorch/core/components.py b/prototorch/core/components.py new file mode 100644 index 0000000..53555af --- /dev/null +++ b/prototorch/core/components.py @@ -0,0 +1,243 @@ +"""ProtoTorch components""" + +import inspect +from typing import Union + +import torch +from torch.nn.parameter import Parameter + +from ..utils import parse_distribution +from .initializers import ( + AbstractComponentsInitializer, + AbstractLabelsInitializer, + AbstractReasoningsInitializer, + ClassAwareCompInitializer, + LabelsInitializer, +) + + +def validate_initializer(initializer, instanceof): + if not isinstance(initializer, instanceof): + emsg = f"`initializer` has to be an instance " \ + f"of some subtype of {instanceof}. " \ + f"You have provided: {initializer} instead. " + helpmsg = "" + if inspect.isclass(initializer): + helpmsg = f"Perhaps you meant to say, {initializer.__name__}() " \ + f"with the brackets instead of just {initializer.__name__}?" + raise TypeError(emsg + helpmsg) + return True + + +def validate_components_initializer(initializer): + return validate_initializer(initializer, AbstractComponentsInitializer) + + +def validate_labels_initializer(initializer): + return validate_initializer(initializer, AbstractLabelsInitializer) + + +def validate_reasonings_initializer(initializer): + return validate_initializer(initializer, AbstractReasoningsInitializer) + + +class AbstractComponents(torch.nn.Module): + """Abstract class for all components modules.""" + @property + def num_components(self): + """Current number of components.""" + return len(self._components) + + @property + def components(self): + """Detached Tensor containing the components.""" + return self._components.detach() + + def _register_components(self, components): + self.register_parameter("_components", Parameter(components)) + + def extra_repr(self): + return f"(components): (shape: {tuple(self._components.shape)})" + + +class Components(AbstractComponents): + """A set of adaptable Tensors.""" + def __init__(self, num_components: int, + initializer: AbstractComponentsInitializer, **kwargs): + super().__init__(**kwargs) + self.add_components(num_components, initializer) + + def add_components(self, num: int, + initializer: AbstractComponentsInitializer): + """Add new components.""" + assert validate_components_initializer(initializer) + new_components = initializer.generate(num) + # Register + if hasattr(self, "_components"): + _components = torch.cat([self._components, new_components]) + else: + _components = new_components + self._register_components(_components) + return new_components + + def remove_components(self, indices): + """Remove components at specified indices.""" + mask = torch.ones(self.num_components, dtype=torch.bool) + mask[indices] = False + _components = self._components[mask] + self._register_components(_components) + return mask + + def forward(self): + """Simply return the components parameter Tensor.""" + return self._components + + +class LabeledComponents(AbstractComponents): + """A set of adaptable components and corresponding unadaptable labels.""" + def __init__(self, distribution: Union[dict, list, tuple], + components_initializer: AbstractComponentsInitializer, + labels_initializer: AbstractLabelsInitializer, **kwargs): + super().__init__(**kwargs) + self.add_components(distribution, components_initializer, + labels_initializer) + + @property + def component_labels(self): + """Tensor containing the component tensors.""" + return self._labels.detach() + + def _register_labels(self, labels): + self.register_buffer("_labels", labels) + + def add_components( + self, + distribution, + components_initializer, + labels_initializer: AbstractLabelsInitializer = LabelsInitializer()): + # Checks + assert validate_components_initializer(components_initializer) + assert validate_labels_initializer(labels_initializer) + + distribution = parse_distribution(distribution) + + # Generate new components + if isinstance(components_initializer, ClassAwareCompInitializer): + new_components = components_initializer.generate(distribution) + else: + num_components = sum(distribution.values()) + new_components = components_initializer.generate(num_components) + + # Generate new labels + new_labels = labels_initializer.generate(distribution) + + # Register + if hasattr(self, "_components"): + _components = torch.cat([self._components, new_components]) + else: + _components = new_components + if hasattr(self, "_labels"): + _labels = torch.cat([self._labels, new_labels]) + else: + _labels = new_labels + self._register_components(_components) + self._register_labels(_labels) + + return new_components, new_labels + + def remove_components(self, indices): + """Remove components and labels at specified indices.""" + mask = torch.ones(self.num_components, dtype=torch.bool) + mask[indices] = False + _components = self._components[mask] + _labels = self._labels[mask] + self._register_components(_components) + self._register_labels(_labels) + return mask + + def forward(self): + """Simply return the components parameter Tensor and labels.""" + return self._components, self._labels + + +class ReasoningComponents(AbstractComponents): + """A set of components and a corresponding adapatable reasoning matrices. + + Every component has its own reasoning matrix. + + A reasoning matrix is an Nx2 matrix, where N is the number of classes. The + first element is called positive reasoning :math:`p`, the second negative + reasoning :math:`n`. A components can reason in favour (positive) of a + class, against (negative) a class or not at all (neutral). + + It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0 + \leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a + three element probability distribution. + + """ + def __init__(self, distribution: Union[dict, list, tuple], + components_initializer: AbstractComponentsInitializer, + reasonings_initializer: AbstractReasoningsInitializer, + **kwargs): + super().__init__(**kwargs) + self.add_components(distribution, components_initializer, + reasonings_initializer) + + @property + def reasonings(self): + """Returns Reasoning Matrix. + + Dimension NxCx2 + + """ + return self._reasonings.detach() + + def _register_reasonings(self, reasonings): + self.register_parameter("_reasonings", Parameter(reasonings)) + + def add_components(self, distribution, components_initializer, + reasonings_initializer: AbstractReasoningsInitializer): + # Checks + assert validate_components_initializer(components_initializer) + assert validate_reasonings_initializer(reasonings_initializer) + + distribution = parse_distribution(distribution) + + # Generate new components + if isinstance(components_initializer, ClassAwareCompInitializer): + new_components = components_initializer.generate(distribution) + else: + num_components = sum(distribution.values()) + new_components = components_initializer.generate(num_components) + + # Generate new reasonings + new_reasonings = reasonings_initializer.generate(distribution) + + # Register + if hasattr(self, "_components"): + _components = torch.cat([self._components, new_components]) + else: + _components = new_components + if hasattr(self, "_reasonings"): + _reasonings = torch.cat([self._reasonings, new_reasonings]) + else: + _reasonings = new_reasonings + self._register_components(_components) + self._register_reasonings(_reasonings) + + return new_components, new_reasonings + + def remove_components(self, indices): + """Remove components and labels at specified indices.""" + mask = torch.ones(self.num_components, dtype=torch.bool) + mask[indices] = False + _components = self._components[mask] + # TODO + # _reasonings = self._reasonings[mask] + self._register_components(_components) + # self._register_reasonings(_reasonings) + return mask + + def forward(self): + """Simply return the components and reasonings.""" + return self._components, self._reasonings diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py new file mode 100644 index 0000000..ba48ffd --- /dev/null +++ b/prototorch/core/initializers.py @@ -0,0 +1,258 @@ +"""ProtoTorch code initializers""" + +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import Union + +import torch + +from ..utils import parse_data_arg, parse_distribution + + +# Components +class AbstractComponentsInitializer(ABC): + """Abstract class for all components initializers.""" + ... + + +class ShapeAwareCompInitializer(AbstractComponentsInitializer): + """Abstract class for all dimension-aware components initializers.""" + def __init__(self, shape: Union[Iterable, int]): + if isinstance(shape, Iterable): + self.component_shape = tuple(shape) + else: + self.component_shape = (shape, ) + + @abstractmethod + def generate(self, num_components: int): + ... + + +class DataAwareCompInitializer(AbstractComponentsInitializer): + """Abstract class for all data-aware components initializers. + + Components generated by data-aware components initializers inherit the shape + of the provided data. + + """ + def __init__(self, + data, + noise: float = 0.0, + transform: callable = torch.nn.Identity()): + self.data = data + self.noise = noise + self.transform = transform + + def generate_end_hook(self, samples): + drift = torch.rand_like(samples) * self.noise + components = self.transform(samples + drift) + return components + + @abstractmethod + def generate(self, num_components: int): + ... + return self.generate_end_hook(...) + + def __del__(self): + del self.data + + +class ClassAwareCompInitializer(AbstractComponentsInitializer): + """Abstract class for all class-aware components initializers. + + Components generated by class-aware components initializers inherit the shape + of the provided data. + + """ + def __init__(self, + data, + noise: float = 0.0, + transform: callable = torch.nn.Identity()): + self.data, self.targets = parse_data_arg(data) + self.noise = noise + self.transform = transform + self.clabels = torch.unique(self.targets).int().tolist() + self.num_classes = len(self.clabels) + + @property + @abstractmethod + def subinit_type(self) -> DataAwareCompInitializer: + ... + + def generate(self, distribution: Union[dict, list, tuple]): + distribution = parse_distribution(distribution) + initializers = { + k: self.subinit_type(self.data[self.targets == k]) + for k in distribution.keys() + } + components = torch.tensor([]) + for k, v in distribution.items(): + stratified_data = self.data[self.targets == k] + # skip transform here + initializer = self.subinit_type( + stratified_data, + noise=self.noise, + transform=self.transform, + ) + samples = initializer.generate(num_components=v) + components = torch.cat([components, samples]) + return components + + def __del__(self): + del self.data + del self.targets + + +class LiteralCompInitializer(DataAwareCompInitializer): + """'Generate' the provided components. + + Use this to 'generate' pre-initialized components from elsewhere. + + """ + def generate(self, num_components: int): + """Ignore `num_components` and simply return transformed `self.data`.""" + components = self.transform(self.data) + return components + + +class ZerosCompInitializer(ShapeAwareCompInitializer): + """Generate zeros corresponding to the components shape.""" + def generate(self, num_components: int): + components = torch.zeros((num_components, ) + self.component_shape) + return components + + +class OnesCompInitializer(ShapeAwareCompInitializer): + """Generate ones corresponding to the components shape.""" + def generate(self, num_components: int): + components = torch.ones((num_components, ) + self.component_shape) + return components + + +class FillValueCompInitializer(OnesCompInitializer): + """Generate components with the provided `fill_value`.""" + def __init__(self, shape, fill_value: float = 1.0): + super().__init__(shape) + self.fill_value = fill_value + + def generate(self, num_components: int): + ones = super().generate(num_components) + components = ones.fill_(self.fill_value) + return components + + +class UniformCompInitializer(OnesCompInitializer): + """Generate components by sampling from a continuous uniform distribution.""" + def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0): + super().__init__(shape) + self.minimum = minimum + self.maximum = maximum + self.scale = scale + + def generate(self, num_components: int): + ones = super().generate(num_components) + components = self.scale * ones.uniform_(self.minimum, self.maximum) + return components + + +class RandomNormalCompInitializer(OnesCompInitializer): + """Generate components by sampling from a standard normal distribution.""" + def __init__(self, shape, scale=1.0): + super().__init__(shape) + self.scale = scale + + def generate(self, num_components: int): + ones = super().generate(num_components) + components = self.scale * torch.randn_like(ones) + return components + + +class SelectionCompInitializer(DataAwareCompInitializer): + """Generate components by uniformly sampling from the provided data.""" + def generate(self, num_components: int): + indices = torch.LongTensor(num_components).random_(0, len(self.data)) + samples = self.data[indices] + components = self.generate_end_hook(samples) + return components + + +class MeanCompInitializer(DataAwareCompInitializer): + """Generate components by computing the mean of the provided data.""" + def generate(self, num_components: int): + mean = torch.mean(self.data, dim=0) + repeat_dim = [num_components] + [1] * len(mean.shape) + samples = mean.repeat(repeat_dim) + components = self.generate_end_hook(samples) + return components + + +class StratifiedSelectionCompInitializer(ClassAwareCompInitializer): + """Generate components using stratified sampling from the provided data.""" + @property + def subinit_type(self): + return SelectionCompInitializer + + +class StratifiedMeanCompInitializer(ClassAwareCompInitializer): + """Generate components at stratified means of the provided data.""" + @property + def subinit_type(self): + return MeanCompInitializer + + +# Labels +class AbstractLabelsInitializer(ABC): + """Abstract class for all labels initializers.""" + @abstractmethod + def generate(self, distribution: Union[dict, list, tuple]): + ... + + +class LabelsInitializer(AbstractLabelsInitializer): + """Generate labels with `self.distribution`.""" + def __init__(self, override_labels: list = []): + self.override_labels = override_labels + + def generate(self, distribution: Union[dict, list, tuple]): + distribution = parse_distribution(distribution) + labels = [] + for k, v in distribution.items(): + labels.extend([k] * v) + labels = torch.LongTensor(labels) + return labels + + +# Reasonings +class AbstractReasoningsInitializer(ABC): + """Abstract class for all reasonings initializers.""" + @abstractmethod + def generate(self, distribution: Union[dict, list, tuple]): + ... + + +class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): + """Generate labels with `self.distribution`.""" + def generate(self, distribution: Union[dict, list, tuple]): + distribution = parse_distribution(distribution) + num_classes = len(distribution.keys()) + num_components = sum(distribution.values()) + assert num_classes == num_components + reasonings = torch.stack( + [torch.eye(num_classes), + torch.zeros(num_classes, num_classes)], + dim=0) + return reasonings + + +# Aliases - Components +ZCI = ZerosCompInitializer +OCI = OnesCompInitializer +FVCI = FillValueCompInitializer +LCI = LiteralCompInitializer +UCI = UniformCompInitializer +RNCI = RandomNormalCompInitializer +SCI = SelectionCompInitializer +MCI = MeanCompInitializer +SSCI = StratifiedSelectionCompInitializer +SMCI = StratifiedMeanCompInitializer +PPRI = PurePositiveReasoningsInitializer diff --git a/prototorch/utils/utils.py b/prototorch/utils/utils.py index 316d5eb..b2058cd 100644 --- a/prototorch/utils/utils.py +++ b/prototorch/utils/utils.py @@ -23,7 +23,10 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100): return mesh, xx, yy -def parse_distribution(user_distribution: Union[dict, list, tuple]): +def parse_distribution( + user_distribution: Union[dict[int, int], dict[str, str], list[int], + tuple[int]] +) -> dict[int, int]: """Parse user-provided distribution. Return a dictionary with integer keys that represent the class labels and @@ -51,14 +54,15 @@ def parse_distribution(user_distribution: Union[dict, list, tuple]): if isinstance(user_distribution, dict): if "num_classes" in user_distribution.keys(): - num_classes = user_distribution["num_classes"] - per_class = user_distribution["per_class"] + num_classes = int(user_distribution["num_classes"]) + per_class = int(user_distribution["per_class"]) return from_list([per_class] * num_classes) else: return user_distribution elif isinstance(user_distribution, tuple): assert len(user_distribution) == 2 num_classes, per_class = user_distribution + num_classes, per_class = int(num_classes), int(per_class) return from_list([per_class] * num_classes) elif isinstance(user_distribution, list): return from_list(user_distribution)