[REFACTOR] Clean and move components and initializers into core
This commit is contained in:
		@@ -1,235 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch Components."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
from prototorch.components.initializers import (ClassAwareInitializer,
 | 
					 | 
				
			||||||
                                                ComponentsInitializer,
 | 
					 | 
				
			||||||
                                                EqualLabelsInitializer,
 | 
					 | 
				
			||||||
                                                UnequalLabelsInitializer,
 | 
					 | 
				
			||||||
                                                ZeroReasoningsInitializer)
 | 
					 | 
				
			||||||
from torch.nn.parameter import Parameter
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .initializers import parse_data_arg
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_labels_initializer(distribution):
 | 
					 | 
				
			||||||
    if isinstance(distribution, dict):
 | 
					 | 
				
			||||||
        if "num_classes" in distribution.keys():
 | 
					 | 
				
			||||||
            labels = EqualLabelsInitializer(
 | 
					 | 
				
			||||||
                distribution["num_classes"],
 | 
					 | 
				
			||||||
                distribution["prototypes_per_class"])
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            clabels = list(distribution.keys())
 | 
					 | 
				
			||||||
            dist = list(distribution.values())
 | 
					 | 
				
			||||||
            labels = UnequalLabelsInitializer(dist, clabels)
 | 
					 | 
				
			||||||
    elif isinstance(distribution, tuple):
 | 
					 | 
				
			||||||
        num_classes, prototypes_per_class = distribution
 | 
					 | 
				
			||||||
        labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
 | 
					 | 
				
			||||||
    elif isinstance(distribution, list):
 | 
					 | 
				
			||||||
        labels = UnequalLabelsInitializer(distribution)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        msg = f"`distribution` not understood." \
 | 
					 | 
				
			||||||
            f"You have provided: {distribution=}."
 | 
					 | 
				
			||||||
        raise ValueError(msg)
 | 
					 | 
				
			||||||
    return labels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _precheck_initializer(initializer):
 | 
					 | 
				
			||||||
    if not isinstance(initializer, ComponentsInitializer):
 | 
					 | 
				
			||||||
        emsg = f"`initializer` has to be some subtype of " \
 | 
					 | 
				
			||||||
            f"{ComponentsInitializer}. " \
 | 
					 | 
				
			||||||
            f"You have provided: {initializer=} instead."
 | 
					 | 
				
			||||||
        raise TypeError(emsg)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Components(torch.nn.Module):
 | 
					 | 
				
			||||||
    """Components is a set of learnable Tensors."""
 | 
					 | 
				
			||||||
    def __init__(self,
 | 
					 | 
				
			||||||
                 num_components=None,
 | 
					 | 
				
			||||||
                 initializer=None,
 | 
					 | 
				
			||||||
                 *,
 | 
					 | 
				
			||||||
                 initialized_components=None):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Ignore all initialization settings if initialized_components is given.
 | 
					 | 
				
			||||||
        if initialized_components is not None:
 | 
					 | 
				
			||||||
            self._register_components(initialized_components)
 | 
					 | 
				
			||||||
            if num_components is not None or initializer is not None:
 | 
					 | 
				
			||||||
                wmsg = "Arguments ignored while initializing Components"
 | 
					 | 
				
			||||||
                warnings.warn(wmsg)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self._initialize_components(num_components, initializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def num_components(self):
 | 
					 | 
				
			||||||
        return len(self._components)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _register_components(self, components):
 | 
					 | 
				
			||||||
        self.register_parameter("_components", Parameter(components))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _initialize_components(self, num_components, initializer):
 | 
					 | 
				
			||||||
        _precheck_initializer(initializer)
 | 
					 | 
				
			||||||
        _components = initializer.generate(num_components)
 | 
					 | 
				
			||||||
        self._register_components(_components)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def add_components(self,
 | 
					 | 
				
			||||||
                       num=1,
 | 
					 | 
				
			||||||
                       initializer=None,
 | 
					 | 
				
			||||||
                       *,
 | 
					 | 
				
			||||||
                       initialized_components=None):
 | 
					 | 
				
			||||||
        if initialized_components is not None:
 | 
					 | 
				
			||||||
            _components = torch.cat([self._components, initialized_components])
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            _precheck_initializer(initializer)
 | 
					 | 
				
			||||||
            _new = initializer.generate(num)
 | 
					 | 
				
			||||||
            _components = torch.cat([self._components, _new])
 | 
					 | 
				
			||||||
        self._register_components(_components)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def remove_components(self, indices=None):
 | 
					 | 
				
			||||||
        mask = torch.ones(self.num_components, dtype=torch.bool)
 | 
					 | 
				
			||||||
        mask[indices] = False
 | 
					 | 
				
			||||||
        _components = self._components[mask]
 | 
					 | 
				
			||||||
        self._register_components(_components)
 | 
					 | 
				
			||||||
        return mask
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def components(self):
 | 
					 | 
				
			||||||
        """Tensor containing the component tensors."""
 | 
					 | 
				
			||||||
        return self._components.detach()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self):
 | 
					 | 
				
			||||||
        return self._components
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def extra_repr(self):
 | 
					 | 
				
			||||||
        return f"(components): (shape: {tuple(self._components.shape)})"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class LabeledComponents(Components):
 | 
					 | 
				
			||||||
    """LabeledComponents generate a set of components and a set of labels.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Every Component has a label assigned.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    def __init__(self,
 | 
					 | 
				
			||||||
                 distribution=None,
 | 
					 | 
				
			||||||
                 initializer=None,
 | 
					 | 
				
			||||||
                 *,
 | 
					 | 
				
			||||||
                 initialized_components=None):
 | 
					 | 
				
			||||||
        if initialized_components is not None:
 | 
					 | 
				
			||||||
            components, component_labels = parse_data_arg(
 | 
					 | 
				
			||||||
                initialized_components)
 | 
					 | 
				
			||||||
            super().__init__(initialized_components=components)
 | 
					 | 
				
			||||||
            # self._labels = component_labels
 | 
					 | 
				
			||||||
            self._labels = component_labels
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            labels_initializer = get_labels_initializer(distribution)
 | 
					 | 
				
			||||||
            self.initial_distribution = labels_initializer.distribution
 | 
					 | 
				
			||||||
            _labels = labels.generate()
 | 
					 | 
				
			||||||
            super().__init__(len(_labels), initializer=initializer)
 | 
					 | 
				
			||||||
            self._register_labels(_labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _register_labels(self, labels):
 | 
					 | 
				
			||||||
        self.register_buffer("_labels", labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def distribution(self):
 | 
					 | 
				
			||||||
        clabels, counts = torch.unique(self._labels,
 | 
					 | 
				
			||||||
                                       sorted=True,
 | 
					 | 
				
			||||||
                                       return_counts=True)
 | 
					 | 
				
			||||||
        return dict(zip(clabels.tolist(), counts.tolist()))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _initialize_components(self, num_components, initializer):
 | 
					 | 
				
			||||||
        if isinstance(initializer, ClassAwareInitializer):
 | 
					 | 
				
			||||||
            _precheck_initializer(initializer)
 | 
					 | 
				
			||||||
            _components = initializer.generate(num_components,
 | 
					 | 
				
			||||||
                                               self.initial_distribution)
 | 
					 | 
				
			||||||
            self._register_components(_components)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            super()._initialize_components(num_components, initializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def add_components(self, distribution, initializer):
 | 
					 | 
				
			||||||
        _precheck_initializer(initializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Labels
 | 
					 | 
				
			||||||
        labels_initializer = get_labels_initializer(distribution)
 | 
					 | 
				
			||||||
        new_labels = labels_initializer.generate()
 | 
					 | 
				
			||||||
        _labels = torch.cat([self._labels, new_labels])
 | 
					 | 
				
			||||||
        self._register_labels(_labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Components
 | 
					 | 
				
			||||||
        if isinstance(initializer, ClassAwareInitializer):
 | 
					 | 
				
			||||||
            _new = initializer.generate(len(new_labels), distribution)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            _new = initializer.generate(len(new_labels))
 | 
					 | 
				
			||||||
        _components = torch.cat([self._components, _new])
 | 
					 | 
				
			||||||
        self._register_components(_components)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def remove_components(self, indices=None):
 | 
					 | 
				
			||||||
        # Components
 | 
					 | 
				
			||||||
        mask = super().remove_components(indices)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Labels
 | 
					 | 
				
			||||||
        _labels = self._labels[mask]
 | 
					 | 
				
			||||||
        self._register_labels(_labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def component_labels(self):
 | 
					 | 
				
			||||||
        """Tensor containing the component tensors."""
 | 
					 | 
				
			||||||
        return self._labels.detach()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self):
 | 
					 | 
				
			||||||
        return super().forward(), self._labels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ReasoningComponents(Components):
 | 
					 | 
				
			||||||
    """ReasoningComponents generate a set of components and a set of reasoning matrices.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Every Component has a reasoning matrix assigned.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    A reasoning matrix is a Nx2 matrix, where N is the number of Classes. The
 | 
					 | 
				
			||||||
    first element is called positive reasoning :math:`p`, the second negative
 | 
					 | 
				
			||||||
    reasoning :math:`n`. A components can reason in favour (positive) of a
 | 
					 | 
				
			||||||
    class, against (negative) a class or not at all (neutral).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
 | 
					 | 
				
			||||||
    \leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
 | 
					 | 
				
			||||||
    three element probability distribution.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    def __init__(self,
 | 
					 | 
				
			||||||
                 distribution=None,
 | 
					 | 
				
			||||||
                 initializer=None,
 | 
					 | 
				
			||||||
                 reasoning_initializer=None,
 | 
					 | 
				
			||||||
                 *,
 | 
					 | 
				
			||||||
                 initialized_components=None):
 | 
					 | 
				
			||||||
        if initialized_components is not None:
 | 
					 | 
				
			||||||
            components, reasonings = initialized_components
 | 
					 | 
				
			||||||
            super().__init__(initialized_components=components)
 | 
					 | 
				
			||||||
            self.register_parameter("_reasonings", reasonings)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            labels_initializer = get_labels_initializer(distribution)
 | 
					 | 
				
			||||||
            self.initial_distribution = labels_initializer.distribution
 | 
					 | 
				
			||||||
            super().__init__(len(self.initial_distribution),
 | 
					 | 
				
			||||||
                             initializer=initializer)
 | 
					 | 
				
			||||||
            reasonings = reasoning_initializer.generate()
 | 
					 | 
				
			||||||
            self._register_reasonings(reasonings)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _initialize_reasonings(self, reasoning_initializer):
 | 
					 | 
				
			||||||
        if isinstance(reasonings, tuple):
 | 
					 | 
				
			||||||
            num_classes, num_components = reasonings
 | 
					 | 
				
			||||||
            reasonings = ZeroReasoningsInitializer(num_classes, num_components)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        _reasonings = reasonings.generate()
 | 
					 | 
				
			||||||
        self.register_parameter("_reasonings", _reasonings)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def reasonings(self):
 | 
					 | 
				
			||||||
        """Returns Reasoning Matrix.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Dimension NxCx2
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self._reasonings.detach()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self):
 | 
					 | 
				
			||||||
        return super().forward(), self._reasonings
 | 
					 | 
				
			||||||
@@ -1,225 +0,0 @@
 | 
				
			|||||||
"""ProtoTroch Initializers."""
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
from collections.abc import Iterable
 | 
					 | 
				
			||||||
from itertools import chain
 | 
					 | 
				
			||||||
from typing import List
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
from torch.utils.data import DataLoader, Dataset
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def parse_data_arg(data_arg):
 | 
					 | 
				
			||||||
    if isinstance(data_arg, Dataset):
 | 
					 | 
				
			||||||
        data_arg = DataLoader(data_arg, batch_size=len(data_arg))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if isinstance(data_arg, DataLoader):
 | 
					 | 
				
			||||||
        data = torch.tensor([])
 | 
					 | 
				
			||||||
        targets = torch.tensor([])
 | 
					 | 
				
			||||||
        for x, y in data_arg:
 | 
					 | 
				
			||||||
            data = torch.cat([data, x])
 | 
					 | 
				
			||||||
            targets = torch.cat([targets, y])
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        data, targets = data_arg
 | 
					 | 
				
			||||||
        if not isinstance(data, torch.Tensor):
 | 
					 | 
				
			||||||
            wmsg = f"Converting data to {torch.Tensor}."
 | 
					 | 
				
			||||||
            warnings.warn(wmsg)
 | 
					 | 
				
			||||||
            data = torch.Tensor(data)
 | 
					 | 
				
			||||||
        if not isinstance(targets, torch.Tensor):
 | 
					 | 
				
			||||||
            wmsg = f"Converting targets to {torch.Tensor}."
 | 
					 | 
				
			||||||
            warnings.warn(wmsg)
 | 
					 | 
				
			||||||
            targets = torch.Tensor(targets)
 | 
					 | 
				
			||||||
    return data, targets
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_subinitializers(data, targets, clabels, subinit_type):
 | 
					 | 
				
			||||||
    initializers = dict()
 | 
					 | 
				
			||||||
    for clabel in clabels:
 | 
					 | 
				
			||||||
        class_data = data[targets == clabel]
 | 
					 | 
				
			||||||
        class_initializer = subinit_type(class_data)
 | 
					 | 
				
			||||||
        initializers[clabel] = (class_initializer)
 | 
					 | 
				
			||||||
    return initializers
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Components
 | 
					 | 
				
			||||||
class ComponentsInitializer(object):
 | 
					 | 
				
			||||||
    def generate(self, number_of_components):
 | 
					 | 
				
			||||||
        raise NotImplementedError("Subclasses should implement this!")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class DimensionAwareInitializer(ComponentsInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, dims):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        if isinstance(dims, Iterable):
 | 
					 | 
				
			||||||
            self.components_dims = tuple(dims)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.components_dims = (dims, )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class OnesInitializer(DimensionAwareInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, dims, scale=1.0):
 | 
					 | 
				
			||||||
        super().__init__(dims)
 | 
					 | 
				
			||||||
        self.scale = scale
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self, length):
 | 
					 | 
				
			||||||
        gen_dims = (length, ) + self.components_dims
 | 
					 | 
				
			||||||
        return torch.ones(gen_dims) * self.scale
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ZerosInitializer(DimensionAwareInitializer):
 | 
					 | 
				
			||||||
    def generate(self, length):
 | 
					 | 
				
			||||||
        gen_dims = (length, ) + self.components_dims
 | 
					 | 
				
			||||||
        return torch.zeros(gen_dims)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class UniformInitializer(DimensionAwareInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, dims, minimum=0.0, maximum=1.0, scale=1.0):
 | 
					 | 
				
			||||||
        super().__init__(dims)
 | 
					 | 
				
			||||||
        self.minimum = minimum
 | 
					 | 
				
			||||||
        self.maximum = maximum
 | 
					 | 
				
			||||||
        self.scale = scale
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self, length):
 | 
					 | 
				
			||||||
        gen_dims = (length, ) + self.components_dims
 | 
					 | 
				
			||||||
        return torch.ones(gen_dims).uniform_(self.minimum,
 | 
					 | 
				
			||||||
                                             self.maximum) * self.scale
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class DataAwareInitializer(ComponentsInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, data, transform=torch.nn.Identity()):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        self.data = data
 | 
					 | 
				
			||||||
        self.transform = transform
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __del__(self):
 | 
					 | 
				
			||||||
        del self.data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SelectionInitializer(DataAwareInitializer):
 | 
					 | 
				
			||||||
    def generate(self, length):
 | 
					 | 
				
			||||||
        indices = torch.LongTensor(length).random_(0, len(self.data))
 | 
					 | 
				
			||||||
        return self.transform(self.data[indices])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class MeanInitializer(DataAwareInitializer):
 | 
					 | 
				
			||||||
    def generate(self, length):
 | 
					 | 
				
			||||||
        mean = torch.mean(self.data, dim=0)
 | 
					 | 
				
			||||||
        repeat_dim = [length] + [1] * len(mean.shape)
 | 
					 | 
				
			||||||
        return self.transform(mean.repeat(repeat_dim))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ClassAwareInitializer(DataAwareInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, data, transform=torch.nn.Identity()):
 | 
					 | 
				
			||||||
        data, targets = parse_data_arg(data)
 | 
					 | 
				
			||||||
        super().__init__(data, transform)
 | 
					 | 
				
			||||||
        self.targets = targets
 | 
					 | 
				
			||||||
        self.clabels = torch.unique(self.targets).int().tolist()
 | 
					 | 
				
			||||||
        self.num_classes = len(self.clabels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _get_samples_from_initializer(self, length, dist):
 | 
					 | 
				
			||||||
        if not dist:
 | 
					 | 
				
			||||||
            per_class = length // self.num_classes
 | 
					 | 
				
			||||||
            dist = dict(zip(self.clabels, self.num_classes * [per_class]))
 | 
					 | 
				
			||||||
        if isinstance(dist, list):
 | 
					 | 
				
			||||||
            dist = dict(zip(self.clabels, dist))
 | 
					 | 
				
			||||||
        samples = [self.initializers[k].generate(n) for k, n in dist.items()]
 | 
					 | 
				
			||||||
        out = torch.vstack(samples)
 | 
					 | 
				
			||||||
        with torch.no_grad():
 | 
					 | 
				
			||||||
            out = self.transform(out)
 | 
					 | 
				
			||||||
        return out
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __del__(self):
 | 
					 | 
				
			||||||
        del self.data
 | 
					 | 
				
			||||||
        del self.targets
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class StratifiedMeanInitializer(ClassAwareInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, data, **kwargs):
 | 
					 | 
				
			||||||
        super().__init__(data, **kwargs)
 | 
					 | 
				
			||||||
        self.initializers = get_subinitializers(self.data, self.targets,
 | 
					 | 
				
			||||||
                                                self.clabels, MeanInitializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self, length, dist):
 | 
					 | 
				
			||||||
        samples = self._get_samples_from_initializer(length, dist)
 | 
					 | 
				
			||||||
        return samples
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, data, noise=None, **kwargs):
 | 
					 | 
				
			||||||
        super().__init__(data, **kwargs)
 | 
					 | 
				
			||||||
        self.noise = noise
 | 
					 | 
				
			||||||
        self.initializers = get_subinitializers(self.data, self.targets,
 | 
					 | 
				
			||||||
                                                self.clabels,
 | 
					 | 
				
			||||||
                                                SelectionInitializer)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def add_noise_v1(self, x):
 | 
					 | 
				
			||||||
        return x + self.noise
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def add_noise_v2(self, x):
 | 
					 | 
				
			||||||
        """Shifts some dimensions of the data randomly."""
 | 
					 | 
				
			||||||
        n1 = torch.rand_like(x)
 | 
					 | 
				
			||||||
        n2 = torch.rand_like(x)
 | 
					 | 
				
			||||||
        mask = torch.bernoulli(n1) - torch.bernoulli(n2)
 | 
					 | 
				
			||||||
        return x + (self.noise * mask)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self, length, dist):
 | 
					 | 
				
			||||||
        samples = self._get_samples_from_initializer(length, dist)
 | 
					 | 
				
			||||||
        if self.noise is not None:
 | 
					 | 
				
			||||||
            samples = self.add_noise_v1(samples)
 | 
					 | 
				
			||||||
        return samples
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Labels
 | 
					 | 
				
			||||||
class LabelsInitializer:
 | 
					 | 
				
			||||||
    def generate(self):
 | 
					 | 
				
			||||||
        raise NotImplementedError("Subclasses should implement this!")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class UnequalLabelsInitializer(LabelsInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, dist, clabels=None):
 | 
					 | 
				
			||||||
        self.dist = dist
 | 
					 | 
				
			||||||
        self.clabels = clabels or range(len(self.dist))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def distribution(self) -> List:
 | 
					 | 
				
			||||||
        return self.dist
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self):
 | 
					 | 
				
			||||||
        targets = list(
 | 
					 | 
				
			||||||
            chain(*[[i] * n for i, n in zip(self.clabels, self.dist)]))
 | 
					 | 
				
			||||||
        return torch.LongTensor(targets)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class EqualLabelsInitializer(LabelsInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, classes, per_class):
 | 
					 | 
				
			||||||
        self.classes = classes
 | 
					 | 
				
			||||||
        self.per_class = per_class
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def distribution(self) -> List:
 | 
					 | 
				
			||||||
        return self.classes * [self.per_class]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self):
 | 
					 | 
				
			||||||
        return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Reasonings
 | 
					 | 
				
			||||||
class ReasoningsInitializer:
 | 
					 | 
				
			||||||
    def generate(self, length):
 | 
					 | 
				
			||||||
        raise NotImplementedError("Subclasses should implement this!")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ZeroReasoningsInitializer(ReasoningsInitializer):
 | 
					 | 
				
			||||||
    def __init__(self, classes, length):
 | 
					 | 
				
			||||||
        self.classes = classes
 | 
					 | 
				
			||||||
        self.length = length
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(self):
 | 
					 | 
				
			||||||
        return torch.zeros((self.length, self.classes, 2))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Aliases
 | 
					 | 
				
			||||||
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
 | 
					 | 
				
			||||||
SMI = StratifiedMeanInitializer
 | 
					 | 
				
			||||||
Random = RandomInitializer = UniformInitializer
 | 
					 | 
				
			||||||
Zeros = ZerosInitializer
 | 
					 | 
				
			||||||
Ones = OnesInitializer
 | 
					 | 
				
			||||||
@@ -1,86 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch Labels."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
from prototorch.components.components import get_labels_initializer
 | 
					 | 
				
			||||||
from prototorch.components.initializers import (ClassAwareInitializer,
 | 
					 | 
				
			||||||
                                                ComponentsInitializer,
 | 
					 | 
				
			||||||
                                                EqualLabelsInitializer,
 | 
					 | 
				
			||||||
                                                UnequalLabelsInitializer)
 | 
					 | 
				
			||||||
from torch.nn.parameter import Parameter
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_labels_initializer(distribution):
 | 
					 | 
				
			||||||
    if isinstance(distribution, dict):
 | 
					 | 
				
			||||||
        if "num_classes" in distribution.keys():
 | 
					 | 
				
			||||||
            labels = EqualLabelsInitializer(
 | 
					 | 
				
			||||||
                distribution["num_classes"],
 | 
					 | 
				
			||||||
                distribution["prototypes_per_class"])
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            clabels = list(distribution.keys())
 | 
					 | 
				
			||||||
            dist = list(distribution.values())
 | 
					 | 
				
			||||||
            labels = UnequalLabelsInitializer(dist, clabels)
 | 
					 | 
				
			||||||
    elif isinstance(distribution, tuple):
 | 
					 | 
				
			||||||
        num_classes, prototypes_per_class = distribution
 | 
					 | 
				
			||||||
        labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
 | 
					 | 
				
			||||||
    elif isinstance(distribution, list):
 | 
					 | 
				
			||||||
        labels = UnequalLabelsInitializer(distribution)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        msg = f"`distribution` not understood." \
 | 
					 | 
				
			||||||
            f"You have provided: {distribution=}."
 | 
					 | 
				
			||||||
        raise ValueError(msg)
 | 
					 | 
				
			||||||
    return labels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Labels(torch.nn.Module):
 | 
					 | 
				
			||||||
    def __init__(self,
 | 
					 | 
				
			||||||
                 distribution=None,
 | 
					 | 
				
			||||||
                 initializer=None,
 | 
					 | 
				
			||||||
                 *,
 | 
					 | 
				
			||||||
                 initialized_labels=None):
 | 
					 | 
				
			||||||
        _labels = self.get_labels(distribution,
 | 
					 | 
				
			||||||
                                  initializer,
 | 
					 | 
				
			||||||
                                  initialized_labels=initialized_labels)
 | 
					 | 
				
			||||||
        self._register_labels(_labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _register_labels(self, labels):
 | 
					 | 
				
			||||||
        # self.register_buffer("_labels", labels)
 | 
					 | 
				
			||||||
        self.register_parameter("_labels",
 | 
					 | 
				
			||||||
                                Parameter(labels, requires_grad=False))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_labels(self,
 | 
					 | 
				
			||||||
                   distribution=None,
 | 
					 | 
				
			||||||
                   initializer=None,
 | 
					 | 
				
			||||||
                   *,
 | 
					 | 
				
			||||||
                   initialized_labels=None):
 | 
					 | 
				
			||||||
        if initialized_labels is not None:
 | 
					 | 
				
			||||||
            _labels = initialized_labels
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            labels_initializer = initializer or get_labels_initializer(
 | 
					 | 
				
			||||||
                distribution)
 | 
					 | 
				
			||||||
            self.initial_distribution = labels_initializer.distribution
 | 
					 | 
				
			||||||
            _labels = labels_initializer.generate()
 | 
					 | 
				
			||||||
        return _labels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def add_labels(self,
 | 
					 | 
				
			||||||
                   distribution=None,
 | 
					 | 
				
			||||||
                   initializer=None,
 | 
					 | 
				
			||||||
                   *,
 | 
					 | 
				
			||||||
                   initialized_labels=None):
 | 
					 | 
				
			||||||
        new_labels = self.get_labels(distribution,
 | 
					 | 
				
			||||||
                                     initializer,
 | 
					 | 
				
			||||||
                                     initialized_labels=initialized_labels)
 | 
					 | 
				
			||||||
        _labels = torch.cat([self._labels, new_labels])
 | 
					 | 
				
			||||||
        self._register_labels(_labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def remove_labels(self, indices=None):
 | 
					 | 
				
			||||||
        mask = torch.ones(len(self._labels, dtype=torch.bool))
 | 
					 | 
				
			||||||
        mask[indices] = False
 | 
					 | 
				
			||||||
        _labels = self._labels[mask]
 | 
					 | 
				
			||||||
        self._register_labels(_labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def labels(self):
 | 
					 | 
				
			||||||
        return self._labels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self):
 | 
					 | 
				
			||||||
        return self._labels
 | 
					 | 
				
			||||||
@@ -1,3 +1,5 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch core"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .components import *
 | 
					from .components import *
 | 
				
			||||||
from .initializers import *
 | 
					from .initializers import *
 | 
				
			||||||
from .labels import *
 | 
					from .labels import *
 | 
				
			||||||
							
								
								
									
										243
									
								
								prototorch/core/components.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										243
									
								
								prototorch/core/components.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,243 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch components"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import inspect
 | 
				
			||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.nn.parameter import Parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..utils import parse_distribution
 | 
				
			||||||
 | 
					from .initializers import (
 | 
				
			||||||
 | 
					    AbstractComponentsInitializer,
 | 
				
			||||||
 | 
					    AbstractLabelsInitializer,
 | 
				
			||||||
 | 
					    AbstractReasoningsInitializer,
 | 
				
			||||||
 | 
					    ClassAwareCompInitializer,
 | 
				
			||||||
 | 
					    LabelsInitializer,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def validate_initializer(initializer, instanceof):
 | 
				
			||||||
 | 
					    if not isinstance(initializer, instanceof):
 | 
				
			||||||
 | 
					        emsg = f"`initializer` has to be an instance " \
 | 
				
			||||||
 | 
					            f"of some subtype of {instanceof}. " \
 | 
				
			||||||
 | 
					            f"You have provided: {initializer} instead. "
 | 
				
			||||||
 | 
					        helpmsg = ""
 | 
				
			||||||
 | 
					        if inspect.isclass(initializer):
 | 
				
			||||||
 | 
					            helpmsg = f"Perhaps you meant to say, {initializer.__name__}() " \
 | 
				
			||||||
 | 
					                f"with the brackets instead of just {initializer.__name__}?"
 | 
				
			||||||
 | 
					        raise TypeError(emsg + helpmsg)
 | 
				
			||||||
 | 
					    return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def validate_components_initializer(initializer):
 | 
				
			||||||
 | 
					    return validate_initializer(initializer, AbstractComponentsInitializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def validate_labels_initializer(initializer):
 | 
				
			||||||
 | 
					    return validate_initializer(initializer, AbstractLabelsInitializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def validate_reasonings_initializer(initializer):
 | 
				
			||||||
 | 
					    return validate_initializer(initializer, AbstractReasoningsInitializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AbstractComponents(torch.nn.Module):
 | 
				
			||||||
 | 
					    """Abstract class for all components modules."""
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def num_components(self):
 | 
				
			||||||
 | 
					        """Current number of components."""
 | 
				
			||||||
 | 
					        return len(self._components)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def components(self):
 | 
				
			||||||
 | 
					        """Detached Tensor containing the components."""
 | 
				
			||||||
 | 
					        return self._components.detach()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _register_components(self, components):
 | 
				
			||||||
 | 
					        self.register_parameter("_components", Parameter(components))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def extra_repr(self):
 | 
				
			||||||
 | 
					        return f"(components): (shape: {tuple(self._components.shape)})"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Components(AbstractComponents):
 | 
				
			||||||
 | 
					    """A set of adaptable Tensors."""
 | 
				
			||||||
 | 
					    def __init__(self, num_components: int,
 | 
				
			||||||
 | 
					                 initializer: AbstractComponentsInitializer, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					        self.add_components(num_components, initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_components(self, num: int,
 | 
				
			||||||
 | 
					                       initializer: AbstractComponentsInitializer):
 | 
				
			||||||
 | 
					        """Add new components."""
 | 
				
			||||||
 | 
					        assert validate_components_initializer(initializer)
 | 
				
			||||||
 | 
					        new_components = initializer.generate(num)
 | 
				
			||||||
 | 
					        # Register
 | 
				
			||||||
 | 
					        if hasattr(self, "_components"):
 | 
				
			||||||
 | 
					            _components = torch.cat([self._components, new_components])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            _components = new_components
 | 
				
			||||||
 | 
					        self._register_components(_components)
 | 
				
			||||||
 | 
					        return new_components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_components(self, indices):
 | 
				
			||||||
 | 
					        """Remove components at specified indices."""
 | 
				
			||||||
 | 
					        mask = torch.ones(self.num_components, dtype=torch.bool)
 | 
				
			||||||
 | 
					        mask[indices] = False
 | 
				
			||||||
 | 
					        _components = self._components[mask]
 | 
				
			||||||
 | 
					        self._register_components(_components)
 | 
				
			||||||
 | 
					        return mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self):
 | 
				
			||||||
 | 
					        """Simply return the components parameter Tensor."""
 | 
				
			||||||
 | 
					        return self._components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LabeledComponents(AbstractComponents):
 | 
				
			||||||
 | 
					    """A set of adaptable components and corresponding unadaptable labels."""
 | 
				
			||||||
 | 
					    def __init__(self, distribution: Union[dict, list, tuple],
 | 
				
			||||||
 | 
					                 components_initializer: AbstractComponentsInitializer,
 | 
				
			||||||
 | 
					                 labels_initializer: AbstractLabelsInitializer, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					        self.add_components(distribution, components_initializer,
 | 
				
			||||||
 | 
					                            labels_initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def component_labels(self):
 | 
				
			||||||
 | 
					        """Tensor containing the component tensors."""
 | 
				
			||||||
 | 
					        return self._labels.detach()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _register_labels(self, labels):
 | 
				
			||||||
 | 
					        self.register_buffer("_labels", labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_components(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        distribution,
 | 
				
			||||||
 | 
					        components_initializer,
 | 
				
			||||||
 | 
					        labels_initializer: AbstractLabelsInitializer = LabelsInitializer()):
 | 
				
			||||||
 | 
					        # Checks
 | 
				
			||||||
 | 
					        assert validate_components_initializer(components_initializer)
 | 
				
			||||||
 | 
					        assert validate_labels_initializer(labels_initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Generate new components
 | 
				
			||||||
 | 
					        if isinstance(components_initializer, ClassAwareCompInitializer):
 | 
				
			||||||
 | 
					            new_components = components_initializer.generate(distribution)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            num_components = sum(distribution.values())
 | 
				
			||||||
 | 
					            new_components = components_initializer.generate(num_components)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Generate new labels
 | 
				
			||||||
 | 
					        new_labels = labels_initializer.generate(distribution)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Register
 | 
				
			||||||
 | 
					        if hasattr(self, "_components"):
 | 
				
			||||||
 | 
					            _components = torch.cat([self._components, new_components])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            _components = new_components
 | 
				
			||||||
 | 
					        if hasattr(self, "_labels"):
 | 
				
			||||||
 | 
					            _labels = torch.cat([self._labels, new_labels])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            _labels = new_labels
 | 
				
			||||||
 | 
					        self._register_components(_components)
 | 
				
			||||||
 | 
					        self._register_labels(_labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return new_components, new_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_components(self, indices):
 | 
				
			||||||
 | 
					        """Remove components and labels at specified indices."""
 | 
				
			||||||
 | 
					        mask = torch.ones(self.num_components, dtype=torch.bool)
 | 
				
			||||||
 | 
					        mask[indices] = False
 | 
				
			||||||
 | 
					        _components = self._components[mask]
 | 
				
			||||||
 | 
					        _labels = self._labels[mask]
 | 
				
			||||||
 | 
					        self._register_components(_components)
 | 
				
			||||||
 | 
					        self._register_labels(_labels)
 | 
				
			||||||
 | 
					        return mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self):
 | 
				
			||||||
 | 
					        """Simply return the components parameter Tensor and labels."""
 | 
				
			||||||
 | 
					        return self._components, self._labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ReasoningComponents(AbstractComponents):
 | 
				
			||||||
 | 
					    """A set of components and a corresponding adapatable reasoning matrices.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Every component has its own reasoning matrix.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    A reasoning matrix is an Nx2 matrix, where N is the number of classes. The
 | 
				
			||||||
 | 
					    first element is called positive reasoning :math:`p`, the second negative
 | 
				
			||||||
 | 
					    reasoning :math:`n`. A components can reason in favour (positive) of a
 | 
				
			||||||
 | 
					    class, against (negative) a class or not at all (neutral).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
 | 
				
			||||||
 | 
					    \leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
 | 
				
			||||||
 | 
					    three element probability distribution.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self, distribution: Union[dict, list, tuple],
 | 
				
			||||||
 | 
					                 components_initializer: AbstractComponentsInitializer,
 | 
				
			||||||
 | 
					                 reasonings_initializer: AbstractReasoningsInitializer,
 | 
				
			||||||
 | 
					                 **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					        self.add_components(distribution, components_initializer,
 | 
				
			||||||
 | 
					                            reasonings_initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def reasonings(self):
 | 
				
			||||||
 | 
					        """Returns Reasoning Matrix.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Dimension NxCx2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self._reasonings.detach()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _register_reasonings(self, reasonings):
 | 
				
			||||||
 | 
					        self.register_parameter("_reasonings", Parameter(reasonings))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_components(self, distribution, components_initializer,
 | 
				
			||||||
 | 
					                       reasonings_initializer: AbstractReasoningsInitializer):
 | 
				
			||||||
 | 
					        # Checks
 | 
				
			||||||
 | 
					        assert validate_components_initializer(components_initializer)
 | 
				
			||||||
 | 
					        assert validate_reasonings_initializer(reasonings_initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Generate new components
 | 
				
			||||||
 | 
					        if isinstance(components_initializer, ClassAwareCompInitializer):
 | 
				
			||||||
 | 
					            new_components = components_initializer.generate(distribution)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            num_components = sum(distribution.values())
 | 
				
			||||||
 | 
					            new_components = components_initializer.generate(num_components)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Generate new reasonings
 | 
				
			||||||
 | 
					        new_reasonings = reasonings_initializer.generate(distribution)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Register
 | 
				
			||||||
 | 
					        if hasattr(self, "_components"):
 | 
				
			||||||
 | 
					            _components = torch.cat([self._components, new_components])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            _components = new_components
 | 
				
			||||||
 | 
					        if hasattr(self, "_reasonings"):
 | 
				
			||||||
 | 
					            _reasonings = torch.cat([self._reasonings, new_reasonings])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            _reasonings = new_reasonings
 | 
				
			||||||
 | 
					        self._register_components(_components)
 | 
				
			||||||
 | 
					        self._register_reasonings(_reasonings)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return new_components, new_reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_components(self, indices):
 | 
				
			||||||
 | 
					        """Remove components and labels at specified indices."""
 | 
				
			||||||
 | 
					        mask = torch.ones(self.num_components, dtype=torch.bool)
 | 
				
			||||||
 | 
					        mask[indices] = False
 | 
				
			||||||
 | 
					        _components = self._components[mask]
 | 
				
			||||||
 | 
					        # TODO
 | 
				
			||||||
 | 
					        # _reasonings = self._reasonings[mask]
 | 
				
			||||||
 | 
					        self._register_components(_components)
 | 
				
			||||||
 | 
					        # self._register_reasonings(_reasonings)
 | 
				
			||||||
 | 
					        return mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self):
 | 
				
			||||||
 | 
					        """Simply return the components and reasonings."""
 | 
				
			||||||
 | 
					        return self._components, self._reasonings
 | 
				
			||||||
							
								
								
									
										258
									
								
								prototorch/core/initializers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										258
									
								
								prototorch/core/initializers.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,258 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch code initializers"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from abc import ABC, abstractmethod
 | 
				
			||||||
 | 
					from collections.abc import Iterable
 | 
				
			||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..utils import parse_data_arg, parse_distribution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Components
 | 
				
			||||||
 | 
					class AbstractComponentsInitializer(ABC):
 | 
				
			||||||
 | 
					    """Abstract class for all components initializers."""
 | 
				
			||||||
 | 
					    ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ShapeAwareCompInitializer(AbstractComponentsInitializer):
 | 
				
			||||||
 | 
					    """Abstract class for all dimension-aware components initializers."""
 | 
				
			||||||
 | 
					    def __init__(self, shape: Union[Iterable, int]):
 | 
				
			||||||
 | 
					        if isinstance(shape, Iterable):
 | 
				
			||||||
 | 
					            self.component_shape = tuple(shape)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.component_shape = (shape, )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DataAwareCompInitializer(AbstractComponentsInitializer):
 | 
				
			||||||
 | 
					    """Abstract class for all data-aware components initializers.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Components generated by data-aware components initializers inherit the shape
 | 
				
			||||||
 | 
					    of the provided data.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 data,
 | 
				
			||||||
 | 
					                 noise: float = 0.0,
 | 
				
			||||||
 | 
					                 transform: callable = torch.nn.Identity()):
 | 
				
			||||||
 | 
					        self.data = data
 | 
				
			||||||
 | 
					        self.noise = noise
 | 
				
			||||||
 | 
					        self.transform = transform
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_end_hook(self, samples):
 | 
				
			||||||
 | 
					        drift = torch.rand_like(samples) * self.noise
 | 
				
			||||||
 | 
					        components = self.transform(samples + drift)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					        return self.generate_end_hook(...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __del__(self):
 | 
				
			||||||
 | 
					        del self.data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ClassAwareCompInitializer(AbstractComponentsInitializer):
 | 
				
			||||||
 | 
					    """Abstract class for all class-aware components initializers.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Components generated by class-aware components initializers inherit the shape
 | 
				
			||||||
 | 
					    of the provided data.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 data,
 | 
				
			||||||
 | 
					                 noise: float = 0.0,
 | 
				
			||||||
 | 
					                 transform: callable = torch.nn.Identity()):
 | 
				
			||||||
 | 
					        self.data, self.targets = parse_data_arg(data)
 | 
				
			||||||
 | 
					        self.noise = noise
 | 
				
			||||||
 | 
					        self.transform = transform
 | 
				
			||||||
 | 
					        self.clabels = torch.unique(self.targets).int().tolist()
 | 
				
			||||||
 | 
					        self.num_classes = len(self.clabels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def subinit_type(self) -> DataAwareCompInitializer:
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					        initializers = {
 | 
				
			||||||
 | 
					            k: self.subinit_type(self.data[self.targets == k])
 | 
				
			||||||
 | 
					            for k in distribution.keys()
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        components = torch.tensor([])
 | 
				
			||||||
 | 
					        for k, v in distribution.items():
 | 
				
			||||||
 | 
					            stratified_data = self.data[self.targets == k]
 | 
				
			||||||
 | 
					            # skip transform here
 | 
				
			||||||
 | 
					            initializer = self.subinit_type(
 | 
				
			||||||
 | 
					                stratified_data,
 | 
				
			||||||
 | 
					                noise=self.noise,
 | 
				
			||||||
 | 
					                transform=self.transform,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            samples = initializer.generate(num_components=v)
 | 
				
			||||||
 | 
					            components = torch.cat([components, samples])
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __del__(self):
 | 
				
			||||||
 | 
					        del self.data
 | 
				
			||||||
 | 
					        del self.targets
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LiteralCompInitializer(DataAwareCompInitializer):
 | 
				
			||||||
 | 
					    """'Generate' the provided components.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Use this to 'generate' pre-initialized components from elsewhere.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        """Ignore `num_components` and simply return transformed `self.data`."""
 | 
				
			||||||
 | 
					        components = self.transform(self.data)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ZerosCompInitializer(ShapeAwareCompInitializer):
 | 
				
			||||||
 | 
					    """Generate zeros corresponding to the components shape."""
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        components = torch.zeros((num_components, ) + self.component_shape)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class OnesCompInitializer(ShapeAwareCompInitializer):
 | 
				
			||||||
 | 
					    """Generate ones corresponding to the components shape."""
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        components = torch.ones((num_components, ) + self.component_shape)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FillValueCompInitializer(OnesCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components with the provided `fill_value`."""
 | 
				
			||||||
 | 
					    def __init__(self, shape, fill_value: float = 1.0):
 | 
				
			||||||
 | 
					        super().__init__(shape)
 | 
				
			||||||
 | 
					        self.fill_value = fill_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        ones = super().generate(num_components)
 | 
				
			||||||
 | 
					        components = ones.fill_(self.fill_value)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UniformCompInitializer(OnesCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components by sampling from a continuous uniform distribution."""
 | 
				
			||||||
 | 
					    def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
 | 
				
			||||||
 | 
					        super().__init__(shape)
 | 
				
			||||||
 | 
					        self.minimum = minimum
 | 
				
			||||||
 | 
					        self.maximum = maximum
 | 
				
			||||||
 | 
					        self.scale = scale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        ones = super().generate(num_components)
 | 
				
			||||||
 | 
					        components = self.scale * ones.uniform_(self.minimum, self.maximum)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RandomNormalCompInitializer(OnesCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components by sampling from a standard normal distribution."""
 | 
				
			||||||
 | 
					    def __init__(self, shape, scale=1.0):
 | 
				
			||||||
 | 
					        super().__init__(shape)
 | 
				
			||||||
 | 
					        self.scale = scale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        ones = super().generate(num_components)
 | 
				
			||||||
 | 
					        components = self.scale * torch.randn_like(ones)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SelectionCompInitializer(DataAwareCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components by uniformly sampling from the provided data."""
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        indices = torch.LongTensor(num_components).random_(0, len(self.data))
 | 
				
			||||||
 | 
					        samples = self.data[indices]
 | 
				
			||||||
 | 
					        components = self.generate_end_hook(samples)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MeanCompInitializer(DataAwareCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components by computing the mean of the provided data."""
 | 
				
			||||||
 | 
					    def generate(self, num_components: int):
 | 
				
			||||||
 | 
					        mean = torch.mean(self.data, dim=0)
 | 
				
			||||||
 | 
					        repeat_dim = [num_components] + [1] * len(mean.shape)
 | 
				
			||||||
 | 
					        samples = mean.repeat(repeat_dim)
 | 
				
			||||||
 | 
					        components = self.generate_end_hook(samples)
 | 
				
			||||||
 | 
					        return components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StratifiedSelectionCompInitializer(ClassAwareCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components using stratified sampling from the provided data."""
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def subinit_type(self):
 | 
				
			||||||
 | 
					        return SelectionCompInitializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StratifiedMeanCompInitializer(ClassAwareCompInitializer):
 | 
				
			||||||
 | 
					    """Generate components at stratified means of the provided data."""
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def subinit_type(self):
 | 
				
			||||||
 | 
					        return MeanCompInitializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Labels
 | 
				
			||||||
 | 
					class AbstractLabelsInitializer(ABC):
 | 
				
			||||||
 | 
					    """Abstract class for all labels initializers."""
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LabelsInitializer(AbstractLabelsInitializer):
 | 
				
			||||||
 | 
					    """Generate labels with `self.distribution`."""
 | 
				
			||||||
 | 
					    def __init__(self, override_labels: list = []):
 | 
				
			||||||
 | 
					        self.override_labels = override_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					        labels = []
 | 
				
			||||||
 | 
					        for k, v in distribution.items():
 | 
				
			||||||
 | 
					            labels.extend([k] * v)
 | 
				
			||||||
 | 
					        labels = torch.LongTensor(labels)
 | 
				
			||||||
 | 
					        return labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Reasonings
 | 
				
			||||||
 | 
					class AbstractReasoningsInitializer(ABC):
 | 
				
			||||||
 | 
					    """Abstract class for all reasonings initializers."""
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
 | 
				
			||||||
 | 
					    """Generate labels with `self.distribution`."""
 | 
				
			||||||
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					        num_classes = len(distribution.keys())
 | 
				
			||||||
 | 
					        num_components = sum(distribution.values())
 | 
				
			||||||
 | 
					        assert num_classes == num_components
 | 
				
			||||||
 | 
					        reasonings = torch.stack(
 | 
				
			||||||
 | 
					            [torch.eye(num_classes),
 | 
				
			||||||
 | 
					             torch.zeros(num_classes, num_classes)],
 | 
				
			||||||
 | 
					            dim=0)
 | 
				
			||||||
 | 
					        return reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Aliases - Components
 | 
				
			||||||
 | 
					ZCI = ZerosCompInitializer
 | 
				
			||||||
 | 
					OCI = OnesCompInitializer
 | 
				
			||||||
 | 
					FVCI = FillValueCompInitializer
 | 
				
			||||||
 | 
					LCI = LiteralCompInitializer
 | 
				
			||||||
 | 
					UCI = UniformCompInitializer
 | 
				
			||||||
 | 
					RNCI = RandomNormalCompInitializer
 | 
				
			||||||
 | 
					SCI = SelectionCompInitializer
 | 
				
			||||||
 | 
					MCI = MeanCompInitializer
 | 
				
			||||||
 | 
					SSCI = StratifiedSelectionCompInitializer
 | 
				
			||||||
 | 
					SMCI = StratifiedMeanCompInitializer
 | 
				
			||||||
 | 
					PPRI = PurePositiveReasoningsInitializer
 | 
				
			||||||
@@ -23,7 +23,10 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
 | 
				
			|||||||
    return mesh, xx, yy
 | 
					    return mesh, xx, yy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def parse_distribution(user_distribution: Union[dict, list, tuple]):
 | 
					def parse_distribution(
 | 
				
			||||||
 | 
					    user_distribution: Union[dict[int, int], dict[str, str], list[int],
 | 
				
			||||||
 | 
					                             tuple[int]]
 | 
				
			||||||
 | 
					) -> dict[int, int]:
 | 
				
			||||||
    """Parse user-provided distribution.
 | 
					    """Parse user-provided distribution.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Return a dictionary with integer keys that represent the class labels and
 | 
					    Return a dictionary with integer keys that represent the class labels and
 | 
				
			||||||
@@ -51,14 +54,15 @@ def parse_distribution(user_distribution: Union[dict, list, tuple]):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    if isinstance(user_distribution, dict):
 | 
					    if isinstance(user_distribution, dict):
 | 
				
			||||||
        if "num_classes" in user_distribution.keys():
 | 
					        if "num_classes" in user_distribution.keys():
 | 
				
			||||||
            num_classes = user_distribution["num_classes"]
 | 
					            num_classes = int(user_distribution["num_classes"])
 | 
				
			||||||
            per_class = user_distribution["per_class"]
 | 
					            per_class = int(user_distribution["per_class"])
 | 
				
			||||||
            return from_list([per_class] * num_classes)
 | 
					            return from_list([per_class] * num_classes)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return user_distribution
 | 
					            return user_distribution
 | 
				
			||||||
    elif isinstance(user_distribution, tuple):
 | 
					    elif isinstance(user_distribution, tuple):
 | 
				
			||||||
        assert len(user_distribution) == 2
 | 
					        assert len(user_distribution) == 2
 | 
				
			||||||
        num_classes, per_class = user_distribution
 | 
					        num_classes, per_class = user_distribution
 | 
				
			||||||
 | 
					        num_classes, per_class = int(num_classes), int(per_class)
 | 
				
			||||||
        return from_list([per_class] * num_classes)
 | 
					        return from_list([per_class] * num_classes)
 | 
				
			||||||
    elif isinstance(user_distribution, list):
 | 
					    elif isinstance(user_distribution, list):
 | 
				
			||||||
        return from_list(user_distribution)
 | 
					        return from_list(user_distribution)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user