diff --git a/prototorch/core/components.py b/prototorch/core/components.py index d0155a7..f1694ab 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -116,7 +116,7 @@ class AbstractLabels(torch.nn.Module): @property def num_labels(self): - return len(self.labels) + return len(self._labels) @property def unique_labels(self): @@ -193,6 +193,13 @@ class LabeledComponents(AbstractComponents): """Tensor containing the component labels.""" return self._labels + @property + def distribution(self): + unique, counts = torch.unique(self._labels, + sorted=True, + return_counts=True) + return dict(zip(unique.tolist(), counts.tolist())) + def _register_labels(self, labels): self.register_buffer("_labels", labels)