Apply transformations in component initializers

This commit is contained in:
Jensun Ravichandran 2021-05-17 16:58:22 +02:00
parent e73b70ceb7
commit dc6248413c

View File

@ -7,12 +7,11 @@ import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
def parse_init_arg(arg): def parse_data_arg(data):
if isinstance(arg, Dataset): if isinstance(data, Dataset):
data, labels = next(iter(DataLoader(arg, batch_size=len(arg)))) data, labels = next(iter(DataLoader(data, batch_size=len(data))))
# data = data.view(len(arg), -1) # flatten
else: else:
data, labels = arg data, labels = data
if not isinstance(data, torch.Tensor): if not isinstance(data, torch.Tensor):
wmsg = f"Converting data to {torch.Tensor}." wmsg = f"Converting data to {torch.Tensor}."
warnings.warn(wmsg) warnings.warn(wmsg)
@ -83,12 +82,14 @@ class MeanInitializer(PositionAwareInitializer):
class ClassAwareInitializer(ComponentsInitializer): class ClassAwareInitializer(ComponentsInitializer):
def __init__(self, arg): def __init__(self, data, transform=torch.nn.Identity()):
super().__init__() super().__init__()
data, labels = parse_init_arg(arg) data, labels = parse_data_arg(data)
self.data = data self.data = data
self.labels = labels self.labels = labels
self.transform = transform
self.clabels = torch.unique(self.labels) self.clabels = torch.unique(self.labels)
self.num_classes = len(self.clabels) self.num_classes = len(self.clabels)
@ -99,7 +100,10 @@ class ClassAwareInitializer(ComponentsInitializer):
samples_list = [ samples_list = [
init.generate(n) for init, n in zip(self.initializers, dist) init.generate(n) for init, n in zip(self.initializers, dist)
] ]
return torch.vstack(samples_list) out = torch.vstack(samples_list)
with torch.no_grad():
out = self.transform(out)
return out
def __del__(self): def __del__(self):
del self.data del self.data
@ -107,8 +111,8 @@ class ClassAwareInitializer(ComponentsInitializer):
class StratifiedMeanInitializer(ClassAwareInitializer): class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, arg): def __init__(self, data, **kwargs):
super().__init__(arg) super().__init__(data, **kwargs)
self.initializers = [] self.initializers = []
for clabel in self.clabels: for clabel in self.clabels:
@ -122,8 +126,8 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
class StratifiedSelectionInitializer(ClassAwareInitializer): class StratifiedSelectionInitializer(ClassAwareInitializer):
def __init__(self, arg, *, noise=None): def __init__(self, data, noise=None, **kwargs):
super().__init__(arg) super().__init__(data, **kwargs)
self.noise = noise self.noise = noise
self.initializers = [] self.initializers = []
@ -132,7 +136,10 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
class_initializer = SelectionInitializer(class_data) class_initializer = SelectionInitializer(class_data)
self.initializers.append(class_initializer) self.initializers.append(class_initializer)
def add_noise(self, x): def add_noise_v1(self, x):
return x + self.noise
def add_noise_v2(self, x):
"""Shifts some dimensions of the data randomly.""" """Shifts some dimensions of the data randomly."""
n1 = torch.rand_like(x) n1 = torch.rand_like(x)
n2 = torch.rand_like(x) n2 = torch.rand_like(x)
@ -142,8 +149,7 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
def generate(self, length, dist=[]): def generate(self, length, dist=[]):
samples = self._get_samples_from_initializer(length, dist) samples = self._get_samples_from_initializer(length, dist)
if self.noise is not None: if self.noise is not None:
# samples = self.add_noise(samples) samples = self.add_noise_v1(samples)
samples = samples + self.noise
return samples return samples