Update initializers
This commit is contained in:
parent
9b663477fd
commit
b7d53aa5f1
@ -1,11 +1,13 @@
|
|||||||
import torch
|
"""ProtoTroch Initializers."""
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
# Components
|
# Components
|
||||||
class ComponentsInitializer:
|
class ComponentsInitializer:
|
||||||
def generate(self, number_of_components):
|
def generate(self, number_of_components):
|
||||||
pass
|
raise NotImplementedError("Subclasses should implement this!")
|
||||||
|
|
||||||
|
|
||||||
class DimensionAwareInitializer(ComponentsInitializer):
|
class DimensionAwareInitializer(ComponentsInitializer):
|
||||||
@ -38,7 +40,7 @@ class UniformInitializer(DimensionAwareInitializer):
|
|||||||
|
|
||||||
def generate(self, length):
|
def generate(self, length):
|
||||||
gen_dims = (length, ) + self.components_dims
|
gen_dims = (length, ) + self.components_dims
|
||||||
return torch.FloatTensor(gen_dims).uniform_(self.min, self.max)
|
return torch.ones(gen_dims).uniform_(self.min, self.max)
|
||||||
|
|
||||||
|
|
||||||
class PositionAwareInitializer(ComponentsInitializer):
|
class PositionAwareInitializer(ComponentsInitializer):
|
||||||
@ -61,51 +63,62 @@ class MeanInitializer(PositionAwareInitializer):
|
|||||||
|
|
||||||
|
|
||||||
class ClassAwareInitializer(ComponentsInitializer):
|
class ClassAwareInitializer(ComponentsInitializer):
|
||||||
def __init__(self, positions, classes):
|
def __init__(self, data, labels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.data = positions
|
self.data = data
|
||||||
self.classes = classes
|
self.labels = labels
|
||||||
|
|
||||||
self.names = torch.unique(self.classes)
|
self.clabels = torch.unique(self.labels)
|
||||||
self.num_classes = len(self.names)
|
self.num_classes = len(self.clabels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedMeanInitializer(ClassAwareInitializer):
|
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||||
def __init__(self, positions, classes):
|
def __init__(self, data, labels):
|
||||||
super().__init__(positions, classes)
|
super().__init__(data, labels)
|
||||||
|
|
||||||
self.initializers = []
|
self.initializers = []
|
||||||
for name in self.names:
|
for clabel in self.clabels:
|
||||||
class_data = self.data[self.classes == name]
|
class_data = self.data[self.labels == clabel]
|
||||||
class_initializer = MeanInitializer(class_data)
|
class_initializer = MeanInitializer(class_data)
|
||||||
self.initializers.append(class_initializer)
|
self.initializers.append(class_initializer)
|
||||||
|
|
||||||
def generate(self, length):
|
def generate(self, length):
|
||||||
per_class = length // self.num_classes
|
per_class = length // self.num_classes
|
||||||
return torch.vstack(
|
samples_list = [init.generate(per_class) for init in self.initializers]
|
||||||
[init.generate(per_class) for init in self.initializers])
|
return torch.vstack(samples_list)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||||
def __init__(self, positions, classes):
|
def __init__(self, data, labels, noise=None):
|
||||||
super().__init__(positions, classes)
|
super().__init__(data, labels)
|
||||||
|
self.noise = noise
|
||||||
|
|
||||||
self.initializers = []
|
self.initializers = []
|
||||||
for name in self.names:
|
for clabel in self.clabels:
|
||||||
class_data = self.data[self.classes == name]
|
class_data = self.data[self.labels == clabel]
|
||||||
class_initializer = SelectionInitializer(class_data)
|
class_initializer = SelectionInitializer(class_data)
|
||||||
self.initializers.append(class_initializer)
|
self.initializers.append(class_initializer)
|
||||||
|
|
||||||
|
def add_noise(self, x):
|
||||||
|
"""Shifts some dimensions of the data randomly."""
|
||||||
|
n1 = torch.rand_like(x)
|
||||||
|
n2 = torch.rand_like(x)
|
||||||
|
mask = torch.bernoulli(n1) - torch.bernoulli(n2)
|
||||||
|
return x + (self.noise * mask)
|
||||||
|
|
||||||
def generate(self, length):
|
def generate(self, length):
|
||||||
per_class = length // self.num_classes
|
per_class = length // self.num_classes
|
||||||
return torch.vstack(
|
samples_list = [init.generate(per_class) for init in self.initializers]
|
||||||
[init.generate(per_class) for init in self.initializers])
|
samples = torch.vstack(samples_list)
|
||||||
|
if self.noise is not None:
|
||||||
|
samples = self.add_noise(samples)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
# Labels
|
# Labels
|
||||||
class LabelsInitializer:
|
class LabelsInitializer:
|
||||||
def generate(self):
|
def generate(self):
|
||||||
pass
|
raise NotImplementedError("Subclasses should implement this!")
|
||||||
|
|
||||||
|
|
||||||
class EqualLabelInitializer(LabelsInitializer):
|
class EqualLabelInitializer(LabelsInitializer):
|
||||||
@ -120,7 +133,7 @@ class EqualLabelInitializer(LabelsInitializer):
|
|||||||
# Reasonings
|
# Reasonings
|
||||||
class ReasoningsInitializer:
|
class ReasoningsInitializer:
|
||||||
def generate(self, length):
|
def generate(self, length):
|
||||||
pass
|
raise NotImplementedError("Subclasses should implement this!")
|
||||||
|
|
||||||
|
|
||||||
class ZeroReasoningsInitializer(ReasoningsInitializer):
|
class ZeroReasoningsInitializer(ReasoningsInitializer):
|
||||||
@ -130,3 +143,9 @@ class ZeroReasoningsInitializer(ReasoningsInitializer):
|
|||||||
|
|
||||||
def generate(self):
|
def generate(self):
|
||||||
return torch.zeros((self.length, self.classes, 2))
|
return torch.zeros((self.length, self.classes, 2))
|
||||||
|
|
||||||
|
|
||||||
|
# Aliases
|
||||||
|
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
|
||||||
|
SMI = StratifiedMeanInitializer
|
||||||
|
Random = RandomInitializer = UniformInitializer
|
||||||
|
Loading…
Reference in New Issue
Block a user