Overload distribution argument in component initializers

The component initializers behave differently based on the type of the
`distribution` argument. If it is a Python
[list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed
that there are as many entries in this list as there are classes, and the number
at each location of this list describes the number of prototypes to be used for
that particular class. So, `[1, 1, 1]` implies that we have three classes with
one prototype per class. If it is a Python
[tuple](https://docs.python.org/3/tutorial/datastructures.html), it a shorthand
of `(num_classes, prototypes_per_class)` is assumed. If it is a Python
[dictionary](https://docs.python.org/3/tutorial/datastructures.html), the
key-value pairs describe the class label and the number of prototypes for that
class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes
with labels `{1, 2, 3}`, each equipped with two prototypes.
This commit is contained in:
Jensun Ravichandran 2021-05-25 20:05:29 +02:00
parent 21e3e3b82d
commit 8a291f7bfb
2 changed files with 31 additions and 18 deletions

View File

@ -1,34 +1,33 @@
"""ProtoTorch components modules.""" """ProtoTorch components modules."""
import warnings import warnings
from typing import Tuple
import torch import torch
from prototorch.components.initializers import (ClassAwareInitializer, from prototorch.components.initializers import (ClassAwareInitializer,
ComponentsInitializer, ComponentsInitializer,
CustomLabelsInitializer,
EqualLabelsInitializer, EqualLabelsInitializer,
UnequalLabelsInitializer, UnequalLabelsInitializer,
ZeroReasoningsInitializer) ZeroReasoningsInitializer)
from prototorch.functions.initializers import get_initializer
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
class Components(torch.nn.Module): class Components(torch.nn.Module):
"""Components is a set of learnable Tensors.""" """Components is a set of learnable Tensors."""
def __init__(self, def __init__(self,
ncomps=None, num_components=None,
initializer=None, initializer=None,
*, *,
initialized_components=None): initialized_components=None):
super().__init__() super().__init__()
self.ncomps = ncomps self.num_components = num_components
# Ignore all initialization settings if initialized_components is given. # Ignore all initialization settings if initialized_components is given.
if initialized_components is not None: if initialized_components is not None:
self.register_parameter("_components", self.register_parameter("_components",
Parameter(initialized_components)) Parameter(initialized_components))
if ncomps is not None or initializer is not None: if num_components is not None or initializer is not None:
wmsg = "Arguments ignored while initializing Components" wmsg = "Arguments ignored while initializing Components"
warnings.warn(wmsg) warnings.warn(wmsg)
else: else:
@ -43,7 +42,7 @@ class Components(torch.nn.Module):
def _initialize_components(self, initializer): def _initialize_components(self, initializer):
self._precheck_initializer(initializer) self._precheck_initializer(initializer)
_components = initializer.generate(self.ncomps) _components = initializer.generate(self.num_components)
self.register_parameter("_components", Parameter(_components)) self.register_parameter("_components", Parameter(_components))
@property @property
@ -80,16 +79,20 @@ class LabeledComponents(Components):
def _initialize_components(self, initializer): def _initialize_components(self, initializer):
if isinstance(initializer, ClassAwareInitializer): if isinstance(initializer, ClassAwareInitializer):
self._precheck_initializer(initializer) self._precheck_initializer(initializer)
_components = initializer.generate(self.ncomps, self.distribution) _components = initializer.generate(self.num_components,
self.distribution)
self.register_parameter("_components", Parameter(_components)) self.register_parameter("_components", Parameter(_components))
else: else:
super()._initialize_components(initializer) super()._initialize_components(initializer)
def _initialize_labels(self, distribution): def _initialize_labels(self, distribution):
if type(distribution) == dict: if type(distribution) == dict:
if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer( labels = EqualLabelsInitializer(
distribution["num_classes"], distribution["num_classes"],
distribution["prototypes_per_class"]) distribution["prototypes_per_class"])
else:
labels = CustomLabelsInitializer(distribution)
elif type(distribution) == tuple: elif type(distribution) == tuple:
num_classes, prototypes_per_class = distribution num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class) labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
@ -139,8 +142,8 @@ class ReasoningComponents(Components):
def _initialize_reasonings(self, reasonings): def _initialize_reasonings(self, reasonings):
if type(reasonings) == tuple: if type(reasonings) == tuple:
num_classes, ncomps = reasonings num_classes, num_components = reasonings
reasonings = ZeroReasoningsInitializer(num_classes, ncomps) reasonings = ZeroReasoningsInitializer(num_classes, num_components)
_reasonings = reasonings.generate() _reasonings = reasonings.generate()
self.register_parameter("_reasonings", _reasonings) self.register_parameter("_reasonings", _reasonings)

View File

@ -103,9 +103,9 @@ class ClassAwareInitializer(ComponentsInitializer):
def _get_samples_from_initializer(self, length, dist): def _get_samples_from_initializer(self, length, dist):
if not dist: if not dist:
per_class = length // self.num_classes per_class = length // self.num_classes
dist = self.num_classes * [per_class] dist = dict(zip(self.clabels, self.num_classes * [per_class]))
samples_list = [ samples_list = [
init.generate(n) for init, n in zip(self.initializers, dist) init.generate(n) for init, n in zip(self.initializers, dist.values())
] ]
out = torch.vstack(samples_list) out = torch.vstack(samples_list)
with torch.no_grad(): with torch.no_grad():
@ -174,10 +174,13 @@ class UnequalLabelsInitializer(LabelsInitializer):
def distribution(self): def distribution(self):
return self.dist return self.dist
def generate(self): def generate(self, clabels=None, dist=None):
if not clabels:
clabels = range(len(self.dist)) clabels = range(len(self.dist))
labels = list(chain(*[[i] * n for i, n in zip(clabels, self.dist)])) if not dist:
return torch.tensor(labels) dist = self.dist
labels = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
return torch.LongTensor(labels)
class EqualLabelsInitializer(LabelsInitializer): class EqualLabelsInitializer(LabelsInitializer):
@ -193,6 +196,13 @@ class EqualLabelsInitializer(LabelsInitializer):
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten() return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
class CustomLabelsInitializer(UnequalLabelsInitializer):
def generate(self):
clabels = list(self.dist.keys())
dist = list(self.dist.values())
return super().generate(clabels, dist)
# Reasonings # Reasonings
class ReasoningsInitializer: class ReasoningsInitializer:
def generate(self, length): def generate(self, length):