diff --git a/examples/new_components.py b/examples/new_components.py new file mode 100644 index 0000000..d4a2555 --- /dev/null +++ b/examples/new_components.py @@ -0,0 +1,67 @@ +# +# DATASET +# +import torch + +from sklearn.datasets import load_iris +from sklearn.preprocessing import StandardScaler + +scaler = StandardScaler() +x_train, y_train = load_iris(return_X_y=True) +x_train = x_train[:, [0, 2]] +scaler.fit(x_train) +x_train = scaler.transform(x_train) + +x_train = torch.Tensor(x_train) +y_train = torch.Tensor(y_train) +num_classes = len(torch.unique(y_train)) + +# +# CREATE NEW COMPONENTS +# +from prototorch.components import * +from prototorch.components.initializers import * + +unsupervised = Components(6, SelectionInitializer(x_train)) +print(unsupervised()) + +prototypes = LabeledComponents( + (3, 2), StratifiedSelectionInitializer(x_train, y_train)) +print(prototypes()) + +components = ReasoningComponents( + (3, 6), StratifiedSelectionInitializer(x_train, y_train)) +print(components()) + +# +# TEST SERIALIZATION +# +import io + +save = io.BytesIO() +torch.save(unsupervised, save) +save.seek(0) +serialized_unsupervised = torch.load(save) + +assert torch.all(unsupervised.components == serialized_unsupervised.components + ), "Serialization of Components failed." + +save = io.BytesIO() +torch.save(prototypes, save) +save.seek(0) +serialized_prototypes = torch.load(save) + +assert torch.all(prototypes.components == serialized_prototypes.components + ), "Serialization of Components failed." +assert torch.all(prototypes.labels == serialized_prototypes.labels + ), "Serialization of Components failed." + +save = io.BytesIO() +torch.save(components, save) +save.seek(0) +serialized_components = torch.load(save) + +assert torch.all(components.components == serialized_components.components + ), "Serialization of Components failed." +assert torch.all(components.reasonings == serialized_components.reasonings + ), "Serialization of Components failed." diff --git a/prototorch/components/__init__.py b/prototorch/components/__init__.py new file mode 100644 index 0000000..3ae0a51 --- /dev/null +++ b/prototorch/components/__init__.py @@ -0,0 +1,7 @@ +from prototorch.components.components import Components, LabeledComponents, ReasoningComponents + +__all__ = [ + "Components", + "LabeledComponents", + "ReasoningComponents", +] diff --git a/prototorch/components/components.py b/prototorch/components/components.py new file mode 100644 index 0000000..267e22b --- /dev/null +++ b/prototorch/components/components.py @@ -0,0 +1,130 @@ +"""ProtoTorch components modules.""" + +from typing import Tuple +import warnings +from prototorch.components.initializers import EqualLabelInitializer, ZeroReasoningsInitializer +import torch +from torch.nn.parameter import Parameter + +from prototorch.functions.initializers import get_initializer + + +class Components(torch.nn.Module): + """ + Components is a set of learnable Tensors. + """ + def __init__(self, + number_of_components=None, + initializer=None, + *, + initialized_components=None, + dtype=torch.float32): + super().__init__() + + # Ignore all initialization settings if initialized_components is given. + if initialized_components is not None: + self._components = Parameter(initialized_components) + if number_of_components is not None or initializer is not None: + warnings.warn( + "Arguments ignored while initializing Components") + else: + self._initialize_components(number_of_components, initializer) + + def _initialize_components(self, number_of_components, initializer): + self._components = Parameter( + initializer.generate(number_of_components)) + + @property + def components(self): + """ + Tensor containing the component tensors. + """ + return self._components.detach().cpu() + + def forward(self): + return self._components + + def extra_repr(self): + return f"components.shape: {tuple(self._components.shape)}" + + +class LabeledComponents(Components): + """ + LabeledComponents generate a set of components and a set of labels. + Every Component has a label assigned. + """ + def __init__(self, + labels=None, + initializer=None, + *, + initialized_components=None): + if initialized_components is not None: + super().__init__(initialized_components=initialized_components[0]) + self._labels = initialized_components[1] + else: + self._initialize_labels(labels, initializer) + super().__init__(number_of_components=len(self._labels), + initializer=initializer) + + def _initialize_labels(self, labels, initializer): + if type(labels) == tuple: + num_classes, prototypes_per_class = labels + labels = EqualLabelInitializer(num_classes, prototypes_per_class) + + self._labels = labels.generate() + + @property + def labels(self): + """ + Tensor containing the component tensors. + """ + return self._labels.detach().cpu() + + def forward(self): + return super().forward(), self._labels + + +class ReasoningComponents(Components): + """ + ReasoningComponents generate a set of components and a set of reasoning matrices. + Every Component has a reasoning matrix assigned. + + A reasoning matrix is a Nx2 matrix, where N is the number of Classes. + The first element is called positive reasoning :math:`p`, the second negative reasoning :math:`n`. + A components can reason in favour (positive) of a class, against (negative) a class or not at all (neutral). + + It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0 \leq n+p \leq 1`. + Therefore :math:`n` and :math:`p` are two elements of a three element probability distribution. + """ + def __init__(self, + reasonings=None, + initializer=None, + *, + initialized_components=None): + if initialized_components is not None: + super().__init__(initialized_components=initialized_components[0]) + self._reasonings = initialized_components[1] + else: + self._initialize_reasonings(reasonings) + super().__init__(number_of_components=len(self._reasonings), + initializer=initializer) + + def _initialize_reasonings(self, reasonings): + if type(reasonings) == tuple: + num_classes, number_of_components = reasonings + reasonings = ZeroReasoningsInitializer(num_classes, + number_of_components) + + self._reasonings = reasonings.generate() + + @property + def reasonings(self): + """ + Returns Reasoning Matrix. + + Dimension NxCx2 + """ + return self._reasonings.detach().cpu() + + def forward(self): + return super().forward(), self._reasonings \ No newline at end of file diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py new file mode 100644 index 0000000..c9ca22a --- /dev/null +++ b/prototorch/components/initializers.py @@ -0,0 +1,132 @@ +import torch +from collections.abc import Iterable + + +# Components +class ComponentsInitializer: + def generate(self, number_of_components): + pass + + +class DimensionAwareInitializer(ComponentsInitializer): + def __init__(self, c_dims): + super().__init__() + if isinstance(c_dims, Iterable): + self.components_dims = tuple(c_dims) + else: + self.components_dims = (c_dims, ) + + +class OnesInitializer(DimensionAwareInitializer): + def generate(self, length): + gen_dims = (length, ) + self.components_dims + return torch.ones(gen_dims) + + +class ZerosInitializer(DimensionAwareInitializer): + def generate(self, length): + gen_dims = (length, ) + self.components_dims + return torch.zeros(gen_dims) + + +class UniformInitializer(DimensionAwareInitializer): + def __init__(self, c_dims, min=0.0, max=1.0): + super().__init__(c_dims) + + self.min = min + self.max = max + + def generate(self, length): + gen_dims = (length, ) + self.components_dims + return torch.FloatTensor(gen_dims).uniform_(self.min, self.max) + + +class PositionAwareInitializer(ComponentsInitializer): + def __init__(self, positions): + super().__init__() + self.data = positions + + +class SelectionInitializer(PositionAwareInitializer): + def generate(self, length): + indices = torch.LongTensor(length).random_(0, len(self.data)) + return self.data[indices] + + +class MeanInitializer(PositionAwareInitializer): + def generate(self, length): + mean = torch.mean(self.data, dim=0) + repeat_dim = [length] + [1] * len(mean.shape) + return mean.repeat(repeat_dim) + + +class ClassAwareInitializer(ComponentsInitializer): + def __init__(self, positions, classes): + super().__init__() + self.data = positions + self.classes = classes + + self.names = torch.unique(self.classes) + self.num_classes = len(self.names) + + +class StratifiedMeanInitializer(ClassAwareInitializer): + def __init__(self, positions, classes): + super().__init__(positions, classes) + + self.initializers = [] + for name in self.names: + class_data = self.data[self.classes == name] + class_initializer = MeanInitializer(class_data) + self.initializers.append(class_initializer) + + def generate(self, length): + per_class = length // self.num_classes + return torch.vstack( + [init.generate(per_class) for init in self.initializers]) + + +class StratifiedSelectionInitializer(ClassAwareInitializer): + def __init__(self, positions, classes): + super().__init__(positions, classes) + + self.initializers = [] + for name in self.names: + class_data = self.data[self.classes == name] + class_initializer = SelectionInitializer(class_data) + self.initializers.append(class_initializer) + + def generate(self, length): + per_class = length // self.num_classes + return torch.vstack( + [init.generate(per_class) for init in self.initializers]) + + +# Labels +class LabelsInitializer: + def generate(self): + pass + + +class EqualLabelInitializer(LabelsInitializer): + def __init__(self, classes, per_class): + self.classes = classes + self.per_class = per_class + + def generate(self): + return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten() + + +# Reasonings +class ReasoningsInitializer: + def generate(self, length): + pass + + +class ZeroReasoningsInitializer(ReasoningsInitializer): + def __init__(self, classes, length): + self.classes = classes + self.length = length + + def generate(self): + return torch.zeros((self.length, self.classes, 2)) diff --git a/prototorch/modules/prototypes.py b/prototorch/modules/prototypes.py index 3a557a0..ac4c1e5 100644 --- a/prototorch/modules/prototypes.py +++ b/prototorch/modules/prototypes.py @@ -39,6 +39,9 @@ class Prototypes1D(_Prototypes): one_hot_labels=False, **kwargs, ): + warnings.warn( + PendingDeprecationWarning( + "Prototypes1D will be replaced in future versions.")) # Convert tensors to python lists before processing if prototype_distribution is not None: