Overload distribution argument in component initializers

The component initializers behave differently based on the type of the
`distribution` argument. If it is a Python
[list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed
that there are as many entries in this list as there are classes, and the number
at each location of this list describes the number of prototypes to be used for
that particular class. So, `[1, 1, 1]` implies that we have three classes with
one prototype per class. If it is a Python
[tuple](https://docs.python.org/3/tutorial/datastructures.html), it a shorthand
of `(num_classes, prototypes_per_class)` is assumed. If it is a Python
[dictionary](https://docs.python.org/3/tutorial/datastructures.html), the
key-value pairs describe the class label and the number of prototypes for that
class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes
with labels `{1, 2, 3}`, each equipped with two prototypes.
This commit is contained in:
Jensun Ravichandran 2021-05-25 20:05:29 +02:00
parent 21e3e3b82d
commit 8a291f7bfb
2 changed files with 31 additions and 18 deletions

View File

@ -1,34 +1,33 @@
"""ProtoTorch components modules."""
import warnings
from typing import Tuple
import torch
from prototorch.components.initializers import (ClassAwareInitializer,
ComponentsInitializer,
CustomLabelsInitializer,
EqualLabelsInitializer,
UnequalLabelsInitializer,
ZeroReasoningsInitializer)
from prototorch.functions.initializers import get_initializer
from torch.nn.parameter import Parameter
class Components(torch.nn.Module):
"""Components is a set of learnable Tensors."""
def __init__(self,
ncomps=None,
num_components=None,
initializer=None,
*,
initialized_components=None):
super().__init__()
self.ncomps = ncomps
self.num_components = num_components
# Ignore all initialization settings if initialized_components is given.
if initialized_components is not None:
self.register_parameter("_components",
Parameter(initialized_components))
if ncomps is not None or initializer is not None:
if num_components is not None or initializer is not None:
wmsg = "Arguments ignored while initializing Components"
warnings.warn(wmsg)
else:
@ -43,7 +42,7 @@ class Components(torch.nn.Module):
def _initialize_components(self, initializer):
self._precheck_initializer(initializer)
_components = initializer.generate(self.ncomps)
_components = initializer.generate(self.num_components)
self.register_parameter("_components", Parameter(_components))
@property
@ -80,16 +79,20 @@ class LabeledComponents(Components):
def _initialize_components(self, initializer):
if isinstance(initializer, ClassAwareInitializer):
self._precheck_initializer(initializer)
_components = initializer.generate(self.ncomps, self.distribution)
_components = initializer.generate(self.num_components,
self.distribution)
self.register_parameter("_components", Parameter(_components))
else:
super()._initialize_components(initializer)
def _initialize_labels(self, distribution):
if type(distribution) == dict:
labels = EqualLabelsInitializer(
distribution["num_classes"],
distribution["prototypes_per_class"])
if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer(
distribution["num_classes"],
distribution["prototypes_per_class"])
else:
labels = CustomLabelsInitializer(distribution)
elif type(distribution) == tuple:
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
@ -139,8 +142,8 @@ class ReasoningComponents(Components):
def _initialize_reasonings(self, reasonings):
if type(reasonings) == tuple:
num_classes, ncomps = reasonings
reasonings = ZeroReasoningsInitializer(num_classes, ncomps)
num_classes, num_components = reasonings
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
_reasonings = reasonings.generate()
self.register_parameter("_reasonings", _reasonings)

View File

@ -103,9 +103,9 @@ class ClassAwareInitializer(ComponentsInitializer):
def _get_samples_from_initializer(self, length, dist):
if not dist:
per_class = length // self.num_classes
dist = self.num_classes * [per_class]
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
samples_list = [
init.generate(n) for init, n in zip(self.initializers, dist)
init.generate(n) for init, n in zip(self.initializers, dist.values())
]
out = torch.vstack(samples_list)
with torch.no_grad():
@ -174,10 +174,13 @@ class UnequalLabelsInitializer(LabelsInitializer):
def distribution(self):
return self.dist
def generate(self):
clabels = range(len(self.dist))
labels = list(chain(*[[i] * n for i, n in zip(clabels, self.dist)]))
return torch.tensor(labels)
def generate(self, clabels=None, dist=None):
if not clabels:
clabels = range(len(self.dist))
if not dist:
dist = self.dist
labels = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
return torch.LongTensor(labels)
class EqualLabelsInitializer(LabelsInitializer):
@ -193,6 +196,13 @@ class EqualLabelsInitializer(LabelsInitializer):
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
class CustomLabelsInitializer(UnequalLabelsInitializer):
def generate(self):
clabels = list(self.dist.keys())
dist = list(self.dist.values())
return super().generate(clabels, dist)
# Reasonings
class ReasoningsInitializer:
def generate(self, length):