Create Component and initializer classes.
This commit is contained in:
parent
7c30ffe2c7
commit
40751aa50a
67
examples/new_components.py
Normal file
67
examples/new_components.py
Normal file
@ -0,0 +1,67 @@
|
||||
#
|
||||
# DATASET
|
||||
#
|
||||
import torch
|
||||
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
scaler = StandardScaler()
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
scaler.fit(x_train)
|
||||
x_train = scaler.transform(x_train)
|
||||
|
||||
x_train = torch.Tensor(x_train)
|
||||
y_train = torch.Tensor(y_train)
|
||||
num_classes = len(torch.unique(y_train))
|
||||
|
||||
#
|
||||
# CREATE NEW COMPONENTS
|
||||
#
|
||||
from prototorch.components import *
|
||||
from prototorch.components.initializers import *
|
||||
|
||||
unsupervised = Components(6, SelectionInitializer(x_train))
|
||||
print(unsupervised())
|
||||
|
||||
prototypes = LabeledComponents(
|
||||
(3, 2), StratifiedSelectionInitializer(x_train, y_train))
|
||||
print(prototypes())
|
||||
|
||||
components = ReasoningComponents(
|
||||
(3, 6), StratifiedSelectionInitializer(x_train, y_train))
|
||||
print(components())
|
||||
|
||||
#
|
||||
# TEST SERIALIZATION
|
||||
#
|
||||
import io
|
||||
|
||||
save = io.BytesIO()
|
||||
torch.save(unsupervised, save)
|
||||
save.seek(0)
|
||||
serialized_unsupervised = torch.load(save)
|
||||
|
||||
assert torch.all(unsupervised.components == serialized_unsupervised.components
|
||||
), "Serialization of Components failed."
|
||||
|
||||
save = io.BytesIO()
|
||||
torch.save(prototypes, save)
|
||||
save.seek(0)
|
||||
serialized_prototypes = torch.load(save)
|
||||
|
||||
assert torch.all(prototypes.components == serialized_prototypes.components
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(prototypes.labels == serialized_prototypes.labels
|
||||
), "Serialization of Components failed."
|
||||
|
||||
save = io.BytesIO()
|
||||
torch.save(components, save)
|
||||
save.seek(0)
|
||||
serialized_components = torch.load(save)
|
||||
|
||||
assert torch.all(components.components == serialized_components.components
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(components.reasonings == serialized_components.reasonings
|
||||
), "Serialization of Components failed."
|
7
prototorch/components/__init__.py
Normal file
7
prototorch/components/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from prototorch.components.components import Components, LabeledComponents, ReasoningComponents
|
||||
|
||||
__all__ = [
|
||||
"Components",
|
||||
"LabeledComponents",
|
||||
"ReasoningComponents",
|
||||
]
|
130
prototorch/components/components.py
Normal file
130
prototorch/components/components.py
Normal file
@ -0,0 +1,130 @@
|
||||
"""ProtoTorch components modules."""
|
||||
|
||||
from typing import Tuple
|
||||
import warnings
|
||||
from prototorch.components.initializers import EqualLabelInitializer, ZeroReasoningsInitializer
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
|
||||
|
||||
class Components(torch.nn.Module):
|
||||
"""
|
||||
Components is a set of learnable Tensors.
|
||||
"""
|
||||
def __init__(self,
|
||||
number_of_components=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None,
|
||||
dtype=torch.float32):
|
||||
super().__init__()
|
||||
|
||||
# Ignore all initialization settings if initialized_components is given.
|
||||
if initialized_components is not None:
|
||||
self._components = Parameter(initialized_components)
|
||||
if number_of_components is not None or initializer is not None:
|
||||
warnings.warn(
|
||||
"Arguments ignored while initializing Components")
|
||||
else:
|
||||
self._initialize_components(number_of_components, initializer)
|
||||
|
||||
def _initialize_components(self, number_of_components, initializer):
|
||||
self._components = Parameter(
|
||||
initializer.generate(number_of_components))
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
"""
|
||||
Tensor containing the component tensors.
|
||||
"""
|
||||
return self._components.detach().cpu()
|
||||
|
||||
def forward(self):
|
||||
return self._components
|
||||
|
||||
def extra_repr(self):
|
||||
return f"components.shape: {tuple(self._components.shape)}"
|
||||
|
||||
|
||||
class LabeledComponents(Components):
|
||||
"""
|
||||
LabeledComponents generate a set of components and a set of labels.
|
||||
Every Component has a label assigned.
|
||||
"""
|
||||
def __init__(self,
|
||||
labels=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
super().__init__(initialized_components=initialized_components[0])
|
||||
self._labels = initialized_components[1]
|
||||
else:
|
||||
self._initialize_labels(labels, initializer)
|
||||
super().__init__(number_of_components=len(self._labels),
|
||||
initializer=initializer)
|
||||
|
||||
def _initialize_labels(self, labels, initializer):
|
||||
if type(labels) == tuple:
|
||||
num_classes, prototypes_per_class = labels
|
||||
labels = EqualLabelInitializer(num_classes, prototypes_per_class)
|
||||
|
||||
self._labels = labels.generate()
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
"""
|
||||
Tensor containing the component tensors.
|
||||
"""
|
||||
return self._labels.detach().cpu()
|
||||
|
||||
def forward(self):
|
||||
return super().forward(), self._labels
|
||||
|
||||
|
||||
class ReasoningComponents(Components):
|
||||
"""
|
||||
ReasoningComponents generate a set of components and a set of reasoning matrices.
|
||||
Every Component has a reasoning matrix assigned.
|
||||
|
||||
A reasoning matrix is a Nx2 matrix, where N is the number of Classes.
|
||||
The first element is called positive reasoning :math:`p`, the second negative reasoning :math:`n`.
|
||||
A components can reason in favour (positive) of a class, against (negative) a class or not at all (neutral).
|
||||
|
||||
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0 \leq n+p \leq 1`.
|
||||
Therefore :math:`n` and :math:`p` are two elements of a three element probability distribution.
|
||||
"""
|
||||
def __init__(self,
|
||||
reasonings=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
super().__init__(initialized_components=initialized_components[0])
|
||||
self._reasonings = initialized_components[1]
|
||||
else:
|
||||
self._initialize_reasonings(reasonings)
|
||||
super().__init__(number_of_components=len(self._reasonings),
|
||||
initializer=initializer)
|
||||
|
||||
def _initialize_reasonings(self, reasonings):
|
||||
if type(reasonings) == tuple:
|
||||
num_classes, number_of_components = reasonings
|
||||
reasonings = ZeroReasoningsInitializer(num_classes,
|
||||
number_of_components)
|
||||
|
||||
self._reasonings = reasonings.generate()
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
"""
|
||||
Returns Reasoning Matrix.
|
||||
|
||||
Dimension NxCx2
|
||||
"""
|
||||
return self._reasonings.detach().cpu()
|
||||
|
||||
def forward(self):
|
||||
return super().forward(), self._reasonings
|
132
prototorch/components/initializers.py
Normal file
132
prototorch/components/initializers.py
Normal file
@ -0,0 +1,132 @@
|
||||
import torch
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
# Components
|
||||
class ComponentsInitializer:
|
||||
def generate(self, number_of_components):
|
||||
pass
|
||||
|
||||
|
||||
class DimensionAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, c_dims):
|
||||
super().__init__()
|
||||
if isinstance(c_dims, Iterable):
|
||||
self.components_dims = tuple(c_dims)
|
||||
else:
|
||||
self.components_dims = (c_dims, )
|
||||
|
||||
|
||||
class OnesInitializer(DimensionAwareInitializer):
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.ones(gen_dims)
|
||||
|
||||
|
||||
class ZerosInitializer(DimensionAwareInitializer):
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.zeros(gen_dims)
|
||||
|
||||
|
||||
class UniformInitializer(DimensionAwareInitializer):
|
||||
def __init__(self, c_dims, min=0.0, max=1.0):
|
||||
super().__init__(c_dims)
|
||||
|
||||
self.min = min
|
||||
self.max = max
|
||||
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.FloatTensor(gen_dims).uniform_(self.min, self.max)
|
||||
|
||||
|
||||
class PositionAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, positions):
|
||||
super().__init__()
|
||||
self.data = positions
|
||||
|
||||
|
||||
class SelectionInitializer(PositionAwareInitializer):
|
||||
def generate(self, length):
|
||||
indices = torch.LongTensor(length).random_(0, len(self.data))
|
||||
return self.data[indices]
|
||||
|
||||
|
||||
class MeanInitializer(PositionAwareInitializer):
|
||||
def generate(self, length):
|
||||
mean = torch.mean(self.data, dim=0)
|
||||
repeat_dim = [length] + [1] * len(mean.shape)
|
||||
return mean.repeat(repeat_dim)
|
||||
|
||||
|
||||
class ClassAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, positions, classes):
|
||||
super().__init__()
|
||||
self.data = positions
|
||||
self.classes = classes
|
||||
|
||||
self.names = torch.unique(self.classes)
|
||||
self.num_classes = len(self.names)
|
||||
|
||||
|
||||
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||
def __init__(self, positions, classes):
|
||||
super().__init__(positions, classes)
|
||||
|
||||
self.initializers = []
|
||||
for name in self.names:
|
||||
class_data = self.data[self.classes == name]
|
||||
class_initializer = MeanInitializer(class_data)
|
||||
self.initializers.append(class_initializer)
|
||||
|
||||
def generate(self, length):
|
||||
per_class = length // self.num_classes
|
||||
return torch.vstack(
|
||||
[init.generate(per_class) for init in self.initializers])
|
||||
|
||||
|
||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||
def __init__(self, positions, classes):
|
||||
super().__init__(positions, classes)
|
||||
|
||||
self.initializers = []
|
||||
for name in self.names:
|
||||
class_data = self.data[self.classes == name]
|
||||
class_initializer = SelectionInitializer(class_data)
|
||||
self.initializers.append(class_initializer)
|
||||
|
||||
def generate(self, length):
|
||||
per_class = length // self.num_classes
|
||||
return torch.vstack(
|
||||
[init.generate(per_class) for init in self.initializers])
|
||||
|
||||
|
||||
# Labels
|
||||
class LabelsInitializer:
|
||||
def generate(self):
|
||||
pass
|
||||
|
||||
|
||||
class EqualLabelInitializer(LabelsInitializer):
|
||||
def __init__(self, classes, per_class):
|
||||
self.classes = classes
|
||||
self.per_class = per_class
|
||||
|
||||
def generate(self):
|
||||
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
|
||||
|
||||
|
||||
# Reasonings
|
||||
class ReasoningsInitializer:
|
||||
def generate(self, length):
|
||||
pass
|
||||
|
||||
|
||||
class ZeroReasoningsInitializer(ReasoningsInitializer):
|
||||
def __init__(self, classes, length):
|
||||
self.classes = classes
|
||||
self.length = length
|
||||
|
||||
def generate(self):
|
||||
return torch.zeros((self.length, self.classes, 2))
|
@ -39,6 +39,9 @@ class Prototypes1D(_Prototypes):
|
||||
one_hot_labels=False,
|
||||
**kwargs,
|
||||
):
|
||||
warnings.warn(
|
||||
PendingDeprecationWarning(
|
||||
"Prototypes1D will be replaced in future versions."))
|
||||
|
||||
# Convert tensors to python lists before processing
|
||||
if prototype_distribution is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user