Update initializers
This commit is contained in:
parent
9b663477fd
commit
b7d53aa5f1
@ -1,11 +1,13 @@
|
||||
import torch
|
||||
"""ProtoTroch Initializers."""
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Components
|
||||
class ComponentsInitializer:
|
||||
def generate(self, number_of_components):
|
||||
pass
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
|
||||
class DimensionAwareInitializer(ComponentsInitializer):
|
||||
@ -38,7 +40,7 @@ class UniformInitializer(DimensionAwareInitializer):
|
||||
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.FloatTensor(gen_dims).uniform_(self.min, self.max)
|
||||
return torch.ones(gen_dims).uniform_(self.min, self.max)
|
||||
|
||||
|
||||
class PositionAwareInitializer(ComponentsInitializer):
|
||||
@ -61,51 +63,62 @@ class MeanInitializer(PositionAwareInitializer):
|
||||
|
||||
|
||||
class ClassAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, positions, classes):
|
||||
def __init__(self, data, labels):
|
||||
super().__init__()
|
||||
self.data = positions
|
||||
self.classes = classes
|
||||
self.data = data
|
||||
self.labels = labels
|
||||
|
||||
self.names = torch.unique(self.classes)
|
||||
self.num_classes = len(self.names)
|
||||
self.clabels = torch.unique(self.labels)
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
|
||||
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||
def __init__(self, positions, classes):
|
||||
super().__init__(positions, classes)
|
||||
def __init__(self, data, labels):
|
||||
super().__init__(data, labels)
|
||||
|
||||
self.initializers = []
|
||||
for name in self.names:
|
||||
class_data = self.data[self.classes == name]
|
||||
for clabel in self.clabels:
|
||||
class_data = self.data[self.labels == clabel]
|
||||
class_initializer = MeanInitializer(class_data)
|
||||
self.initializers.append(class_initializer)
|
||||
|
||||
def generate(self, length):
|
||||
per_class = length // self.num_classes
|
||||
return torch.vstack(
|
||||
[init.generate(per_class) for init in self.initializers])
|
||||
samples_list = [init.generate(per_class) for init in self.initializers]
|
||||
return torch.vstack(samples_list)
|
||||
|
||||
|
||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||
def __init__(self, positions, classes):
|
||||
super().__init__(positions, classes)
|
||||
def __init__(self, data, labels, noise=None):
|
||||
super().__init__(data, labels)
|
||||
self.noise = noise
|
||||
|
||||
self.initializers = []
|
||||
for name in self.names:
|
||||
class_data = self.data[self.classes == name]
|
||||
for clabel in self.clabels:
|
||||
class_data = self.data[self.labels == clabel]
|
||||
class_initializer = SelectionInitializer(class_data)
|
||||
self.initializers.append(class_initializer)
|
||||
|
||||
def add_noise(self, x):
|
||||
"""Shifts some dimensions of the data randomly."""
|
||||
n1 = torch.rand_like(x)
|
||||
n2 = torch.rand_like(x)
|
||||
mask = torch.bernoulli(n1) - torch.bernoulli(n2)
|
||||
return x + (self.noise * mask)
|
||||
|
||||
def generate(self, length):
|
||||
per_class = length // self.num_classes
|
||||
return torch.vstack(
|
||||
[init.generate(per_class) for init in self.initializers])
|
||||
samples_list = [init.generate(per_class) for init in self.initializers]
|
||||
samples = torch.vstack(samples_list)
|
||||
if self.noise is not None:
|
||||
samples = self.add_noise(samples)
|
||||
return samples
|
||||
|
||||
|
||||
# Labels
|
||||
class LabelsInitializer:
|
||||
def generate(self):
|
||||
pass
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
|
||||
class EqualLabelInitializer(LabelsInitializer):
|
||||
@ -120,7 +133,7 @@ class EqualLabelInitializer(LabelsInitializer):
|
||||
# Reasonings
|
||||
class ReasoningsInitializer:
|
||||
def generate(self, length):
|
||||
pass
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
|
||||
class ZeroReasoningsInitializer(ReasoningsInitializer):
|
||||
@ -130,3 +143,9 @@ class ZeroReasoningsInitializer(ReasoningsInitializer):
|
||||
|
||||
def generate(self):
|
||||
return torch.zeros((self.length, self.classes, 2))
|
||||
|
||||
|
||||
# Aliases
|
||||
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
|
||||
SMI = StratifiedMeanInitializer
|
||||
Random = RandomInitializer = UniformInitializer
|
||||
|
Loading…
Reference in New Issue
Block a user