diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index 909628a..a21fad8 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -217,6 +217,8 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer): components = torch.tensor([]) for k, v in distribution.items(): stratified_data = self.data[self.targets == k] + if len(stratified_data) == 0: + raise ValueError(f"No data available for class {k}.") initializer = self.subinit_type( stratified_data, noise=self.noise,