[FEATURE] Add distribution property to LabeledComponents

This commit is contained in:
Jensun Ravichandran 2021-06-14 21:08:48 +02:00
parent d45e71256c
commit 1f458ac0cc

View File

@ -116,7 +116,7 @@ class AbstractLabels(torch.nn.Module):
@property @property
def num_labels(self): def num_labels(self):
return len(self.labels) return len(self._labels)
@property @property
def unique_labels(self): def unique_labels(self):
@ -193,6 +193,13 @@ class LabeledComponents(AbstractComponents):
"""Tensor containing the component labels.""" """Tensor containing the component labels."""
return self._labels return self._labels
@property
def distribution(self):
unique, counts = torch.unique(self._labels,
sorted=True,
return_counts=True)
return dict(zip(unique.tolist(), counts.tolist()))
def _register_labels(self, labels): def _register_labels(self, labels):
self.register_buffer("_labels", labels) self.register_buffer("_labels", labels)