Overload distribution
argument in component initializers
The component initializers behave differently based on the type of the `distribution` argument. If it is a Python [list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed that there are as many entries in this list as there are classes, and the number at each location of this list describes the number of prototypes to be used for that particular class. So, `[1, 1, 1]` implies that we have three classes with one prototype per class. If it is a Python [tuple](https://docs.python.org/3/tutorial/datastructures.html), it a shorthand of `(num_classes, prototypes_per_class)` is assumed. If it is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html), the key-value pairs describe the class label and the number of prototypes for that class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes with labels `{1, 2, 3}`, each equipped with two prototypes.
This commit is contained in:
parent
21e3e3b82d
commit
8a291f7bfb
@ -1,34 +1,33 @@
|
||||
"""ProtoTorch components modules."""
|
||||
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from prototorch.components.initializers import (ClassAwareInitializer,
|
||||
ComponentsInitializer,
|
||||
CustomLabelsInitializer,
|
||||
EqualLabelsInitializer,
|
||||
UnequalLabelsInitializer,
|
||||
ZeroReasoningsInitializer)
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class Components(torch.nn.Module):
|
||||
"""Components is a set of learnable Tensors."""
|
||||
def __init__(self,
|
||||
ncomps=None,
|
||||
num_components=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
super().__init__()
|
||||
|
||||
self.ncomps = ncomps
|
||||
self.num_components = num_components
|
||||
|
||||
# Ignore all initialization settings if initialized_components is given.
|
||||
if initialized_components is not None:
|
||||
self.register_parameter("_components",
|
||||
Parameter(initialized_components))
|
||||
if ncomps is not None or initializer is not None:
|
||||
if num_components is not None or initializer is not None:
|
||||
wmsg = "Arguments ignored while initializing Components"
|
||||
warnings.warn(wmsg)
|
||||
else:
|
||||
@ -43,7 +42,7 @@ class Components(torch.nn.Module):
|
||||
|
||||
def _initialize_components(self, initializer):
|
||||
self._precheck_initializer(initializer)
|
||||
_components = initializer.generate(self.ncomps)
|
||||
_components = initializer.generate(self.num_components)
|
||||
self.register_parameter("_components", Parameter(_components))
|
||||
|
||||
@property
|
||||
@ -80,16 +79,20 @@ class LabeledComponents(Components):
|
||||
def _initialize_components(self, initializer):
|
||||
if isinstance(initializer, ClassAwareInitializer):
|
||||
self._precheck_initializer(initializer)
|
||||
_components = initializer.generate(self.ncomps, self.distribution)
|
||||
_components = initializer.generate(self.num_components,
|
||||
self.distribution)
|
||||
self.register_parameter("_components", Parameter(_components))
|
||||
else:
|
||||
super()._initialize_components(initializer)
|
||||
|
||||
def _initialize_labels(self, distribution):
|
||||
if type(distribution) == dict:
|
||||
labels = EqualLabelsInitializer(
|
||||
distribution["num_classes"],
|
||||
distribution["prototypes_per_class"])
|
||||
if "num_classes" in distribution.keys():
|
||||
labels = EqualLabelsInitializer(
|
||||
distribution["num_classes"],
|
||||
distribution["prototypes_per_class"])
|
||||
else:
|
||||
labels = CustomLabelsInitializer(distribution)
|
||||
elif type(distribution) == tuple:
|
||||
num_classes, prototypes_per_class = distribution
|
||||
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||
@ -139,8 +142,8 @@ class ReasoningComponents(Components):
|
||||
|
||||
def _initialize_reasonings(self, reasonings):
|
||||
if type(reasonings) == tuple:
|
||||
num_classes, ncomps = reasonings
|
||||
reasonings = ZeroReasoningsInitializer(num_classes, ncomps)
|
||||
num_classes, num_components = reasonings
|
||||
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
|
||||
|
||||
_reasonings = reasonings.generate()
|
||||
self.register_parameter("_reasonings", _reasonings)
|
||||
|
@ -103,9 +103,9 @@ class ClassAwareInitializer(ComponentsInitializer):
|
||||
def _get_samples_from_initializer(self, length, dist):
|
||||
if not dist:
|
||||
per_class = length // self.num_classes
|
||||
dist = self.num_classes * [per_class]
|
||||
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
|
||||
samples_list = [
|
||||
init.generate(n) for init, n in zip(self.initializers, dist)
|
||||
init.generate(n) for init, n in zip(self.initializers, dist.values())
|
||||
]
|
||||
out = torch.vstack(samples_list)
|
||||
with torch.no_grad():
|
||||
@ -174,10 +174,13 @@ class UnequalLabelsInitializer(LabelsInitializer):
|
||||
def distribution(self):
|
||||
return self.dist
|
||||
|
||||
def generate(self):
|
||||
clabels = range(len(self.dist))
|
||||
labels = list(chain(*[[i] * n for i, n in zip(clabels, self.dist)]))
|
||||
return torch.tensor(labels)
|
||||
def generate(self, clabels=None, dist=None):
|
||||
if not clabels:
|
||||
clabels = range(len(self.dist))
|
||||
if not dist:
|
||||
dist = self.dist
|
||||
labels = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
|
||||
return torch.LongTensor(labels)
|
||||
|
||||
|
||||
class EqualLabelsInitializer(LabelsInitializer):
|
||||
@ -193,6 +196,13 @@ class EqualLabelsInitializer(LabelsInitializer):
|
||||
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
|
||||
|
||||
|
||||
class CustomLabelsInitializer(UnequalLabelsInitializer):
|
||||
def generate(self):
|
||||
clabels = list(self.dist.keys())
|
||||
dist = list(self.dist.values())
|
||||
return super().generate(clabels, dist)
|
||||
|
||||
|
||||
# Reasonings
|
||||
class ReasoningsInitializer:
|
||||
def generate(self, length):
|
||||
|
Loading…
Reference in New Issue
Block a user