[TEST] Add more tests

This commit is contained in:
Jensun Ravichandran
2021-06-14 14:45:14 +02:00
parent d2d6f31e7b
commit 668c9a1fb7
2 changed files with 43 additions and 19 deletions

View File

@@ -81,14 +81,9 @@ class ClassAwareCompInitializer(AbstractComponentsInitializer):
def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution)
initializers = {
k: self.subinit_type(self.data[self.targets == k])
for k in distribution.keys()
}
components = torch.tensor([])
for k, v in distribution.items():
stratified_data = self.data[self.targets == k]
# skip transform here
initializer = self.subinit_type(
stratified_data,
noise=self.noise,
@@ -157,13 +152,14 @@ class UniformCompInitializer(OnesCompInitializer):
class RandomNormalCompInitializer(OnesCompInitializer):
"""Generate components by sampling from a standard normal distribution."""
def __init__(self, shape, scale=1.0):
def __init__(self, shape, shift=0.0, scale=1.0):
super().__init__(shape)
self.shift = shift
self.scale = scale
def generate(self, num_components: int):
ones = super().generate(num_components)
components = self.scale * torch.randn_like(ones)
components = self.scale * (torch.randn_like(ones) + self.shift)
return components