[BUGFIX] Parse dictionary distribution appropirately
This commit is contained in:
parent
8a291f7bfb
commit
9f5f0d12dd
@ -103,9 +103,11 @@ class ClassAwareInitializer(ComponentsInitializer):
|
||||
def _get_samples_from_initializer(self, length, dist):
|
||||
if not dist:
|
||||
per_class = length // self.num_classes
|
||||
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
|
||||
dist = self.num_classes * [per_class]
|
||||
if type(dist) == dict:
|
||||
dist = dist.values()
|
||||
samples_list = [
|
||||
init.generate(n) for init, n in zip(self.initializers, dist.values())
|
||||
init.generate(n) for init, n in zip(self.initializers, dist)
|
||||
]
|
||||
out = torch.vstack(samples_list)
|
||||
with torch.no_grad():
|
||||
|
Loading…
Reference in New Issue
Block a user