Update components

This commit is contained in:
Jensun Ravichandran 2021-04-29 18:06:26 +02:00
parent a70166280a
commit 9b663477fd
3 changed files with 41 additions and 45 deletions

View File

@ -1,8 +1,10 @@
# """This example script shows the usage of the new components architecture.
# DATASET
#
import torch
Serialization/deserialization also works as expected.
"""
# DATASET
import torch
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
@ -16,9 +18,7 @@ x_train = torch.Tensor(x_train)
y_train = torch.Tensor(y_train) y_train = torch.Tensor(y_train)
num_classes = len(torch.unique(y_train)) num_classes = len(torch.unique(y_train))
#
# CREATE NEW COMPONENTS # CREATE NEW COMPONENTS
#
from prototorch.components import * from prototorch.components import *
from prototorch.components.initializers import * from prototorch.components.initializers import *
@ -33,9 +33,7 @@ components = ReasoningComponents(
(3, 6), StratifiedSelectionInitializer(x_train, y_train)) (3, 6), StratifiedSelectionInitializer(x_train, y_train))
print(components()) print(components())
#
# TEST SERIALIZATION # TEST SERIALIZATION
#
import io import io
save = io.BytesIO() save = io.BytesIO()
@ -53,8 +51,8 @@ serialized_prototypes = torch.load(save)
assert torch.all(prototypes.components == serialized_prototypes.components assert torch.all(prototypes.components == serialized_prototypes.components
), "Serialization of Components failed." ), "Serialization of Components failed."
assert torch.all(prototypes.labels == serialized_prototypes.labels assert torch.all(prototypes.component_labels == serialized_prototypes.
), "Serialization of Components failed." component_labels), "Serialization of Components failed."
save = io.BytesIO() save = io.BytesIO()
torch.save(components, save) torch.save(components, save)

View File

@ -1,7 +1,2 @@
from prototorch.components.components import Components, LabeledComponents, ReasoningComponents from prototorch.components.components import *
from prototorch.components.initializers import *
__all__ = [
"Components",
"LabeledComponents",
"ReasoningComponents",
]

View File

@ -1,18 +1,18 @@
"""ProtoTorch components modules.""" """ProtoTorch components modules."""
from typing import Tuple
import warnings import warnings
from prototorch.components.initializers import EqualLabelInitializer, ZeroReasoningsInitializer from typing import Tuple
import torch
from torch.nn.parameter import Parameter
import torch
from prototorch.components.initializers import (ComponentsInitializer,
EqualLabelInitializer,
ZeroReasoningsInitializer)
from prototorch.functions.initializers import get_initializer from prototorch.functions.initializers import get_initializer
from torch.nn.parameter import Parameter
class Components(torch.nn.Module): class Components(torch.nn.Module):
""" """Components is a set of learnable Tensors."""
Components is a set of learnable Tensors.
"""
def __init__(self, def __init__(self,
number_of_components=None, number_of_components=None,
initializer=None, initializer=None,
@ -31,14 +31,16 @@ class Components(torch.nn.Module):
self._initialize_components(number_of_components, initializer) self._initialize_components(number_of_components, initializer)
def _initialize_components(self, number_of_components, initializer): def _initialize_components(self, number_of_components, initializer):
if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some kind of `ComponentsInitializer`. " \
f"You provided: {initializer=} instead."
raise TypeError(emsg)
self._components = Parameter( self._components = Parameter(
initializer.generate(number_of_components)) initializer.generate(number_of_components))
@property @property
def components(self): def components(self):
""" """Tensor containing the component tensors."""
Tensor containing the component tensors.
"""
return self._components.detach().cpu() return self._components.detach().cpu()
def forward(self): def forward(self):
@ -49,8 +51,8 @@ class Components(torch.nn.Module):
class LabeledComponents(Components): class LabeledComponents(Components):
""" """LabeledComponents generate a set of components and a set of labels.
LabeledComponents generate a set of components and a set of labels.
Every Component has a label assigned. Every Component has a label assigned.
""" """
def __init__(self, def __init__(self,
@ -62,11 +64,11 @@ class LabeledComponents(Components):
super().__init__(initialized_components=initialized_components[0]) super().__init__(initialized_components=initialized_components[0])
self._labels = initialized_components[1] self._labels = initialized_components[1]
else: else:
self._initialize_labels(labels, initializer) self._initialize_labels(labels)
super().__init__(number_of_components=len(self._labels), super().__init__(number_of_components=len(self._labels),
initializer=initializer) initializer=initializer)
def _initialize_labels(self, labels, initializer): def _initialize_labels(self, labels):
if type(labels) == tuple: if type(labels) == tuple:
num_classes, prototypes_per_class = labels num_classes, prototypes_per_class = labels
labels = EqualLabelInitializer(num_classes, prototypes_per_class) labels = EqualLabelInitializer(num_classes, prototypes_per_class)
@ -74,10 +76,8 @@ class LabeledComponents(Components):
self._labels = labels.generate() self._labels = labels.generate()
@property @property
def labels(self): def component_labels(self):
""" """Tensor containing the component tensors."""
Tensor containing the component tensors.
"""
return self._labels.detach().cpu() return self._labels.detach().cpu()
def forward(self): def forward(self):
@ -85,16 +85,19 @@ class LabeledComponents(Components):
class ReasoningComponents(Components): class ReasoningComponents(Components):
""" """ReasoningComponents generate a set of components and a set of reasoning matrices.
ReasoningComponents generate a set of components and a set of reasoning matrices.
Every Component has a reasoning matrix assigned. Every Component has a reasoning matrix assigned.
A reasoning matrix is a Nx2 matrix, where N is the number of Classes. A reasoning matrix is a Nx2 matrix, where N is the number of Classes. The
The first element is called positive reasoning :math:`p`, the second negative reasoning :math:`n`. first element is called positive reasoning :math:`p`, the second negative
A components can reason in favour (positive) of a class, against (negative) a class or not at all (neutral). reasoning :math:`n`. A components can reason in favour (positive) of a
class, against (negative) a class or not at all (neutral).
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
three element probability distribution.
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0 \leq n+p \leq 1`.
Therefore :math:`n` and :math:`p` are two elements of a three element probability distribution.
""" """
def __init__(self, def __init__(self,
reasonings=None, reasonings=None,
@ -119,12 +122,12 @@ class ReasoningComponents(Components):
@property @property
def reasonings(self): def reasonings(self):
""" """Returns Reasoning Matrix.
Returns Reasoning Matrix.
Dimension NxCx2 Dimension NxCx2
""" """
return self._reasonings.detach().cpu() return self._reasonings.detach().cpu()
def forward(self): def forward(self):
return super().forward(), self._reasonings return super().forward(), self._reasonings