Make components dynamic
This commit is contained in:
parent
040d1ee9e8
commit
e61ae73749
@ -12,6 +12,26 @@ from prototorch.components.initializers import (ClassAwareInitializer,
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
def get_labels_object(distribution):
|
||||||
|
if isinstance(distribution, dict):
|
||||||
|
if "num_classes" in distribution.keys():
|
||||||
|
labels = EqualLabelsInitializer(
|
||||||
|
distribution["num_classes"],
|
||||||
|
distribution["prototypes_per_class"])
|
||||||
|
else:
|
||||||
|
labels = CustomLabelsInitializer(distribution)
|
||||||
|
elif isinstance(distribution, tuple):
|
||||||
|
num_classes, prototypes_per_class = distribution
|
||||||
|
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||||
|
elif isinstance(distribution, list):
|
||||||
|
labels = UnequalLabelsInitializer(distribution)
|
||||||
|
else:
|
||||||
|
msg = f"`distribution` not understood." \
|
||||||
|
f"You have provided: {distribution=}."
|
||||||
|
raise ValueError(msg)
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
class Components(torch.nn.Module):
|
class Components(torch.nn.Module):
|
||||||
"""Components is a set of learnable Tensors."""
|
"""Components is a set of learnable Tensors."""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -25,14 +45,16 @@ class Components(torch.nn.Module):
|
|||||||
|
|
||||||
# Ignore all initialization settings if initialized_components is given.
|
# Ignore all initialization settings if initialized_components is given.
|
||||||
if initialized_components is not None:
|
if initialized_components is not None:
|
||||||
self.register_parameter("_components",
|
self._register_components(initialized_components)
|
||||||
Parameter(initialized_components))
|
|
||||||
if num_components is not None or initializer is not None:
|
if num_components is not None or initializer is not None:
|
||||||
wmsg = "Arguments ignored while initializing Components"
|
wmsg = "Arguments ignored while initializing Components"
|
||||||
warnings.warn(wmsg)
|
warnings.warn(wmsg)
|
||||||
else:
|
else:
|
||||||
self._initialize_components(initializer)
|
self._initialize_components(initializer)
|
||||||
|
|
||||||
|
def _register_components(self, components):
|
||||||
|
self.register_parameter("_components", Parameter(components))
|
||||||
|
|
||||||
def _precheck_initializer(self, initializer):
|
def _precheck_initializer(self, initializer):
|
||||||
if not isinstance(initializer, ComponentsInitializer):
|
if not isinstance(initializer, ComponentsInitializer):
|
||||||
emsg = f"`initializer` has to be some subtype of " \
|
emsg = f"`initializer` has to be some subtype of " \
|
||||||
@ -43,7 +65,13 @@ class Components(torch.nn.Module):
|
|||||||
def _initialize_components(self, initializer):
|
def _initialize_components(self, initializer):
|
||||||
self._precheck_initializer(initializer)
|
self._precheck_initializer(initializer)
|
||||||
_components = initializer.generate(self.num_components)
|
_components = initializer.generate(self.num_components)
|
||||||
self.register_parameter("_components", Parameter(_components))
|
self._register_components(_components)
|
||||||
|
|
||||||
|
def increase_components(self, initializer, num=1):
|
||||||
|
self._precheck_initializer(initializer)
|
||||||
|
_new = initializer.generate(num)
|
||||||
|
_components = torch.cat([self._components, _new])
|
||||||
|
self._register_components(_components)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def components(self):
|
def components(self):
|
||||||
@ -72,35 +100,48 @@ class LabeledComponents(Components):
|
|||||||
super().__init__(initialized_components=components)
|
super().__init__(initialized_components=components)
|
||||||
self._labels = component_labels
|
self._labels = component_labels
|
||||||
else:
|
else:
|
||||||
_labels = self._initialize_labels(distribution)
|
labels = get_labels_object(distribution)
|
||||||
|
self.distribution = labels.distribution
|
||||||
|
_labels = labels.generate()
|
||||||
super().__init__(len(_labels), initializer=initializer)
|
super().__init__(len(_labels), initializer=initializer)
|
||||||
self.register_buffer("_labels", _labels)
|
self._register_labels(_labels)
|
||||||
|
|
||||||
|
def _register_labels(self, labels):
|
||||||
|
self.register_buffer("_labels", labels)
|
||||||
|
|
||||||
|
def _update_distribution(self, distribution):
|
||||||
|
self.distribution = [
|
||||||
|
old + new for old, new in zip(self.distribution, distribution)
|
||||||
|
]
|
||||||
|
|
||||||
def _initialize_components(self, initializer):
|
def _initialize_components(self, initializer):
|
||||||
if isinstance(initializer, ClassAwareInitializer):
|
if isinstance(initializer, ClassAwareInitializer):
|
||||||
self._precheck_initializer(initializer)
|
self._precheck_initializer(initializer)
|
||||||
_components = initializer.generate(self.num_components,
|
_components = initializer.generate(self.num_components,
|
||||||
self.distribution)
|
self.distribution)
|
||||||
self.register_parameter("_components", Parameter(_components))
|
self._register_components(_components)
|
||||||
else:
|
else:
|
||||||
super()._initialize_components(initializer)
|
super()._initialize_components(initializer)
|
||||||
|
|
||||||
def _initialize_labels(self, distribution):
|
def increase_components(self, initializer, distribution=[1]):
|
||||||
if type(distribution) == dict:
|
self._precheck_initializer(initializer)
|
||||||
if "num_classes" in distribution.keys():
|
|
||||||
labels = EqualLabelsInitializer(
|
|
||||||
distribution["num_classes"],
|
|
||||||
distribution["prototypes_per_class"])
|
|
||||||
else:
|
|
||||||
labels = CustomLabelsInitializer(distribution)
|
|
||||||
elif type(distribution) == tuple:
|
|
||||||
num_classes, prototypes_per_class = distribution
|
|
||||||
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
|
||||||
elif type(distribution) == list:
|
|
||||||
labels = UnequalLabelsInitializer(distribution)
|
|
||||||
|
|
||||||
self.distribution = labels.distribution
|
# Labels
|
||||||
return labels.generate()
|
labels = get_labels_object(distribution)
|
||||||
|
new_labels = labels.generate()
|
||||||
|
_labels = torch.cat([self._labels, new_labels])
|
||||||
|
self._register_labels(_labels)
|
||||||
|
|
||||||
|
# Components
|
||||||
|
if isinstance(initializer, ClassAwareInitializer):
|
||||||
|
_new = initializer.generate(len(new_labels), labels.distribution)
|
||||||
|
else:
|
||||||
|
_new = initializer.generate(len(new_labels))
|
||||||
|
_components = torch.cat([self._components, _new])
|
||||||
|
self._register_components(_components)
|
||||||
|
|
||||||
|
# Housekeeping
|
||||||
|
self._update_distribution(labels.distribution)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def component_labels(self):
|
def component_labels(self):
|
||||||
@ -141,7 +182,7 @@ class ReasoningComponents(Components):
|
|||||||
super().__init__(len(self._reasonings), initializer=initializer)
|
super().__init__(len(self._reasonings), initializer=initializer)
|
||||||
|
|
||||||
def _initialize_reasonings(self, reasonings):
|
def _initialize_reasonings(self, reasonings):
|
||||||
if type(reasonings) == tuple:
|
if isinstance(reasonings, tuple):
|
||||||
num_classes, num_components = reasonings
|
num_classes, num_components = reasonings
|
||||||
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
|
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
|
||||||
|
|
||||||
|
@ -13,21 +13,30 @@ def parse_data_arg(data_arg):
|
|||||||
|
|
||||||
if isinstance(data_arg, DataLoader):
|
if isinstance(data_arg, DataLoader):
|
||||||
data = torch.tensor([])
|
data = torch.tensor([])
|
||||||
labels = torch.tensor([])
|
targets = torch.tensor([])
|
||||||
for x, y in data_arg:
|
for x, y in data_arg:
|
||||||
data = torch.cat([data, x])
|
data = torch.cat([data, x])
|
||||||
labels = torch.cat([labels, y])
|
targets = torch.cat([targets, y])
|
||||||
else:
|
else:
|
||||||
data, labels = data_arg
|
data, targets = data_arg
|
||||||
if not isinstance(data, torch.Tensor):
|
if not isinstance(data, torch.Tensor):
|
||||||
wmsg = f"Converting data to {torch.Tensor}."
|
wmsg = f"Converting data to {torch.Tensor}."
|
||||||
warnings.warn(wmsg)
|
warnings.warn(wmsg)
|
||||||
data = torch.Tensor(data)
|
data = torch.Tensor(data)
|
||||||
if not isinstance(labels, torch.Tensor):
|
if not isinstance(targets, torch.Tensor):
|
||||||
wmsg = f"Converting labels to {torch.Tensor}."
|
wmsg = f"Converting targets to {torch.Tensor}."
|
||||||
warnings.warn(wmsg)
|
warnings.warn(wmsg)
|
||||||
labels = torch.Tensor(labels)
|
targets = torch.Tensor(targets)
|
||||||
return data, labels
|
return data, targets
|
||||||
|
|
||||||
|
|
||||||
|
def get_subinitializers(data, targets, clabels, subinit_type):
|
||||||
|
initializers = dict()
|
||||||
|
for clabel in clabels:
|
||||||
|
class_data = data[targets == clabel]
|
||||||
|
class_initializer = subinit_type(class_data)
|
||||||
|
initializers[clabel] = (class_initializer)
|
||||||
|
return initializers
|
||||||
|
|
||||||
|
|
||||||
# Components
|
# Components
|
||||||
@ -91,43 +100,37 @@ class MeanInitializer(DataAwareInitializer):
|
|||||||
class ClassAwareInitializer(ComponentsInitializer):
|
class ClassAwareInitializer(ComponentsInitializer):
|
||||||
def __init__(self, data, transform=torch.nn.Identity()):
|
def __init__(self, data, transform=torch.nn.Identity()):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
data, labels = parse_data_arg(data)
|
data, targets = parse_data_arg(data)
|
||||||
self.data = data
|
self.data = data
|
||||||
self.labels = labels
|
self.targets = targets
|
||||||
|
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
|
||||||
self.clabels = torch.unique(self.labels)
|
self.clabels = torch.unique(self.targets).int().tolist()
|
||||||
self.num_classes = len(self.clabels)
|
self.num_classes = len(self.clabels)
|
||||||
|
|
||||||
def _get_samples_from_initializer(self, length, dist):
|
def _get_samples_from_initializer(self, length, dist):
|
||||||
if not dist:
|
if not dist:
|
||||||
per_class = length // self.num_classes
|
per_class = length // self.num_classes
|
||||||
dist = self.num_classes * [per_class]
|
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
|
||||||
if type(dist) == dict:
|
if isinstance(dist, list):
|
||||||
dist = dist.values()
|
dist = dict(zip(self.clabels, dist))
|
||||||
samples_list = [
|
samples = [self.initializers[k].generate(n) for k, n in dist.items()]
|
||||||
init.generate(n) for init, n in zip(self.initializers, dist)
|
out = torch.vstack(samples)
|
||||||
]
|
|
||||||
out = torch.vstack(samples_list)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out = self.transform(out)
|
out = self.transform(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
del self.data
|
del self.data
|
||||||
del self.labels
|
del self.targets
|
||||||
|
|
||||||
|
|
||||||
class StratifiedMeanInitializer(ClassAwareInitializer):
|
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||||
def __init__(self, data, **kwargs):
|
def __init__(self, data, **kwargs):
|
||||||
super().__init__(data, **kwargs)
|
super().__init__(data, **kwargs)
|
||||||
|
self.initializers = get_subinitializers(self.data, self.targets,
|
||||||
self.initializers = []
|
self.clabels, MeanInitializer)
|
||||||
for clabel in self.clabels:
|
|
||||||
class_data = self.data[self.labels == clabel]
|
|
||||||
class_initializer = MeanInitializer(class_data)
|
|
||||||
self.initializers.append(class_initializer)
|
|
||||||
|
|
||||||
def generate(self, length, dist=[]):
|
def generate(self, length, dist=[]):
|
||||||
samples = self._get_samples_from_initializer(length, dist)
|
samples = self._get_samples_from_initializer(length, dist)
|
||||||
@ -138,12 +141,9 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
|
|||||||
def __init__(self, data, noise=None, **kwargs):
|
def __init__(self, data, noise=None, **kwargs):
|
||||||
super().__init__(data, **kwargs)
|
super().__init__(data, **kwargs)
|
||||||
self.noise = noise
|
self.noise = noise
|
||||||
|
self.initializers = get_subinitializers(self.data, self.targets,
|
||||||
self.initializers = []
|
self.clabels,
|
||||||
for clabel in self.clabels:
|
SelectionInitializer)
|
||||||
class_data = self.data[self.labels == clabel]
|
|
||||||
class_initializer = SelectionInitializer(class_data)
|
|
||||||
self.initializers.append(class_initializer)
|
|
||||||
|
|
||||||
def add_noise_v1(self, x):
|
def add_noise_v1(self, x):
|
||||||
return x + self.noise
|
return x + self.noise
|
||||||
@ -181,8 +181,8 @@ class UnequalLabelsInitializer(LabelsInitializer):
|
|||||||
clabels = range(len(self.dist))
|
clabels = range(len(self.dist))
|
||||||
if not dist:
|
if not dist:
|
||||||
dist = self.dist
|
dist = self.dist
|
||||||
labels = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
|
targets = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
|
||||||
return torch.LongTensor(labels)
|
return torch.LongTensor(targets)
|
||||||
|
|
||||||
|
|
||||||
class EqualLabelsInitializer(LabelsInitializer):
|
class EqualLabelsInitializer(LabelsInitializer):
|
||||||
|
Loading…
Reference in New Issue
Block a user