Apply transformations in component initializers
This commit is contained in:
		@@ -7,12 +7,11 @@ import torch
 | 
				
			|||||||
from torch.utils.data import DataLoader, Dataset
 | 
					from torch.utils.data import DataLoader, Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def parse_init_arg(arg):
 | 
					def parse_data_arg(data):
 | 
				
			||||||
    if isinstance(arg, Dataset):
 | 
					    if isinstance(data, Dataset):
 | 
				
			||||||
        data, labels = next(iter(DataLoader(arg, batch_size=len(arg))))
 | 
					        data, labels = next(iter(DataLoader(data, batch_size=len(data))))
 | 
				
			||||||
        # data = data.view(len(arg), -1)  # flatten
 | 
					 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        data, labels = arg
 | 
					        data, labels = data
 | 
				
			||||||
        if not isinstance(data, torch.Tensor):
 | 
					        if not isinstance(data, torch.Tensor):
 | 
				
			||||||
            wmsg = f"Converting data to {torch.Tensor}."
 | 
					            wmsg = f"Converting data to {torch.Tensor}."
 | 
				
			||||||
            warnings.warn(wmsg)
 | 
					            warnings.warn(wmsg)
 | 
				
			||||||
@@ -83,12 +82,14 @@ class MeanInitializer(PositionAwareInitializer):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ClassAwareInitializer(ComponentsInitializer):
 | 
					class ClassAwareInitializer(ComponentsInitializer):
 | 
				
			||||||
    def __init__(self, arg):
 | 
					    def __init__(self, data, transform=torch.nn.Identity()):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        data, labels = parse_init_arg(arg)
 | 
					        data, labels = parse_data_arg(data)
 | 
				
			||||||
        self.data = data
 | 
					        self.data = data
 | 
				
			||||||
        self.labels = labels
 | 
					        self.labels = labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.transform = transform
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.clabels = torch.unique(self.labels)
 | 
					        self.clabels = torch.unique(self.labels)
 | 
				
			||||||
        self.num_classes = len(self.clabels)
 | 
					        self.num_classes = len(self.clabels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -99,7 +100,10 @@ class ClassAwareInitializer(ComponentsInitializer):
 | 
				
			|||||||
        samples_list = [
 | 
					        samples_list = [
 | 
				
			||||||
            init.generate(n) for init, n in zip(self.initializers, dist)
 | 
					            init.generate(n) for init, n in zip(self.initializers, dist)
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
        return torch.vstack(samples_list)
 | 
					        out = torch.vstack(samples_list)
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            out = self.transform(out)
 | 
				
			||||||
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __del__(self):
 | 
					    def __del__(self):
 | 
				
			||||||
        del self.data
 | 
					        del self.data
 | 
				
			||||||
@@ -107,8 +111,8 @@ class ClassAwareInitializer(ComponentsInitializer):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class StratifiedMeanInitializer(ClassAwareInitializer):
 | 
					class StratifiedMeanInitializer(ClassAwareInitializer):
 | 
				
			||||||
    def __init__(self, arg):
 | 
					    def __init__(self, data, **kwargs):
 | 
				
			||||||
        super().__init__(arg)
 | 
					        super().__init__(data, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.initializers = []
 | 
					        self.initializers = []
 | 
				
			||||||
        for clabel in self.clabels:
 | 
					        for clabel in self.clabels:
 | 
				
			||||||
@@ -122,8 +126,8 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
 | 
					class StratifiedSelectionInitializer(ClassAwareInitializer):
 | 
				
			||||||
    def __init__(self, arg, *, noise=None):
 | 
					    def __init__(self, data, noise=None, **kwargs):
 | 
				
			||||||
        super().__init__(arg)
 | 
					        super().__init__(data, **kwargs)
 | 
				
			||||||
        self.noise = noise
 | 
					        self.noise = noise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.initializers = []
 | 
					        self.initializers = []
 | 
				
			||||||
@@ -132,7 +136,10 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
 | 
				
			|||||||
            class_initializer = SelectionInitializer(class_data)
 | 
					            class_initializer = SelectionInitializer(class_data)
 | 
				
			||||||
            self.initializers.append(class_initializer)
 | 
					            self.initializers.append(class_initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add_noise(self, x):
 | 
					    def add_noise_v1(self, x):
 | 
				
			||||||
 | 
					        return x + self.noise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_noise_v2(self, x):
 | 
				
			||||||
        """Shifts some dimensions of the data randomly."""
 | 
					        """Shifts some dimensions of the data randomly."""
 | 
				
			||||||
        n1 = torch.rand_like(x)
 | 
					        n1 = torch.rand_like(x)
 | 
				
			||||||
        n2 = torch.rand_like(x)
 | 
					        n2 = torch.rand_like(x)
 | 
				
			||||||
@@ -142,8 +149,7 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
 | 
				
			|||||||
    def generate(self, length, dist=[]):
 | 
					    def generate(self, length, dist=[]):
 | 
				
			||||||
        samples = self._get_samples_from_initializer(length, dist)
 | 
					        samples = self._get_samples_from_initializer(length, dist)
 | 
				
			||||||
        if self.noise is not None:
 | 
					        if self.noise is not None:
 | 
				
			||||||
            # samples = self.add_noise(samples)
 | 
					            samples = self.add_noise_v1(samples)
 | 
				
			||||||
            samples = samples + self.noise
 | 
					 | 
				
			||||||
        return samples
 | 
					        return samples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user