Apply transformations in component initializers

This commit is contained in:
Jensun Ravichandran 2021-05-17 16:58:22 +02:00
parent e73b70ceb7
commit dc6248413c

View File

@ -7,12 +7,11 @@ import torch
from torch.utils.data import DataLoader, Dataset
def parse_init_arg(arg):
if isinstance(arg, Dataset):
data, labels = next(iter(DataLoader(arg, batch_size=len(arg))))
# data = data.view(len(arg), -1) # flatten
def parse_data_arg(data):
if isinstance(data, Dataset):
data, labels = next(iter(DataLoader(data, batch_size=len(data))))
else:
data, labels = arg
data, labels = data
if not isinstance(data, torch.Tensor):
wmsg = f"Converting data to {torch.Tensor}."
warnings.warn(wmsg)
@ -83,12 +82,14 @@ class MeanInitializer(PositionAwareInitializer):
class ClassAwareInitializer(ComponentsInitializer):
def __init__(self, arg):
def __init__(self, data, transform=torch.nn.Identity()):
super().__init__()
data, labels = parse_init_arg(arg)
data, labels = parse_data_arg(data)
self.data = data
self.labels = labels
self.transform = transform
self.clabels = torch.unique(self.labels)
self.num_classes = len(self.clabels)
@ -99,7 +100,10 @@ class ClassAwareInitializer(ComponentsInitializer):
samples_list = [
init.generate(n) for init, n in zip(self.initializers, dist)
]
return torch.vstack(samples_list)
out = torch.vstack(samples_list)
with torch.no_grad():
out = self.transform(out)
return out
def __del__(self):
del self.data
@ -107,8 +111,8 @@ class ClassAwareInitializer(ComponentsInitializer):
class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, arg):
super().__init__(arg)
def __init__(self, data, **kwargs):
super().__init__(data, **kwargs)
self.initializers = []
for clabel in self.clabels:
@ -122,8 +126,8 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
class StratifiedSelectionInitializer(ClassAwareInitializer):
def __init__(self, arg, *, noise=None):
super().__init__(arg)
def __init__(self, data, noise=None, **kwargs):
super().__init__(data, **kwargs)
self.noise = noise
self.initializers = []
@ -132,7 +136,10 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
class_initializer = SelectionInitializer(class_data)
self.initializers.append(class_initializer)
def add_noise(self, x):
def add_noise_v1(self, x):
return x + self.noise
def add_noise_v2(self, x):
"""Shifts some dimensions of the data randomly."""
n1 = torch.rand_like(x)
n2 = torch.rand_like(x)
@ -142,8 +149,7 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
def generate(self, length, dist=[]):
samples = self._get_samples_from_initializer(length, dist)
if self.noise is not None:
# samples = self.add_noise(samples)
samples = samples + self.noise
samples = self.add_noise_v1(samples)
return samples