Create Component and initializer classes.
This commit is contained in:
parent
7c30ffe2c7
commit
40751aa50a
67
examples/new_components.py
Normal file
67
examples/new_components.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
#
|
||||||
|
# DATASET
|
||||||
|
#
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
|
scaler = StandardScaler()
|
||||||
|
x_train, y_train = load_iris(return_X_y=True)
|
||||||
|
x_train = x_train[:, [0, 2]]
|
||||||
|
scaler.fit(x_train)
|
||||||
|
x_train = scaler.transform(x_train)
|
||||||
|
|
||||||
|
x_train = torch.Tensor(x_train)
|
||||||
|
y_train = torch.Tensor(y_train)
|
||||||
|
num_classes = len(torch.unique(y_train))
|
||||||
|
|
||||||
|
#
|
||||||
|
# CREATE NEW COMPONENTS
|
||||||
|
#
|
||||||
|
from prototorch.components import *
|
||||||
|
from prototorch.components.initializers import *
|
||||||
|
|
||||||
|
unsupervised = Components(6, SelectionInitializer(x_train))
|
||||||
|
print(unsupervised())
|
||||||
|
|
||||||
|
prototypes = LabeledComponents(
|
||||||
|
(3, 2), StratifiedSelectionInitializer(x_train, y_train))
|
||||||
|
print(prototypes())
|
||||||
|
|
||||||
|
components = ReasoningComponents(
|
||||||
|
(3, 6), StratifiedSelectionInitializer(x_train, y_train))
|
||||||
|
print(components())
|
||||||
|
|
||||||
|
#
|
||||||
|
# TEST SERIALIZATION
|
||||||
|
#
|
||||||
|
import io
|
||||||
|
|
||||||
|
save = io.BytesIO()
|
||||||
|
torch.save(unsupervised, save)
|
||||||
|
save.seek(0)
|
||||||
|
serialized_unsupervised = torch.load(save)
|
||||||
|
|
||||||
|
assert torch.all(unsupervised.components == serialized_unsupervised.components
|
||||||
|
), "Serialization of Components failed."
|
||||||
|
|
||||||
|
save = io.BytesIO()
|
||||||
|
torch.save(prototypes, save)
|
||||||
|
save.seek(0)
|
||||||
|
serialized_prototypes = torch.load(save)
|
||||||
|
|
||||||
|
assert torch.all(prototypes.components == serialized_prototypes.components
|
||||||
|
), "Serialization of Components failed."
|
||||||
|
assert torch.all(prototypes.labels == serialized_prototypes.labels
|
||||||
|
), "Serialization of Components failed."
|
||||||
|
|
||||||
|
save = io.BytesIO()
|
||||||
|
torch.save(components, save)
|
||||||
|
save.seek(0)
|
||||||
|
serialized_components = torch.load(save)
|
||||||
|
|
||||||
|
assert torch.all(components.components == serialized_components.components
|
||||||
|
), "Serialization of Components failed."
|
||||||
|
assert torch.all(components.reasonings == serialized_components.reasonings
|
||||||
|
), "Serialization of Components failed."
|
7
prototorch/components/__init__.py
Normal file
7
prototorch/components/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from prototorch.components.components import Components, LabeledComponents, ReasoningComponents
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Components",
|
||||||
|
"LabeledComponents",
|
||||||
|
"ReasoningComponents",
|
||||||
|
]
|
130
prototorch/components/components.py
Normal file
130
prototorch/components/components.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
"""ProtoTorch components modules."""
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
import warnings
|
||||||
|
from prototorch.components.initializers import EqualLabelInitializer, ZeroReasoningsInitializer
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from prototorch.functions.initializers import get_initializer
|
||||||
|
|
||||||
|
|
||||||
|
class Components(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Components is a set of learnable Tensors.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
number_of_components=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_components=None,
|
||||||
|
dtype=torch.float32):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Ignore all initialization settings if initialized_components is given.
|
||||||
|
if initialized_components is not None:
|
||||||
|
self._components = Parameter(initialized_components)
|
||||||
|
if number_of_components is not None or initializer is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Arguments ignored while initializing Components")
|
||||||
|
else:
|
||||||
|
self._initialize_components(number_of_components, initializer)
|
||||||
|
|
||||||
|
def _initialize_components(self, number_of_components, initializer):
|
||||||
|
self._components = Parameter(
|
||||||
|
initializer.generate(number_of_components))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def components(self):
|
||||||
|
"""
|
||||||
|
Tensor containing the component tensors.
|
||||||
|
"""
|
||||||
|
return self._components.detach().cpu()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self._components
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"components.shape: {tuple(self._components.shape)}"
|
||||||
|
|
||||||
|
|
||||||
|
class LabeledComponents(Components):
|
||||||
|
"""
|
||||||
|
LabeledComponents generate a set of components and a set of labels.
|
||||||
|
Every Component has a label assigned.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
labels=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_components=None):
|
||||||
|
if initialized_components is not None:
|
||||||
|
super().__init__(initialized_components=initialized_components[0])
|
||||||
|
self._labels = initialized_components[1]
|
||||||
|
else:
|
||||||
|
self._initialize_labels(labels, initializer)
|
||||||
|
super().__init__(number_of_components=len(self._labels),
|
||||||
|
initializer=initializer)
|
||||||
|
|
||||||
|
def _initialize_labels(self, labels, initializer):
|
||||||
|
if type(labels) == tuple:
|
||||||
|
num_classes, prototypes_per_class = labels
|
||||||
|
labels = EqualLabelInitializer(num_classes, prototypes_per_class)
|
||||||
|
|
||||||
|
self._labels = labels.generate()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def labels(self):
|
||||||
|
"""
|
||||||
|
Tensor containing the component tensors.
|
||||||
|
"""
|
||||||
|
return self._labels.detach().cpu()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return super().forward(), self._labels
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningComponents(Components):
|
||||||
|
"""
|
||||||
|
ReasoningComponents generate a set of components and a set of reasoning matrices.
|
||||||
|
Every Component has a reasoning matrix assigned.
|
||||||
|
|
||||||
|
A reasoning matrix is a Nx2 matrix, where N is the number of Classes.
|
||||||
|
The first element is called positive reasoning :math:`p`, the second negative reasoning :math:`n`.
|
||||||
|
A components can reason in favour (positive) of a class, against (negative) a class or not at all (neutral).
|
||||||
|
|
||||||
|
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0 \leq n+p \leq 1`.
|
||||||
|
Therefore :math:`n` and :math:`p` are two elements of a three element probability distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
reasonings=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_components=None):
|
||||||
|
if initialized_components is not None:
|
||||||
|
super().__init__(initialized_components=initialized_components[0])
|
||||||
|
self._reasonings = initialized_components[1]
|
||||||
|
else:
|
||||||
|
self._initialize_reasonings(reasonings)
|
||||||
|
super().__init__(number_of_components=len(self._reasonings),
|
||||||
|
initializer=initializer)
|
||||||
|
|
||||||
|
def _initialize_reasonings(self, reasonings):
|
||||||
|
if type(reasonings) == tuple:
|
||||||
|
num_classes, number_of_components = reasonings
|
||||||
|
reasonings = ZeroReasoningsInitializer(num_classes,
|
||||||
|
number_of_components)
|
||||||
|
|
||||||
|
self._reasonings = reasonings.generate()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reasonings(self):
|
||||||
|
"""
|
||||||
|
Returns Reasoning Matrix.
|
||||||
|
|
||||||
|
Dimension NxCx2
|
||||||
|
"""
|
||||||
|
return self._reasonings.detach().cpu()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return super().forward(), self._reasonings
|
132
prototorch/components/initializers.py
Normal file
132
prototorch/components/initializers.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
import torch
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
|
||||||
|
# Components
|
||||||
|
class ComponentsInitializer:
|
||||||
|
def generate(self, number_of_components):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DimensionAwareInitializer(ComponentsInitializer):
|
||||||
|
def __init__(self, c_dims):
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(c_dims, Iterable):
|
||||||
|
self.components_dims = tuple(c_dims)
|
||||||
|
else:
|
||||||
|
self.components_dims = (c_dims, )
|
||||||
|
|
||||||
|
|
||||||
|
class OnesInitializer(DimensionAwareInitializer):
|
||||||
|
def generate(self, length):
|
||||||
|
gen_dims = (length, ) + self.components_dims
|
||||||
|
return torch.ones(gen_dims)
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosInitializer(DimensionAwareInitializer):
|
||||||
|
def generate(self, length):
|
||||||
|
gen_dims = (length, ) + self.components_dims
|
||||||
|
return torch.zeros(gen_dims)
|
||||||
|
|
||||||
|
|
||||||
|
class UniformInitializer(DimensionAwareInitializer):
|
||||||
|
def __init__(self, c_dims, min=0.0, max=1.0):
|
||||||
|
super().__init__(c_dims)
|
||||||
|
|
||||||
|
self.min = min
|
||||||
|
self.max = max
|
||||||
|
|
||||||
|
def generate(self, length):
|
||||||
|
gen_dims = (length, ) + self.components_dims
|
||||||
|
return torch.FloatTensor(gen_dims).uniform_(self.min, self.max)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionAwareInitializer(ComponentsInitializer):
|
||||||
|
def __init__(self, positions):
|
||||||
|
super().__init__()
|
||||||
|
self.data = positions
|
||||||
|
|
||||||
|
|
||||||
|
class SelectionInitializer(PositionAwareInitializer):
|
||||||
|
def generate(self, length):
|
||||||
|
indices = torch.LongTensor(length).random_(0, len(self.data))
|
||||||
|
return self.data[indices]
|
||||||
|
|
||||||
|
|
||||||
|
class MeanInitializer(PositionAwareInitializer):
|
||||||
|
def generate(self, length):
|
||||||
|
mean = torch.mean(self.data, dim=0)
|
||||||
|
repeat_dim = [length] + [1] * len(mean.shape)
|
||||||
|
return mean.repeat(repeat_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassAwareInitializer(ComponentsInitializer):
|
||||||
|
def __init__(self, positions, classes):
|
||||||
|
super().__init__()
|
||||||
|
self.data = positions
|
||||||
|
self.classes = classes
|
||||||
|
|
||||||
|
self.names = torch.unique(self.classes)
|
||||||
|
self.num_classes = len(self.names)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||||
|
def __init__(self, positions, classes):
|
||||||
|
super().__init__(positions, classes)
|
||||||
|
|
||||||
|
self.initializers = []
|
||||||
|
for name in self.names:
|
||||||
|
class_data = self.data[self.classes == name]
|
||||||
|
class_initializer = MeanInitializer(class_data)
|
||||||
|
self.initializers.append(class_initializer)
|
||||||
|
|
||||||
|
def generate(self, length):
|
||||||
|
per_class = length // self.num_classes
|
||||||
|
return torch.vstack(
|
||||||
|
[init.generate(per_class) for init in self.initializers])
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||||
|
def __init__(self, positions, classes):
|
||||||
|
super().__init__(positions, classes)
|
||||||
|
|
||||||
|
self.initializers = []
|
||||||
|
for name in self.names:
|
||||||
|
class_data = self.data[self.classes == name]
|
||||||
|
class_initializer = SelectionInitializer(class_data)
|
||||||
|
self.initializers.append(class_initializer)
|
||||||
|
|
||||||
|
def generate(self, length):
|
||||||
|
per_class = length // self.num_classes
|
||||||
|
return torch.vstack(
|
||||||
|
[init.generate(per_class) for init in self.initializers])
|
||||||
|
|
||||||
|
|
||||||
|
# Labels
|
||||||
|
class LabelsInitializer:
|
||||||
|
def generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EqualLabelInitializer(LabelsInitializer):
|
||||||
|
def __init__(self, classes, per_class):
|
||||||
|
self.classes = classes
|
||||||
|
self.per_class = per_class
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
|
||||||
|
|
||||||
|
|
||||||
|
# Reasonings
|
||||||
|
class ReasoningsInitializer:
|
||||||
|
def generate(self, length):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroReasoningsInitializer(ReasoningsInitializer):
|
||||||
|
def __init__(self, classes, length):
|
||||||
|
self.classes = classes
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
return torch.zeros((self.length, self.classes, 2))
|
@ -39,6 +39,9 @@ class Prototypes1D(_Prototypes):
|
|||||||
one_hot_labels=False,
|
one_hot_labels=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
warnings.warn(
|
||||||
|
PendingDeprecationWarning(
|
||||||
|
"Prototypes1D will be replaced in future versions."))
|
||||||
|
|
||||||
# Convert tensors to python lists before processing
|
# Convert tensors to python lists before processing
|
||||||
if prototype_distribution is not None:
|
if prototype_distribution is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user