[BUGFIX] Parse dictionary distribution appropirately
This commit is contained in:
parent
8a291f7bfb
commit
9f5f0d12dd
@ -103,9 +103,11 @@ class ClassAwareInitializer(ComponentsInitializer):
|
|||||||
def _get_samples_from_initializer(self, length, dist):
|
def _get_samples_from_initializer(self, length, dist):
|
||||||
if not dist:
|
if not dist:
|
||||||
per_class = length // self.num_classes
|
per_class = length // self.num_classes
|
||||||
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
|
dist = self.num_classes * [per_class]
|
||||||
|
if type(dist) == dict:
|
||||||
|
dist = dist.values()
|
||||||
samples_list = [
|
samples_list = [
|
||||||
init.generate(n) for init, n in zip(self.initializers, dist.values())
|
init.generate(n) for init, n in zip(self.initializers, dist)
|
||||||
]
|
]
|
||||||
out = torch.vstack(samples_list)
|
out = torch.vstack(samples_list)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
Loading…
Reference in New Issue
Block a user