[BUGFIX] Parse dictionary distribution appropirately

This commit is contained in:
Jensun Ravichandran 2021-05-25 20:52:39 +02:00
parent 8a291f7bfb
commit 9f5f0d12dd

View File

@ -103,9 +103,11 @@ class ClassAwareInitializer(ComponentsInitializer):
def _get_samples_from_initializer(self, length, dist): def _get_samples_from_initializer(self, length, dist):
if not dist: if not dist:
per_class = length // self.num_classes per_class = length // self.num_classes
dist = dict(zip(self.clabels, self.num_classes * [per_class])) dist = self.num_classes * [per_class]
if type(dist) == dict:
dist = dist.values()
samples_list = [ samples_list = [
init.generate(n) for init, n in zip(self.initializers, dist.values()) init.generate(n) for init, n in zip(self.initializers, dist)
] ]
out = torch.vstack(samples_list) out = torch.vstack(samples_list)
with torch.no_grad(): with torch.no_grad():