Update initializers

This commit is contained in:
Jensun Ravichandran 2021-04-29 19:15:27 +02:00
parent 9b663477fd
commit b7d53aa5f1

View File

@ -1,11 +1,13 @@
import torch
"""ProtoTroch Initializers."""
from collections.abc import Iterable
import torch
# Components
class ComponentsInitializer:
def generate(self, number_of_components):
pass
raise NotImplementedError("Subclasses should implement this!")
class DimensionAwareInitializer(ComponentsInitializer):
@ -38,7 +40,7 @@ class UniformInitializer(DimensionAwareInitializer):
def generate(self, length):
gen_dims = (length, ) + self.components_dims
return torch.FloatTensor(gen_dims).uniform_(self.min, self.max)
return torch.ones(gen_dims).uniform_(self.min, self.max)
class PositionAwareInitializer(ComponentsInitializer):
@ -61,51 +63,62 @@ class MeanInitializer(PositionAwareInitializer):
class ClassAwareInitializer(ComponentsInitializer):
def __init__(self, positions, classes):
def __init__(self, data, labels):
super().__init__()
self.data = positions
self.classes = classes
self.data = data
self.labels = labels
self.names = torch.unique(self.classes)
self.num_classes = len(self.names)
self.clabels = torch.unique(self.labels)
self.num_classes = len(self.clabels)
class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, positions, classes):
super().__init__(positions, classes)
def __init__(self, data, labels):
super().__init__(data, labels)
self.initializers = []
for name in self.names:
class_data = self.data[self.classes == name]
for clabel in self.clabels:
class_data = self.data[self.labels == clabel]
class_initializer = MeanInitializer(class_data)
self.initializers.append(class_initializer)
def generate(self, length):
per_class = length // self.num_classes
return torch.vstack(
[init.generate(per_class) for init in self.initializers])
samples_list = [init.generate(per_class) for init in self.initializers]
return torch.vstack(samples_list)
class StratifiedSelectionInitializer(ClassAwareInitializer):
def __init__(self, positions, classes):
super().__init__(positions, classes)
def __init__(self, data, labels, noise=None):
super().__init__(data, labels)
self.noise = noise
self.initializers = []
for name in self.names:
class_data = self.data[self.classes == name]
for clabel in self.clabels:
class_data = self.data[self.labels == clabel]
class_initializer = SelectionInitializer(class_data)
self.initializers.append(class_initializer)
def add_noise(self, x):
"""Shifts some dimensions of the data randomly."""
n1 = torch.rand_like(x)
n2 = torch.rand_like(x)
mask = torch.bernoulli(n1) - torch.bernoulli(n2)
return x + (self.noise * mask)
def generate(self, length):
per_class = length // self.num_classes
return torch.vstack(
[init.generate(per_class) for init in self.initializers])
samples_list = [init.generate(per_class) for init in self.initializers]
samples = torch.vstack(samples_list)
if self.noise is not None:
samples = self.add_noise(samples)
return samples
# Labels
class LabelsInitializer:
def generate(self):
pass
raise NotImplementedError("Subclasses should implement this!")
class EqualLabelInitializer(LabelsInitializer):
@ -120,7 +133,7 @@ class EqualLabelInitializer(LabelsInitializer):
# Reasonings
class ReasoningsInitializer:
def generate(self, length):
pass
raise NotImplementedError("Subclasses should implement this!")
class ZeroReasoningsInitializer(ReasoningsInitializer):
@ -130,3 +143,9 @@ class ZeroReasoningsInitializer(ReasoningsInitializer):
def generate(self):
return torch.zeros((self.length, self.classes, 2))
# Aliases
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
SMI = StratifiedMeanInitializer
Random = RandomInitializer = UniformInitializer