diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py index e1dc938..78e7b4e 100644 --- a/prototorch/components/initializers.py +++ b/prototorch/components/initializers.py @@ -132,7 +132,7 @@ class StratifiedMeanInitializer(ClassAwareInitializer): self.initializers = get_subinitializers(self.data, self.targets, self.clabels, MeanInitializer) - def generate(self, length, dist=[]): + def generate(self, length, dist): samples = self._get_samples_from_initializer(length, dist) return samples