Create Component and initializer classes.

This commit is contained in:
Alexander Engelsberger 2021-04-26 20:49:50 +02:00
parent 7c30ffe2c7
commit 40751aa50a
5 changed files with 339 additions and 0 deletions

View File

@ -0,0 +1,67 @@
#
# DATASET
#
import torch
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
scaler.fit(x_train)
x_train = scaler.transform(x_train)
x_train = torch.Tensor(x_train)
y_train = torch.Tensor(y_train)
num_classes = len(torch.unique(y_train))
#
# CREATE NEW COMPONENTS
#
from prototorch.components import *
from prototorch.components.initializers import *
unsupervised = Components(6, SelectionInitializer(x_train))
print(unsupervised())
prototypes = LabeledComponents(
(3, 2), StratifiedSelectionInitializer(x_train, y_train))
print(prototypes())
components = ReasoningComponents(
(3, 6), StratifiedSelectionInitializer(x_train, y_train))
print(components())
#
# TEST SERIALIZATION
#
import io
save = io.BytesIO()
torch.save(unsupervised, save)
save.seek(0)
serialized_unsupervised = torch.load(save)
assert torch.all(unsupervised.components == serialized_unsupervised.components
), "Serialization of Components failed."
save = io.BytesIO()
torch.save(prototypes, save)
save.seek(0)
serialized_prototypes = torch.load(save)
assert torch.all(prototypes.components == serialized_prototypes.components
), "Serialization of Components failed."
assert torch.all(prototypes.labels == serialized_prototypes.labels
), "Serialization of Components failed."
save = io.BytesIO()
torch.save(components, save)
save.seek(0)
serialized_components = torch.load(save)
assert torch.all(components.components == serialized_components.components
), "Serialization of Components failed."
assert torch.all(components.reasonings == serialized_components.reasonings
), "Serialization of Components failed."

View File

@ -0,0 +1,7 @@
from prototorch.components.components import Components, LabeledComponents, ReasoningComponents
__all__ = [
"Components",
"LabeledComponents",
"ReasoningComponents",
]

View File

@ -0,0 +1,130 @@
"""ProtoTorch components modules."""
from typing import Tuple
import warnings
from prototorch.components.initializers import EqualLabelInitializer, ZeroReasoningsInitializer
import torch
from torch.nn.parameter import Parameter
from prototorch.functions.initializers import get_initializer
class Components(torch.nn.Module):
"""
Components is a set of learnable Tensors.
"""
def __init__(self,
number_of_components=None,
initializer=None,
*,
initialized_components=None,
dtype=torch.float32):
super().__init__()
# Ignore all initialization settings if initialized_components is given.
if initialized_components is not None:
self._components = Parameter(initialized_components)
if number_of_components is not None or initializer is not None:
warnings.warn(
"Arguments ignored while initializing Components")
else:
self._initialize_components(number_of_components, initializer)
def _initialize_components(self, number_of_components, initializer):
self._components = Parameter(
initializer.generate(number_of_components))
@property
def components(self):
"""
Tensor containing the component tensors.
"""
return self._components.detach().cpu()
def forward(self):
return self._components
def extra_repr(self):
return f"components.shape: {tuple(self._components.shape)}"
class LabeledComponents(Components):
"""
LabeledComponents generate a set of components and a set of labels.
Every Component has a label assigned.
"""
def __init__(self,
labels=None,
initializer=None,
*,
initialized_components=None):
if initialized_components is not None:
super().__init__(initialized_components=initialized_components[0])
self._labels = initialized_components[1]
else:
self._initialize_labels(labels, initializer)
super().__init__(number_of_components=len(self._labels),
initializer=initializer)
def _initialize_labels(self, labels, initializer):
if type(labels) == tuple:
num_classes, prototypes_per_class = labels
labels = EqualLabelInitializer(num_classes, prototypes_per_class)
self._labels = labels.generate()
@property
def labels(self):
"""
Tensor containing the component tensors.
"""
return self._labels.detach().cpu()
def forward(self):
return super().forward(), self._labels
class ReasoningComponents(Components):
"""
ReasoningComponents generate a set of components and a set of reasoning matrices.
Every Component has a reasoning matrix assigned.
A reasoning matrix is a Nx2 matrix, where N is the number of Classes.
The first element is called positive reasoning :math:`p`, the second negative reasoning :math:`n`.
A components can reason in favour (positive) of a class, against (negative) a class or not at all (neutral).
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0 \leq n+p \leq 1`.
Therefore :math:`n` and :math:`p` are two elements of a three element probability distribution.
"""
def __init__(self,
reasonings=None,
initializer=None,
*,
initialized_components=None):
if initialized_components is not None:
super().__init__(initialized_components=initialized_components[0])
self._reasonings = initialized_components[1]
else:
self._initialize_reasonings(reasonings)
super().__init__(number_of_components=len(self._reasonings),
initializer=initializer)
def _initialize_reasonings(self, reasonings):
if type(reasonings) == tuple:
num_classes, number_of_components = reasonings
reasonings = ZeroReasoningsInitializer(num_classes,
number_of_components)
self._reasonings = reasonings.generate()
@property
def reasonings(self):
"""
Returns Reasoning Matrix.
Dimension NxCx2
"""
return self._reasonings.detach().cpu()
def forward(self):
return super().forward(), self._reasonings

View File

@ -0,0 +1,132 @@
import torch
from collections.abc import Iterable
# Components
class ComponentsInitializer:
def generate(self, number_of_components):
pass
class DimensionAwareInitializer(ComponentsInitializer):
def __init__(self, c_dims):
super().__init__()
if isinstance(c_dims, Iterable):
self.components_dims = tuple(c_dims)
else:
self.components_dims = (c_dims, )
class OnesInitializer(DimensionAwareInitializer):
def generate(self, length):
gen_dims = (length, ) + self.components_dims
return torch.ones(gen_dims)
class ZerosInitializer(DimensionAwareInitializer):
def generate(self, length):
gen_dims = (length, ) + self.components_dims
return torch.zeros(gen_dims)
class UniformInitializer(DimensionAwareInitializer):
def __init__(self, c_dims, min=0.0, max=1.0):
super().__init__(c_dims)
self.min = min
self.max = max
def generate(self, length):
gen_dims = (length, ) + self.components_dims
return torch.FloatTensor(gen_dims).uniform_(self.min, self.max)
class PositionAwareInitializer(ComponentsInitializer):
def __init__(self, positions):
super().__init__()
self.data = positions
class SelectionInitializer(PositionAwareInitializer):
def generate(self, length):
indices = torch.LongTensor(length).random_(0, len(self.data))
return self.data[indices]
class MeanInitializer(PositionAwareInitializer):
def generate(self, length):
mean = torch.mean(self.data, dim=0)
repeat_dim = [length] + [1] * len(mean.shape)
return mean.repeat(repeat_dim)
class ClassAwareInitializer(ComponentsInitializer):
def __init__(self, positions, classes):
super().__init__()
self.data = positions
self.classes = classes
self.names = torch.unique(self.classes)
self.num_classes = len(self.names)
class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, positions, classes):
super().__init__(positions, classes)
self.initializers = []
for name in self.names:
class_data = self.data[self.classes == name]
class_initializer = MeanInitializer(class_data)
self.initializers.append(class_initializer)
def generate(self, length):
per_class = length // self.num_classes
return torch.vstack(
[init.generate(per_class) for init in self.initializers])
class StratifiedSelectionInitializer(ClassAwareInitializer):
def __init__(self, positions, classes):
super().__init__(positions, classes)
self.initializers = []
for name in self.names:
class_data = self.data[self.classes == name]
class_initializer = SelectionInitializer(class_data)
self.initializers.append(class_initializer)
def generate(self, length):
per_class = length // self.num_classes
return torch.vstack(
[init.generate(per_class) for init in self.initializers])
# Labels
class LabelsInitializer:
def generate(self):
pass
class EqualLabelInitializer(LabelsInitializer):
def __init__(self, classes, per_class):
self.classes = classes
self.per_class = per_class
def generate(self):
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
# Reasonings
class ReasoningsInitializer:
def generate(self, length):
pass
class ZeroReasoningsInitializer(ReasoningsInitializer):
def __init__(self, classes, length):
self.classes = classes
self.length = length
def generate(self):
return torch.zeros((self.length, self.classes, 2))

View File

@ -39,6 +39,9 @@ class Prototypes1D(_Prototypes):
one_hot_labels=False,
**kwargs,
):
warnings.warn(
PendingDeprecationWarning(
"Prototypes1D will be replaced in future versions."))
# Convert tensors to python lists before processing
if prototype_distribution is not None: