Apply transformations in component initializers
This commit is contained in:
parent
e73b70ceb7
commit
dc6248413c
@ -7,12 +7,11 @@ import torch
|
|||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
|
||||||
def parse_init_arg(arg):
|
def parse_data_arg(data):
|
||||||
if isinstance(arg, Dataset):
|
if isinstance(data, Dataset):
|
||||||
data, labels = next(iter(DataLoader(arg, batch_size=len(arg))))
|
data, labels = next(iter(DataLoader(data, batch_size=len(data))))
|
||||||
# data = data.view(len(arg), -1) # flatten
|
|
||||||
else:
|
else:
|
||||||
data, labels = arg
|
data, labels = data
|
||||||
if not isinstance(data, torch.Tensor):
|
if not isinstance(data, torch.Tensor):
|
||||||
wmsg = f"Converting data to {torch.Tensor}."
|
wmsg = f"Converting data to {torch.Tensor}."
|
||||||
warnings.warn(wmsg)
|
warnings.warn(wmsg)
|
||||||
@ -83,12 +82,14 @@ class MeanInitializer(PositionAwareInitializer):
|
|||||||
|
|
||||||
|
|
||||||
class ClassAwareInitializer(ComponentsInitializer):
|
class ClassAwareInitializer(ComponentsInitializer):
|
||||||
def __init__(self, arg):
|
def __init__(self, data, transform=torch.nn.Identity()):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
data, labels = parse_init_arg(arg)
|
data, labels = parse_data_arg(data)
|
||||||
self.data = data
|
self.data = data
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
|
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
self.clabels = torch.unique(self.labels)
|
self.clabels = torch.unique(self.labels)
|
||||||
self.num_classes = len(self.clabels)
|
self.num_classes = len(self.clabels)
|
||||||
|
|
||||||
@ -99,7 +100,10 @@ class ClassAwareInitializer(ComponentsInitializer):
|
|||||||
samples_list = [
|
samples_list = [
|
||||||
init.generate(n) for init, n in zip(self.initializers, dist)
|
init.generate(n) for init, n in zip(self.initializers, dist)
|
||||||
]
|
]
|
||||||
return torch.vstack(samples_list)
|
out = torch.vstack(samples_list)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = self.transform(out)
|
||||||
|
return out
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
del self.data
|
del self.data
|
||||||
@ -107,8 +111,8 @@ class ClassAwareInitializer(ComponentsInitializer):
|
|||||||
|
|
||||||
|
|
||||||
class StratifiedMeanInitializer(ClassAwareInitializer):
|
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||||
def __init__(self, arg):
|
def __init__(self, data, **kwargs):
|
||||||
super().__init__(arg)
|
super().__init__(data, **kwargs)
|
||||||
|
|
||||||
self.initializers = []
|
self.initializers = []
|
||||||
for clabel in self.clabels:
|
for clabel in self.clabels:
|
||||||
@ -122,8 +126,8 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
|
|||||||
|
|
||||||
|
|
||||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||||
def __init__(self, arg, *, noise=None):
|
def __init__(self, data, noise=None, **kwargs):
|
||||||
super().__init__(arg)
|
super().__init__(data, **kwargs)
|
||||||
self.noise = noise
|
self.noise = noise
|
||||||
|
|
||||||
self.initializers = []
|
self.initializers = []
|
||||||
@ -132,7 +136,10 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
|
|||||||
class_initializer = SelectionInitializer(class_data)
|
class_initializer = SelectionInitializer(class_data)
|
||||||
self.initializers.append(class_initializer)
|
self.initializers.append(class_initializer)
|
||||||
|
|
||||||
def add_noise(self, x):
|
def add_noise_v1(self, x):
|
||||||
|
return x + self.noise
|
||||||
|
|
||||||
|
def add_noise_v2(self, x):
|
||||||
"""Shifts some dimensions of the data randomly."""
|
"""Shifts some dimensions of the data randomly."""
|
||||||
n1 = torch.rand_like(x)
|
n1 = torch.rand_like(x)
|
||||||
n2 = torch.rand_like(x)
|
n2 = torch.rand_like(x)
|
||||||
@ -142,8 +149,7 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
|
|||||||
def generate(self, length, dist=[]):
|
def generate(self, length, dist=[]):
|
||||||
samples = self._get_samples_from_initializer(length, dist)
|
samples = self._get_samples_from_initializer(length, dist)
|
||||||
if self.noise is not None:
|
if self.noise is not None:
|
||||||
# samples = self.add_noise(samples)
|
samples = self.add_noise_v1(samples)
|
||||||
samples = samples + self.noise
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user