Make components dynamic

This commit is contained in:
Jensun Ravichandran 2021-05-31 00:31:40 +02:00
parent 040d1ee9e8
commit e61ae73749
2 changed files with 95 additions and 54 deletions

View File

@ -12,6 +12,26 @@ from prototorch.components.initializers import (ClassAwareInitializer,
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
def get_labels_object(distribution):
if isinstance(distribution, dict):
if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer(
distribution["num_classes"],
distribution["prototypes_per_class"])
else:
labels = CustomLabelsInitializer(distribution)
elif isinstance(distribution, tuple):
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif isinstance(distribution, list):
labels = UnequalLabelsInitializer(distribution)
else:
msg = f"`distribution` not understood." \
f"You have provided: {distribution=}."
raise ValueError(msg)
return labels
class Components(torch.nn.Module): class Components(torch.nn.Module):
"""Components is a set of learnable Tensors.""" """Components is a set of learnable Tensors."""
def __init__(self, def __init__(self,
@ -25,14 +45,16 @@ class Components(torch.nn.Module):
# Ignore all initialization settings if initialized_components is given. # Ignore all initialization settings if initialized_components is given.
if initialized_components is not None: if initialized_components is not None:
self.register_parameter("_components", self._register_components(initialized_components)
Parameter(initialized_components))
if num_components is not None or initializer is not None: if num_components is not None or initializer is not None:
wmsg = "Arguments ignored while initializing Components" wmsg = "Arguments ignored while initializing Components"
warnings.warn(wmsg) warnings.warn(wmsg)
else: else:
self._initialize_components(initializer) self._initialize_components(initializer)
def _register_components(self, components):
self.register_parameter("_components", Parameter(components))
def _precheck_initializer(self, initializer): def _precheck_initializer(self, initializer):
if not isinstance(initializer, ComponentsInitializer): if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some subtype of " \ emsg = f"`initializer` has to be some subtype of " \
@ -43,7 +65,13 @@ class Components(torch.nn.Module):
def _initialize_components(self, initializer): def _initialize_components(self, initializer):
self._precheck_initializer(initializer) self._precheck_initializer(initializer)
_components = initializer.generate(self.num_components) _components = initializer.generate(self.num_components)
self.register_parameter("_components", Parameter(_components)) self._register_components(_components)
def increase_components(self, initializer, num=1):
self._precheck_initializer(initializer)
_new = initializer.generate(num)
_components = torch.cat([self._components, _new])
self._register_components(_components)
@property @property
def components(self): def components(self):
@ -72,35 +100,48 @@ class LabeledComponents(Components):
super().__init__(initialized_components=components) super().__init__(initialized_components=components)
self._labels = component_labels self._labels = component_labels
else: else:
_labels = self._initialize_labels(distribution) labels = get_labels_object(distribution)
self.distribution = labels.distribution
_labels = labels.generate()
super().__init__(len(_labels), initializer=initializer) super().__init__(len(_labels), initializer=initializer)
self.register_buffer("_labels", _labels) self._register_labels(_labels)
def _register_labels(self, labels):
self.register_buffer("_labels", labels)
def _update_distribution(self, distribution):
self.distribution = [
old + new for old, new in zip(self.distribution, distribution)
]
def _initialize_components(self, initializer): def _initialize_components(self, initializer):
if isinstance(initializer, ClassAwareInitializer): if isinstance(initializer, ClassAwareInitializer):
self._precheck_initializer(initializer) self._precheck_initializer(initializer)
_components = initializer.generate(self.num_components, _components = initializer.generate(self.num_components,
self.distribution) self.distribution)
self.register_parameter("_components", Parameter(_components)) self._register_components(_components)
else: else:
super()._initialize_components(initializer) super()._initialize_components(initializer)
def _initialize_labels(self, distribution): def increase_components(self, initializer, distribution=[1]):
if type(distribution) == dict: self._precheck_initializer(initializer)
if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer(
distribution["num_classes"],
distribution["prototypes_per_class"])
else:
labels = CustomLabelsInitializer(distribution)
elif type(distribution) == tuple:
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif type(distribution) == list:
labels = UnequalLabelsInitializer(distribution)
self.distribution = labels.distribution # Labels
return labels.generate() labels = get_labels_object(distribution)
new_labels = labels.generate()
_labels = torch.cat([self._labels, new_labels])
self._register_labels(_labels)
# Components
if isinstance(initializer, ClassAwareInitializer):
_new = initializer.generate(len(new_labels), labels.distribution)
else:
_new = initializer.generate(len(new_labels))
_components = torch.cat([self._components, _new])
self._register_components(_components)
# Housekeeping
self._update_distribution(labels.distribution)
@property @property
def component_labels(self): def component_labels(self):
@ -141,7 +182,7 @@ class ReasoningComponents(Components):
super().__init__(len(self._reasonings), initializer=initializer) super().__init__(len(self._reasonings), initializer=initializer)
def _initialize_reasonings(self, reasonings): def _initialize_reasonings(self, reasonings):
if type(reasonings) == tuple: if isinstance(reasonings, tuple):
num_classes, num_components = reasonings num_classes, num_components = reasonings
reasonings = ZeroReasoningsInitializer(num_classes, num_components) reasonings = ZeroReasoningsInitializer(num_classes, num_components)

View File

@ -13,21 +13,30 @@ def parse_data_arg(data_arg):
if isinstance(data_arg, DataLoader): if isinstance(data_arg, DataLoader):
data = torch.tensor([]) data = torch.tensor([])
labels = torch.tensor([]) targets = torch.tensor([])
for x, y in data_arg: for x, y in data_arg:
data = torch.cat([data, x]) data = torch.cat([data, x])
labels = torch.cat([labels, y]) targets = torch.cat([targets, y])
else: else:
data, labels = data_arg data, targets = data_arg
if not isinstance(data, torch.Tensor): if not isinstance(data, torch.Tensor):
wmsg = f"Converting data to {torch.Tensor}." wmsg = f"Converting data to {torch.Tensor}."
warnings.warn(wmsg) warnings.warn(wmsg)
data = torch.Tensor(data) data = torch.Tensor(data)
if not isinstance(labels, torch.Tensor): if not isinstance(targets, torch.Tensor):
wmsg = f"Converting labels to {torch.Tensor}." wmsg = f"Converting targets to {torch.Tensor}."
warnings.warn(wmsg) warnings.warn(wmsg)
labels = torch.Tensor(labels) targets = torch.Tensor(targets)
return data, labels return data, targets
def get_subinitializers(data, targets, clabels, subinit_type):
initializers = dict()
for clabel in clabels:
class_data = data[targets == clabel]
class_initializer = subinit_type(class_data)
initializers[clabel] = (class_initializer)
return initializers
# Components # Components
@ -91,43 +100,37 @@ class MeanInitializer(DataAwareInitializer):
class ClassAwareInitializer(ComponentsInitializer): class ClassAwareInitializer(ComponentsInitializer):
def __init__(self, data, transform=torch.nn.Identity()): def __init__(self, data, transform=torch.nn.Identity()):
super().__init__() super().__init__()
data, labels = parse_data_arg(data) data, targets = parse_data_arg(data)
self.data = data self.data = data
self.labels = labels self.targets = targets
self.transform = transform self.transform = transform
self.clabels = torch.unique(self.labels) self.clabels = torch.unique(self.targets).int().tolist()
self.num_classes = len(self.clabels) self.num_classes = len(self.clabels)
def _get_samples_from_initializer(self, length, dist): def _get_samples_from_initializer(self, length, dist):
if not dist: if not dist:
per_class = length // self.num_classes per_class = length // self.num_classes
dist = self.num_classes * [per_class] dist = dict(zip(self.clabels, self.num_classes * [per_class]))
if type(dist) == dict: if isinstance(dist, list):
dist = dist.values() dist = dict(zip(self.clabels, dist))
samples_list = [ samples = [self.initializers[k].generate(n) for k, n in dist.items()]
init.generate(n) for init, n in zip(self.initializers, dist) out = torch.vstack(samples)
]
out = torch.vstack(samples_list)
with torch.no_grad(): with torch.no_grad():
out = self.transform(out) out = self.transform(out)
return out return out
def __del__(self): def __del__(self):
del self.data del self.data
del self.labels del self.targets
class StratifiedMeanInitializer(ClassAwareInitializer): class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, data, **kwargs): def __init__(self, data, **kwargs):
super().__init__(data, **kwargs) super().__init__(data, **kwargs)
self.initializers = get_subinitializers(self.data, self.targets,
self.initializers = [] self.clabels, MeanInitializer)
for clabel in self.clabels:
class_data = self.data[self.labels == clabel]
class_initializer = MeanInitializer(class_data)
self.initializers.append(class_initializer)
def generate(self, length, dist=[]): def generate(self, length, dist=[]):
samples = self._get_samples_from_initializer(length, dist) samples = self._get_samples_from_initializer(length, dist)
@ -138,12 +141,9 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
def __init__(self, data, noise=None, **kwargs): def __init__(self, data, noise=None, **kwargs):
super().__init__(data, **kwargs) super().__init__(data, **kwargs)
self.noise = noise self.noise = noise
self.initializers = get_subinitializers(self.data, self.targets,
self.initializers = [] self.clabels,
for clabel in self.clabels: SelectionInitializer)
class_data = self.data[self.labels == clabel]
class_initializer = SelectionInitializer(class_data)
self.initializers.append(class_initializer)
def add_noise_v1(self, x): def add_noise_v1(self, x):
return x + self.noise return x + self.noise
@ -181,8 +181,8 @@ class UnequalLabelsInitializer(LabelsInitializer):
clabels = range(len(self.dist)) clabels = range(len(self.dist))
if not dist: if not dist:
dist = self.dist dist = self.dist
labels = list(chain(*[[i] * n for i, n in zip(clabels, dist)])) targets = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
return torch.LongTensor(labels) return torch.LongTensor(targets)
class EqualLabelsInitializer(LabelsInitializer): class EqualLabelsInitializer(LabelsInitializer):