Support for unequal prototype distributions

This commit is contained in:
Jensun Ravichandran 2021-05-11 16:11:11 +02:00
parent a864cf5d4d
commit 7bb93f027a
2 changed files with 60 additions and 18 deletions

View File

@ -4,8 +4,10 @@ import warnings
from typing import Tuple
import torch
from prototorch.components.initializers import (ComponentsInitializer,
EqualLabelInitializer,
from prototorch.components.initializers import (ClassAwareInitializer,
ComponentsInitializer,
EqualLabelsInitializer,
UnequalLabelsInitializer,
ZeroReasoningsInitializer)
from prototorch.functions.initializers import get_initializer
from torch.nn.parameter import Parameter
@ -30,12 +32,15 @@ class Components(torch.nn.Module):
else:
self._initialize_components(number_of_components, initializer)
def _initialize_components(self, number_of_components, initializer):
def _precheck_initializer(self, initializer):
if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some subtype of " \
f"{ComponentsInitializer}. " \
f"You have provided: {initializer=} instead."
raise TypeError(emsg)
def _initialize_components(self, number_of_components, initializer):
self._precheck_initializer(initializer)
self._components = Parameter(
initializer.generate(number_of_components))
@ -57,7 +62,7 @@ class LabeledComponents(Components):
Every Component has a label assigned.
"""
def __init__(self,
labels=None,
distribution=None,
initializer=None,
*,
initialized_components=None):
@ -65,15 +70,27 @@ class LabeledComponents(Components):
super().__init__(initialized_components=initialized_components[0])
self._labels = initialized_components[1]
else:
self._initialize_labels(labels)
self._initialize_labels(distribution)
super().__init__(number_of_components=len(self._labels),
initializer=initializer)
def _initialize_labels(self, labels):
if type(labels) == tuple:
num_classes, prototypes_per_class = labels
labels = EqualLabelInitializer(num_classes, prototypes_per_class)
def _initialize_components(self, number_of_components, initializer):
if isinstance(initializer, ClassAwareInitializer):
self._precheck_initializer(initializer)
self._components = Parameter(
initializer.generate(number_of_components, self.distribution))
else:
super()._initialize_components(self, number_of_components,
initializer)
def _initialize_labels(self, distribution):
if type(distribution) == tuple:
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif type(distribution) == list:
labels = UnequalLabelsInitializer(distribution)
self.distribution = labels.distribution
self._labels = labels.generate()
@property

View File

@ -1,6 +1,7 @@
"""ProtoTroch Initializers."""
import warnings
from collections.abc import Iterable
from itertools import chain
import torch
from torch.utils.data import DataLoader, Dataset
@ -91,6 +92,15 @@ class ClassAwareInitializer(ComponentsInitializer):
self.clabels = torch.unique(self.labels)
self.num_classes = len(self.clabels)
def _get_samples_from_initializer(self, length, dist):
if not dist:
per_class = length // self.num_classes
dist = self.num_classes * [per_class]
samples_list = [
init.generate(n) for init, n in zip(self.initializers, dist)
]
return torch.vstack(samples_list)
class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, arg):
@ -102,10 +112,9 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
class_initializer = MeanInitializer(class_data)
self.initializers.append(class_initializer)
def generate(self, length):
per_class = length // self.num_classes
samples_list = [init.generate(per_class) for init in self.initializers]
return torch.vstack(samples_list)
def generate(self, length, dist=[]):
samples = self._get_samples_from_initializer(length, dist)
return samples
class StratifiedSelectionInitializer(ClassAwareInitializer):
@ -126,10 +135,8 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
mask = torch.bernoulli(n1) - torch.bernoulli(n2)
return x + (self.noise * mask)
def generate(self, length):
per_class = length // self.num_classes
samples_list = [init.generate(per_class) for init in self.initializers]
samples = torch.vstack(samples_list)
def generate(self, length, dist=[]):
samples = self._get_samples_from_initializer(length, dist)
if self.noise is not None:
# samples = self.add_noise(samples)
samples = samples + self.noise
@ -142,11 +149,29 @@ class LabelsInitializer:
raise NotImplementedError("Subclasses should implement this!")
class EqualLabelInitializer(LabelsInitializer):
class UnequalLabelsInitializer(LabelsInitializer):
def __init__(self, dist):
self.dist = dist
@property
def distribution(self):
return self.dist
def generate(self):
clabels = range(len(self.dist))
labels = list(chain(*[[i] * n for i, n in zip(clabels, self.dist)]))
return torch.tensor(labels)
class EqualLabelsInitializer(LabelsInitializer):
def __init__(self, classes, per_class):
self.classes = classes
self.per_class = per_class
@property
def distribution(self):
return self.classes * [self.per_class]
def generate(self):
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()