Overload distribution
argument in component initializers
The component initializers behave differently based on the type of the `distribution` argument. If it is a Python [list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed that there are as many entries in this list as there are classes, and the number at each location of this list describes the number of prototypes to be used for that particular class. So, `[1, 1, 1]` implies that we have three classes with one prototype per class. If it is a Python [tuple](https://docs.python.org/3/tutorial/datastructures.html), it a shorthand of `(num_classes, prototypes_per_class)` is assumed. If it is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html), the key-value pairs describe the class label and the number of prototypes for that class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes with labels `{1, 2, 3}`, each equipped with two prototypes.
This commit is contained in:
parent
21e3e3b82d
commit
8a291f7bfb
@ -1,34 +1,33 @@
|
|||||||
"""ProtoTorch components modules."""
|
"""ProtoTorch components modules."""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.components.initializers import (ClassAwareInitializer,
|
from prototorch.components.initializers import (ClassAwareInitializer,
|
||||||
ComponentsInitializer,
|
ComponentsInitializer,
|
||||||
|
CustomLabelsInitializer,
|
||||||
EqualLabelsInitializer,
|
EqualLabelsInitializer,
|
||||||
UnequalLabelsInitializer,
|
UnequalLabelsInitializer,
|
||||||
ZeroReasoningsInitializer)
|
ZeroReasoningsInitializer)
|
||||||
from prototorch.functions.initializers import get_initializer
|
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
class Components(torch.nn.Module):
|
class Components(torch.nn.Module):
|
||||||
"""Components is a set of learnable Tensors."""
|
"""Components is a set of learnable Tensors."""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
ncomps=None,
|
num_components=None,
|
||||||
initializer=None,
|
initializer=None,
|
||||||
*,
|
*,
|
||||||
initialized_components=None):
|
initialized_components=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.ncomps = ncomps
|
self.num_components = num_components
|
||||||
|
|
||||||
# Ignore all initialization settings if initialized_components is given.
|
# Ignore all initialization settings if initialized_components is given.
|
||||||
if initialized_components is not None:
|
if initialized_components is not None:
|
||||||
self.register_parameter("_components",
|
self.register_parameter("_components",
|
||||||
Parameter(initialized_components))
|
Parameter(initialized_components))
|
||||||
if ncomps is not None or initializer is not None:
|
if num_components is not None or initializer is not None:
|
||||||
wmsg = "Arguments ignored while initializing Components"
|
wmsg = "Arguments ignored while initializing Components"
|
||||||
warnings.warn(wmsg)
|
warnings.warn(wmsg)
|
||||||
else:
|
else:
|
||||||
@ -43,7 +42,7 @@ class Components(torch.nn.Module):
|
|||||||
|
|
||||||
def _initialize_components(self, initializer):
|
def _initialize_components(self, initializer):
|
||||||
self._precheck_initializer(initializer)
|
self._precheck_initializer(initializer)
|
||||||
_components = initializer.generate(self.ncomps)
|
_components = initializer.generate(self.num_components)
|
||||||
self.register_parameter("_components", Parameter(_components))
|
self.register_parameter("_components", Parameter(_components))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -80,16 +79,20 @@ class LabeledComponents(Components):
|
|||||||
def _initialize_components(self, initializer):
|
def _initialize_components(self, initializer):
|
||||||
if isinstance(initializer, ClassAwareInitializer):
|
if isinstance(initializer, ClassAwareInitializer):
|
||||||
self._precheck_initializer(initializer)
|
self._precheck_initializer(initializer)
|
||||||
_components = initializer.generate(self.ncomps, self.distribution)
|
_components = initializer.generate(self.num_components,
|
||||||
|
self.distribution)
|
||||||
self.register_parameter("_components", Parameter(_components))
|
self.register_parameter("_components", Parameter(_components))
|
||||||
else:
|
else:
|
||||||
super()._initialize_components(initializer)
|
super()._initialize_components(initializer)
|
||||||
|
|
||||||
def _initialize_labels(self, distribution):
|
def _initialize_labels(self, distribution):
|
||||||
if type(distribution) == dict:
|
if type(distribution) == dict:
|
||||||
labels = EqualLabelsInitializer(
|
if "num_classes" in distribution.keys():
|
||||||
distribution["num_classes"],
|
labels = EqualLabelsInitializer(
|
||||||
distribution["prototypes_per_class"])
|
distribution["num_classes"],
|
||||||
|
distribution["prototypes_per_class"])
|
||||||
|
else:
|
||||||
|
labels = CustomLabelsInitializer(distribution)
|
||||||
elif type(distribution) == tuple:
|
elif type(distribution) == tuple:
|
||||||
num_classes, prototypes_per_class = distribution
|
num_classes, prototypes_per_class = distribution
|
||||||
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||||
@ -139,8 +142,8 @@ class ReasoningComponents(Components):
|
|||||||
|
|
||||||
def _initialize_reasonings(self, reasonings):
|
def _initialize_reasonings(self, reasonings):
|
||||||
if type(reasonings) == tuple:
|
if type(reasonings) == tuple:
|
||||||
num_classes, ncomps = reasonings
|
num_classes, num_components = reasonings
|
||||||
reasonings = ZeroReasoningsInitializer(num_classes, ncomps)
|
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
|
||||||
|
|
||||||
_reasonings = reasonings.generate()
|
_reasonings = reasonings.generate()
|
||||||
self.register_parameter("_reasonings", _reasonings)
|
self.register_parameter("_reasonings", _reasonings)
|
||||||
|
@ -103,9 +103,9 @@ class ClassAwareInitializer(ComponentsInitializer):
|
|||||||
def _get_samples_from_initializer(self, length, dist):
|
def _get_samples_from_initializer(self, length, dist):
|
||||||
if not dist:
|
if not dist:
|
||||||
per_class = length // self.num_classes
|
per_class = length // self.num_classes
|
||||||
dist = self.num_classes * [per_class]
|
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
|
||||||
samples_list = [
|
samples_list = [
|
||||||
init.generate(n) for init, n in zip(self.initializers, dist)
|
init.generate(n) for init, n in zip(self.initializers, dist.values())
|
||||||
]
|
]
|
||||||
out = torch.vstack(samples_list)
|
out = torch.vstack(samples_list)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -174,10 +174,13 @@ class UnequalLabelsInitializer(LabelsInitializer):
|
|||||||
def distribution(self):
|
def distribution(self):
|
||||||
return self.dist
|
return self.dist
|
||||||
|
|
||||||
def generate(self):
|
def generate(self, clabels=None, dist=None):
|
||||||
clabels = range(len(self.dist))
|
if not clabels:
|
||||||
labels = list(chain(*[[i] * n for i, n in zip(clabels, self.dist)]))
|
clabels = range(len(self.dist))
|
||||||
return torch.tensor(labels)
|
if not dist:
|
||||||
|
dist = self.dist
|
||||||
|
labels = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
|
||||||
|
return torch.LongTensor(labels)
|
||||||
|
|
||||||
|
|
||||||
class EqualLabelsInitializer(LabelsInitializer):
|
class EqualLabelsInitializer(LabelsInitializer):
|
||||||
@ -193,6 +196,13 @@ class EqualLabelsInitializer(LabelsInitializer):
|
|||||||
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
|
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLabelsInitializer(UnequalLabelsInitializer):
|
||||||
|
def generate(self):
|
||||||
|
clabels = list(self.dist.keys())
|
||||||
|
dist = list(self.dist.values())
|
||||||
|
return super().generate(clabels, dist)
|
||||||
|
|
||||||
|
|
||||||
# Reasonings
|
# Reasonings
|
||||||
class ReasoningsInitializer:
|
class ReasoningsInitializer:
|
||||||
def generate(self, length):
|
def generate(self, length):
|
||||||
|
Loading…
Reference in New Issue
Block a user