Update components

This commit is contained in:
Jensun Ravichandran 2021-04-29 18:06:26 +02:00
parent a70166280a
commit 9b663477fd
3 changed files with 41 additions and 45 deletions

View File

@ -1,8 +1,10 @@
#
# DATASET
#
import torch
"""This example script shows the usage of the new components architecture.
Serialization/deserialization also works as expected.
"""
# DATASET
import torch
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
@ -16,9 +18,7 @@ x_train = torch.Tensor(x_train)
y_train = torch.Tensor(y_train)
num_classes = len(torch.unique(y_train))
#
# CREATE NEW COMPONENTS
#
from prototorch.components import *
from prototorch.components.initializers import *
@ -33,9 +33,7 @@ components = ReasoningComponents(
(3, 6), StratifiedSelectionInitializer(x_train, y_train))
print(components())
#
# TEST SERIALIZATION
#
import io
save = io.BytesIO()
@ -53,8 +51,8 @@ serialized_prototypes = torch.load(save)
assert torch.all(prototypes.components == serialized_prototypes.components
), "Serialization of Components failed."
assert torch.all(prototypes.labels == serialized_prototypes.labels
), "Serialization of Components failed."
assert torch.all(prototypes.component_labels == serialized_prototypes.
component_labels), "Serialization of Components failed."
save = io.BytesIO()
torch.save(components, save)

View File

@ -1,7 +1,2 @@
from prototorch.components.components import Components, LabeledComponents, ReasoningComponents
__all__ = [
"Components",
"LabeledComponents",
"ReasoningComponents",
]
from prototorch.components.components import *
from prototorch.components.initializers import *

View File

@ -1,18 +1,18 @@
"""ProtoTorch components modules."""
from typing import Tuple
import warnings
from prototorch.components.initializers import EqualLabelInitializer, ZeroReasoningsInitializer
import torch
from torch.nn.parameter import Parameter
from typing import Tuple
import torch
from prototorch.components.initializers import (ComponentsInitializer,
EqualLabelInitializer,
ZeroReasoningsInitializer)
from prototorch.functions.initializers import get_initializer
from torch.nn.parameter import Parameter
class Components(torch.nn.Module):
"""
Components is a set of learnable Tensors.
"""
"""Components is a set of learnable Tensors."""
def __init__(self,
number_of_components=None,
initializer=None,
@ -31,14 +31,16 @@ class Components(torch.nn.Module):
self._initialize_components(number_of_components, initializer)
def _initialize_components(self, number_of_components, initializer):
if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some kind of `ComponentsInitializer`. " \
f"You provided: {initializer=} instead."
raise TypeError(emsg)
self._components = Parameter(
initializer.generate(number_of_components))
@property
def components(self):
"""
Tensor containing the component tensors.
"""
"""Tensor containing the component tensors."""
return self._components.detach().cpu()
def forward(self):
@ -49,8 +51,8 @@ class Components(torch.nn.Module):
class LabeledComponents(Components):
"""
LabeledComponents generate a set of components and a set of labels.
"""LabeledComponents generate a set of components and a set of labels.
Every Component has a label assigned.
"""
def __init__(self,
@ -62,11 +64,11 @@ class LabeledComponents(Components):
super().__init__(initialized_components=initialized_components[0])
self._labels = initialized_components[1]
else:
self._initialize_labels(labels, initializer)
self._initialize_labels(labels)
super().__init__(number_of_components=len(self._labels),
initializer=initializer)
def _initialize_labels(self, labels, initializer):
def _initialize_labels(self, labels):
if type(labels) == tuple:
num_classes, prototypes_per_class = labels
labels = EqualLabelInitializer(num_classes, prototypes_per_class)
@ -74,10 +76,8 @@ class LabeledComponents(Components):
self._labels = labels.generate()
@property
def labels(self):
"""
Tensor containing the component tensors.
"""
def component_labels(self):
"""Tensor containing the component tensors."""
return self._labels.detach().cpu()
def forward(self):
@ -85,16 +85,19 @@ class LabeledComponents(Components):
class ReasoningComponents(Components):
"""
ReasoningComponents generate a set of components and a set of reasoning matrices.
"""ReasoningComponents generate a set of components and a set of reasoning matrices.
Every Component has a reasoning matrix assigned.
A reasoning matrix is a Nx2 matrix, where N is the number of Classes.
The first element is called positive reasoning :math:`p`, the second negative reasoning :math:`n`.
A components can reason in favour (positive) of a class, against (negative) a class or not at all (neutral).
A reasoning matrix is a Nx2 matrix, where N is the number of Classes. The
first element is called positive reasoning :math:`p`, the second negative
reasoning :math:`n`. A components can reason in favour (positive) of a
class, against (negative) a class or not at all (neutral).
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
three element probability distribution.
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0 \leq n+p \leq 1`.
Therefore :math:`n` and :math:`p` are two elements of a three element probability distribution.
"""
def __init__(self,
reasonings=None,
@ -119,12 +122,12 @@ class ReasoningComponents(Components):
@property
def reasonings(self):
"""
Returns Reasoning Matrix.
"""Returns Reasoning Matrix.
Dimension NxCx2
"""
return self._reasonings.detach().cpu()
def forward(self):
return super().forward(), self._reasonings
return super().forward(), self._reasonings