Use dict for distribution
This change allows the use of LightningCLI.
This commit is contained in:
parent
1e23ba05fa
commit
aff7a385a3
@ -86,7 +86,11 @@ class LabeledComponents(Components):
|
|||||||
super()._initialize_components(initializer)
|
super()._initialize_components(initializer)
|
||||||
|
|
||||||
def _initialize_labels(self, distribution):
|
def _initialize_labels(self, distribution):
|
||||||
if type(distribution) == tuple:
|
if type(distribution) == dict:
|
||||||
|
labels = EqualLabelsInitializer(
|
||||||
|
distribution["num_classes"],
|
||||||
|
distribution["prototypes_per_class"])
|
||||||
|
elif type(distribution) == tuple:
|
||||||
num_classes, prototypes_per_class = distribution
|
num_classes, prototypes_per_class = distribution
|
||||||
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||||
elif type(distribution) == list:
|
elif type(distribution) == list:
|
||||||
|
Loading…
Reference in New Issue
Block a user