Apply transformations in component initializers
This commit is contained in:
parent
e73b70ceb7
commit
dc6248413c
@ -7,12 +7,11 @@ import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
def parse_init_arg(arg):
|
||||
if isinstance(arg, Dataset):
|
||||
data, labels = next(iter(DataLoader(arg, batch_size=len(arg))))
|
||||
# data = data.view(len(arg), -1) # flatten
|
||||
def parse_data_arg(data):
|
||||
if isinstance(data, Dataset):
|
||||
data, labels = next(iter(DataLoader(data, batch_size=len(data))))
|
||||
else:
|
||||
data, labels = arg
|
||||
data, labels = data
|
||||
if not isinstance(data, torch.Tensor):
|
||||
wmsg = f"Converting data to {torch.Tensor}."
|
||||
warnings.warn(wmsg)
|
||||
@ -83,12 +82,14 @@ class MeanInitializer(PositionAwareInitializer):
|
||||
|
||||
|
||||
class ClassAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, arg):
|
||||
def __init__(self, data, transform=torch.nn.Identity()):
|
||||
super().__init__()
|
||||
data, labels = parse_init_arg(arg)
|
||||
data, labels = parse_data_arg(data)
|
||||
self.data = data
|
||||
self.labels = labels
|
||||
|
||||
self.transform = transform
|
||||
|
||||
self.clabels = torch.unique(self.labels)
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
@ -99,7 +100,10 @@ class ClassAwareInitializer(ComponentsInitializer):
|
||||
samples_list = [
|
||||
init.generate(n) for init, n in zip(self.initializers, dist)
|
||||
]
|
||||
return torch.vstack(samples_list)
|
||||
out = torch.vstack(samples_list)
|
||||
with torch.no_grad():
|
||||
out = self.transform(out)
|
||||
return out
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
@ -107,8 +111,8 @@ class ClassAwareInitializer(ComponentsInitializer):
|
||||
|
||||
|
||||
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||
def __init__(self, arg):
|
||||
super().__init__(arg)
|
||||
def __init__(self, data, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
|
||||
self.initializers = []
|
||||
for clabel in self.clabels:
|
||||
@ -122,8 +126,8 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||
|
||||
|
||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||
def __init__(self, arg, *, noise=None):
|
||||
super().__init__(arg)
|
||||
def __init__(self, data, noise=None, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
self.noise = noise
|
||||
|
||||
self.initializers = []
|
||||
@ -132,7 +136,10 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||
class_initializer = SelectionInitializer(class_data)
|
||||
self.initializers.append(class_initializer)
|
||||
|
||||
def add_noise(self, x):
|
||||
def add_noise_v1(self, x):
|
||||
return x + self.noise
|
||||
|
||||
def add_noise_v2(self, x):
|
||||
"""Shifts some dimensions of the data randomly."""
|
||||
n1 = torch.rand_like(x)
|
||||
n2 = torch.rand_like(x)
|
||||
@ -142,8 +149,7 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||
def generate(self, length, dist=[]):
|
||||
samples = self._get_samples_from_initializer(length, dist)
|
||||
if self.noise is not None:
|
||||
# samples = self.add_noise(samples)
|
||||
samples = samples + self.noise
|
||||
samples = self.add_noise_v1(samples)
|
||||
return samples
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user