Use dict for distribution

This change allows the use of LightningCLI.
This commit is contained in:
Alexander Engelsberger 2021-05-21 17:10:02 +02:00
parent 1e23ba05fa
commit aff7a385a3

View File

@ -86,7 +86,11 @@ class LabeledComponents(Components):
super()._initialize_components(initializer) super()._initialize_components(initializer)
def _initialize_labels(self, distribution): def _initialize_labels(self, distribution):
if type(distribution) == tuple: if type(distribution) == dict:
labels = EqualLabelsInitializer(
distribution["num_classes"],
distribution["prototypes_per_class"])
elif type(distribution) == tuple:
num_classes, prototypes_per_class = distribution num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class) labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif type(distribution) == list: elif type(distribution) == list: