[TEST] Add more tests
This commit is contained in:
@@ -81,14 +81,9 @@ class ClassAwareCompInitializer(AbstractComponentsInitializer):
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
initializers = {
|
||||
k: self.subinit_type(self.data[self.targets == k])
|
||||
for k in distribution.keys()
|
||||
}
|
||||
components = torch.tensor([])
|
||||
for k, v in distribution.items():
|
||||
stratified_data = self.data[self.targets == k]
|
||||
# skip transform here
|
||||
initializer = self.subinit_type(
|
||||
stratified_data,
|
||||
noise=self.noise,
|
||||
@@ -157,13 +152,14 @@ class UniformCompInitializer(OnesCompInitializer):
|
||||
|
||||
class RandomNormalCompInitializer(OnesCompInitializer):
|
||||
"""Generate components by sampling from a standard normal distribution."""
|
||||
def __init__(self, shape, scale=1.0):
|
||||
def __init__(self, shape, shift=0.0, scale=1.0):
|
||||
super().__init__(shape)
|
||||
self.shift = shift
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, num_components: int):
|
||||
ones = super().generate(num_components)
|
||||
components = self.scale * torch.randn_like(ones)
|
||||
components = self.scale * (torch.randn_like(ones) + self.shift)
|
||||
return components
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user