Make components dynamic

This commit is contained in:
Jensun Ravichandran 2021-05-31 00:31:40 +02:00
parent 040d1ee9e8
commit e61ae73749
2 changed files with 95 additions and 54 deletions

View File

@ -12,6 +12,26 @@ from prototorch.components.initializers import (ClassAwareInitializer,
from torch.nn.parameter import Parameter
def get_labels_object(distribution):
if isinstance(distribution, dict):
if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer(
distribution["num_classes"],
distribution["prototypes_per_class"])
else:
labels = CustomLabelsInitializer(distribution)
elif isinstance(distribution, tuple):
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif isinstance(distribution, list):
labels = UnequalLabelsInitializer(distribution)
else:
msg = f"`distribution` not understood." \
f"You have provided: {distribution=}."
raise ValueError(msg)
return labels
class Components(torch.nn.Module):
"""Components is a set of learnable Tensors."""
def __init__(self,
@ -25,14 +45,16 @@ class Components(torch.nn.Module):
# Ignore all initialization settings if initialized_components is given.
if initialized_components is not None:
self.register_parameter("_components",
Parameter(initialized_components))
self._register_components(initialized_components)
if num_components is not None or initializer is not None:
wmsg = "Arguments ignored while initializing Components"
warnings.warn(wmsg)
else:
self._initialize_components(initializer)
def _register_components(self, components):
self.register_parameter("_components", Parameter(components))
def _precheck_initializer(self, initializer):
if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some subtype of " \
@ -43,7 +65,13 @@ class Components(torch.nn.Module):
def _initialize_components(self, initializer):
self._precheck_initializer(initializer)
_components = initializer.generate(self.num_components)
self.register_parameter("_components", Parameter(_components))
self._register_components(_components)
def increase_components(self, initializer, num=1):
self._precheck_initializer(initializer)
_new = initializer.generate(num)
_components = torch.cat([self._components, _new])
self._register_components(_components)
@property
def components(self):
@ -72,35 +100,48 @@ class LabeledComponents(Components):
super().__init__(initialized_components=components)
self._labels = component_labels
else:
_labels = self._initialize_labels(distribution)
labels = get_labels_object(distribution)
self.distribution = labels.distribution
_labels = labels.generate()
super().__init__(len(_labels), initializer=initializer)
self.register_buffer("_labels", _labels)
self._register_labels(_labels)
def _register_labels(self, labels):
self.register_buffer("_labels", labels)
def _update_distribution(self, distribution):
self.distribution = [
old + new for old, new in zip(self.distribution, distribution)
]
def _initialize_components(self, initializer):
if isinstance(initializer, ClassAwareInitializer):
self._precheck_initializer(initializer)
_components = initializer.generate(self.num_components,
self.distribution)
self.register_parameter("_components", Parameter(_components))
self._register_components(_components)
else:
super()._initialize_components(initializer)
def _initialize_labels(self, distribution):
if type(distribution) == dict:
if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer(
distribution["num_classes"],
distribution["prototypes_per_class"])
else:
labels = CustomLabelsInitializer(distribution)
elif type(distribution) == tuple:
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif type(distribution) == list:
labels = UnequalLabelsInitializer(distribution)
def increase_components(self, initializer, distribution=[1]):
self._precheck_initializer(initializer)
self.distribution = labels.distribution
return labels.generate()
# Labels
labels = get_labels_object(distribution)
new_labels = labels.generate()
_labels = torch.cat([self._labels, new_labels])
self._register_labels(_labels)
# Components
if isinstance(initializer, ClassAwareInitializer):
_new = initializer.generate(len(new_labels), labels.distribution)
else:
_new = initializer.generate(len(new_labels))
_components = torch.cat([self._components, _new])
self._register_components(_components)
# Housekeeping
self._update_distribution(labels.distribution)
@property
def component_labels(self):
@ -141,7 +182,7 @@ class ReasoningComponents(Components):
super().__init__(len(self._reasonings), initializer=initializer)
def _initialize_reasonings(self, reasonings):
if type(reasonings) == tuple:
if isinstance(reasonings, tuple):
num_classes, num_components = reasonings
reasonings = ZeroReasoningsInitializer(num_classes, num_components)

View File

@ -13,21 +13,30 @@ def parse_data_arg(data_arg):
if isinstance(data_arg, DataLoader):
data = torch.tensor([])
labels = torch.tensor([])
targets = torch.tensor([])
for x, y in data_arg:
data = torch.cat([data, x])
labels = torch.cat([labels, y])
targets = torch.cat([targets, y])
else:
data, labels = data_arg
data, targets = data_arg
if not isinstance(data, torch.Tensor):
wmsg = f"Converting data to {torch.Tensor}."
warnings.warn(wmsg)
data = torch.Tensor(data)
if not isinstance(labels, torch.Tensor):
wmsg = f"Converting labels to {torch.Tensor}."
if not isinstance(targets, torch.Tensor):
wmsg = f"Converting targets to {torch.Tensor}."
warnings.warn(wmsg)
labels = torch.Tensor(labels)
return data, labels
targets = torch.Tensor(targets)
return data, targets
def get_subinitializers(data, targets, clabels, subinit_type):
initializers = dict()
for clabel in clabels:
class_data = data[targets == clabel]
class_initializer = subinit_type(class_data)
initializers[clabel] = (class_initializer)
return initializers
# Components
@ -91,43 +100,37 @@ class MeanInitializer(DataAwareInitializer):
class ClassAwareInitializer(ComponentsInitializer):
def __init__(self, data, transform=torch.nn.Identity()):
super().__init__()
data, labels = parse_data_arg(data)
data, targets = parse_data_arg(data)
self.data = data
self.labels = labels
self.targets = targets
self.transform = transform
self.clabels = torch.unique(self.labels)
self.clabels = torch.unique(self.targets).int().tolist()
self.num_classes = len(self.clabels)
def _get_samples_from_initializer(self, length, dist):
if not dist:
per_class = length // self.num_classes
dist = self.num_classes * [per_class]
if type(dist) == dict:
dist = dist.values()
samples_list = [
init.generate(n) for init, n in zip(self.initializers, dist)
]
out = torch.vstack(samples_list)
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
if isinstance(dist, list):
dist = dict(zip(self.clabels, dist))
samples = [self.initializers[k].generate(n) for k, n in dist.items()]
out = torch.vstack(samples)
with torch.no_grad():
out = self.transform(out)
return out
def __del__(self):
del self.data
del self.labels
del self.targets
class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, data, **kwargs):
super().__init__(data, **kwargs)
self.initializers = []
for clabel in self.clabels:
class_data = self.data[self.labels == clabel]
class_initializer = MeanInitializer(class_data)
self.initializers.append(class_initializer)
self.initializers = get_subinitializers(self.data, self.targets,
self.clabels, MeanInitializer)
def generate(self, length, dist=[]):
samples = self._get_samples_from_initializer(length, dist)
@ -138,12 +141,9 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
def __init__(self, data, noise=None, **kwargs):
super().__init__(data, **kwargs)
self.noise = noise
self.initializers = []
for clabel in self.clabels:
class_data = self.data[self.labels == clabel]
class_initializer = SelectionInitializer(class_data)
self.initializers.append(class_initializer)
self.initializers = get_subinitializers(self.data, self.targets,
self.clabels,
SelectionInitializer)
def add_noise_v1(self, x):
return x + self.noise
@ -181,8 +181,8 @@ class UnequalLabelsInitializer(LabelsInitializer):
clabels = range(len(self.dist))
if not dist:
dist = self.dist
labels = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
return torch.LongTensor(labels)
targets = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
return torch.LongTensor(targets)
class EqualLabelsInitializer(LabelsInitializer):