[FEATURE] Add distribution property to LabeledComponents
This commit is contained in:
parent
d45e71256c
commit
1f458ac0cc
@ -116,7 +116,7 @@ class AbstractLabels(torch.nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def num_labels(self):
|
def num_labels(self):
|
||||||
return len(self.labels)
|
return len(self._labels)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unique_labels(self):
|
def unique_labels(self):
|
||||||
@ -193,6 +193,13 @@ class LabeledComponents(AbstractComponents):
|
|||||||
"""Tensor containing the component labels."""
|
"""Tensor containing the component labels."""
|
||||||
return self._labels
|
return self._labels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def distribution(self):
|
||||||
|
unique, counts = torch.unique(self._labels,
|
||||||
|
sorted=True,
|
||||||
|
return_counts=True)
|
||||||
|
return dict(zip(unique.tolist(), counts.tolist()))
|
||||||
|
|
||||||
def _register_labels(self, labels):
|
def _register_labels(self, labels):
|
||||||
self.register_buffer("_labels", labels)
|
self.register_buffer("_labels", labels)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user