Make components dynamic
This commit is contained in:
parent
040d1ee9e8
commit
e61ae73749
@ -12,6 +12,26 @@ from prototorch.components.initializers import (ClassAwareInitializer,
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
def get_labels_object(distribution):
|
||||
if isinstance(distribution, dict):
|
||||
if "num_classes" in distribution.keys():
|
||||
labels = EqualLabelsInitializer(
|
||||
distribution["num_classes"],
|
||||
distribution["prototypes_per_class"])
|
||||
else:
|
||||
labels = CustomLabelsInitializer(distribution)
|
||||
elif isinstance(distribution, tuple):
|
||||
num_classes, prototypes_per_class = distribution
|
||||
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||
elif isinstance(distribution, list):
|
||||
labels = UnequalLabelsInitializer(distribution)
|
||||
else:
|
||||
msg = f"`distribution` not understood." \
|
||||
f"You have provided: {distribution=}."
|
||||
raise ValueError(msg)
|
||||
return labels
|
||||
|
||||
|
||||
class Components(torch.nn.Module):
|
||||
"""Components is a set of learnable Tensors."""
|
||||
def __init__(self,
|
||||
@ -25,14 +45,16 @@ class Components(torch.nn.Module):
|
||||
|
||||
# Ignore all initialization settings if initialized_components is given.
|
||||
if initialized_components is not None:
|
||||
self.register_parameter("_components",
|
||||
Parameter(initialized_components))
|
||||
self._register_components(initialized_components)
|
||||
if num_components is not None or initializer is not None:
|
||||
wmsg = "Arguments ignored while initializing Components"
|
||||
warnings.warn(wmsg)
|
||||
else:
|
||||
self._initialize_components(initializer)
|
||||
|
||||
def _register_components(self, components):
|
||||
self.register_parameter("_components", Parameter(components))
|
||||
|
||||
def _precheck_initializer(self, initializer):
|
||||
if not isinstance(initializer, ComponentsInitializer):
|
||||
emsg = f"`initializer` has to be some subtype of " \
|
||||
@ -43,7 +65,13 @@ class Components(torch.nn.Module):
|
||||
def _initialize_components(self, initializer):
|
||||
self._precheck_initializer(initializer)
|
||||
_components = initializer.generate(self.num_components)
|
||||
self.register_parameter("_components", Parameter(_components))
|
||||
self._register_components(_components)
|
||||
|
||||
def increase_components(self, initializer, num=1):
|
||||
self._precheck_initializer(initializer)
|
||||
_new = initializer.generate(num)
|
||||
_components = torch.cat([self._components, _new])
|
||||
self._register_components(_components)
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
@ -72,35 +100,48 @@ class LabeledComponents(Components):
|
||||
super().__init__(initialized_components=components)
|
||||
self._labels = component_labels
|
||||
else:
|
||||
_labels = self._initialize_labels(distribution)
|
||||
labels = get_labels_object(distribution)
|
||||
self.distribution = labels.distribution
|
||||
_labels = labels.generate()
|
||||
super().__init__(len(_labels), initializer=initializer)
|
||||
self.register_buffer("_labels", _labels)
|
||||
self._register_labels(_labels)
|
||||
|
||||
def _register_labels(self, labels):
|
||||
self.register_buffer("_labels", labels)
|
||||
|
||||
def _update_distribution(self, distribution):
|
||||
self.distribution = [
|
||||
old + new for old, new in zip(self.distribution, distribution)
|
||||
]
|
||||
|
||||
def _initialize_components(self, initializer):
|
||||
if isinstance(initializer, ClassAwareInitializer):
|
||||
self._precheck_initializer(initializer)
|
||||
_components = initializer.generate(self.num_components,
|
||||
self.distribution)
|
||||
self.register_parameter("_components", Parameter(_components))
|
||||
self._register_components(_components)
|
||||
else:
|
||||
super()._initialize_components(initializer)
|
||||
|
||||
def _initialize_labels(self, distribution):
|
||||
if type(distribution) == dict:
|
||||
if "num_classes" in distribution.keys():
|
||||
labels = EqualLabelsInitializer(
|
||||
distribution["num_classes"],
|
||||
distribution["prototypes_per_class"])
|
||||
else:
|
||||
labels = CustomLabelsInitializer(distribution)
|
||||
elif type(distribution) == tuple:
|
||||
num_classes, prototypes_per_class = distribution
|
||||
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||
elif type(distribution) == list:
|
||||
labels = UnequalLabelsInitializer(distribution)
|
||||
def increase_components(self, initializer, distribution=[1]):
|
||||
self._precheck_initializer(initializer)
|
||||
|
||||
self.distribution = labels.distribution
|
||||
return labels.generate()
|
||||
# Labels
|
||||
labels = get_labels_object(distribution)
|
||||
new_labels = labels.generate()
|
||||
_labels = torch.cat([self._labels, new_labels])
|
||||
self._register_labels(_labels)
|
||||
|
||||
# Components
|
||||
if isinstance(initializer, ClassAwareInitializer):
|
||||
_new = initializer.generate(len(new_labels), labels.distribution)
|
||||
else:
|
||||
_new = initializer.generate(len(new_labels))
|
||||
_components = torch.cat([self._components, _new])
|
||||
self._register_components(_components)
|
||||
|
||||
# Housekeeping
|
||||
self._update_distribution(labels.distribution)
|
||||
|
||||
@property
|
||||
def component_labels(self):
|
||||
@ -141,7 +182,7 @@ class ReasoningComponents(Components):
|
||||
super().__init__(len(self._reasonings), initializer=initializer)
|
||||
|
||||
def _initialize_reasonings(self, reasonings):
|
||||
if type(reasonings) == tuple:
|
||||
if isinstance(reasonings, tuple):
|
||||
num_classes, num_components = reasonings
|
||||
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
|
||||
|
||||
|
@ -13,21 +13,30 @@ def parse_data_arg(data_arg):
|
||||
|
||||
if isinstance(data_arg, DataLoader):
|
||||
data = torch.tensor([])
|
||||
labels = torch.tensor([])
|
||||
targets = torch.tensor([])
|
||||
for x, y in data_arg:
|
||||
data = torch.cat([data, x])
|
||||
labels = torch.cat([labels, y])
|
||||
targets = torch.cat([targets, y])
|
||||
else:
|
||||
data, labels = data_arg
|
||||
data, targets = data_arg
|
||||
if not isinstance(data, torch.Tensor):
|
||||
wmsg = f"Converting data to {torch.Tensor}."
|
||||
warnings.warn(wmsg)
|
||||
data = torch.Tensor(data)
|
||||
if not isinstance(labels, torch.Tensor):
|
||||
wmsg = f"Converting labels to {torch.Tensor}."
|
||||
if not isinstance(targets, torch.Tensor):
|
||||
wmsg = f"Converting targets to {torch.Tensor}."
|
||||
warnings.warn(wmsg)
|
||||
labels = torch.Tensor(labels)
|
||||
return data, labels
|
||||
targets = torch.Tensor(targets)
|
||||
return data, targets
|
||||
|
||||
|
||||
def get_subinitializers(data, targets, clabels, subinit_type):
|
||||
initializers = dict()
|
||||
for clabel in clabels:
|
||||
class_data = data[targets == clabel]
|
||||
class_initializer = subinit_type(class_data)
|
||||
initializers[clabel] = (class_initializer)
|
||||
return initializers
|
||||
|
||||
|
||||
# Components
|
||||
@ -91,43 +100,37 @@ class MeanInitializer(DataAwareInitializer):
|
||||
class ClassAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, data, transform=torch.nn.Identity()):
|
||||
super().__init__()
|
||||
data, labels = parse_data_arg(data)
|
||||
data, targets = parse_data_arg(data)
|
||||
self.data = data
|
||||
self.labels = labels
|
||||
self.targets = targets
|
||||
|
||||
self.transform = transform
|
||||
|
||||
self.clabels = torch.unique(self.labels)
|
||||
self.clabels = torch.unique(self.targets).int().tolist()
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
def _get_samples_from_initializer(self, length, dist):
|
||||
if not dist:
|
||||
per_class = length // self.num_classes
|
||||
dist = self.num_classes * [per_class]
|
||||
if type(dist) == dict:
|
||||
dist = dist.values()
|
||||
samples_list = [
|
||||
init.generate(n) for init, n in zip(self.initializers, dist)
|
||||
]
|
||||
out = torch.vstack(samples_list)
|
||||
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
|
||||
if isinstance(dist, list):
|
||||
dist = dict(zip(self.clabels, dist))
|
||||
samples = [self.initializers[k].generate(n) for k, n in dist.items()]
|
||||
out = torch.vstack(samples)
|
||||
with torch.no_grad():
|
||||
out = self.transform(out)
|
||||
return out
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
del self.labels
|
||||
del self.targets
|
||||
|
||||
|
||||
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||
def __init__(self, data, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
|
||||
self.initializers = []
|
||||
for clabel in self.clabels:
|
||||
class_data = self.data[self.labels == clabel]
|
||||
class_initializer = MeanInitializer(class_data)
|
||||
self.initializers.append(class_initializer)
|
||||
self.initializers = get_subinitializers(self.data, self.targets,
|
||||
self.clabels, MeanInitializer)
|
||||
|
||||
def generate(self, length, dist=[]):
|
||||
samples = self._get_samples_from_initializer(length, dist)
|
||||
@ -138,12 +141,9 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||
def __init__(self, data, noise=None, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
self.noise = noise
|
||||
|
||||
self.initializers = []
|
||||
for clabel in self.clabels:
|
||||
class_data = self.data[self.labels == clabel]
|
||||
class_initializer = SelectionInitializer(class_data)
|
||||
self.initializers.append(class_initializer)
|
||||
self.initializers = get_subinitializers(self.data, self.targets,
|
||||
self.clabels,
|
||||
SelectionInitializer)
|
||||
|
||||
def add_noise_v1(self, x):
|
||||
return x + self.noise
|
||||
@ -181,8 +181,8 @@ class UnequalLabelsInitializer(LabelsInitializer):
|
||||
clabels = range(len(self.dist))
|
||||
if not dist:
|
||||
dist = self.dist
|
||||
labels = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
|
||||
return torch.LongTensor(labels)
|
||||
targets = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
|
||||
return torch.LongTensor(targets)
|
||||
|
||||
|
||||
class EqualLabelsInitializer(LabelsInitializer):
|
||||
|
Loading…
Reference in New Issue
Block a user