[FEATURE] Add distribution property to LabeledComponents
This commit is contained in:
parent
d45e71256c
commit
1f458ac0cc
@ -116,7 +116,7 @@ class AbstractLabels(torch.nn.Module):
|
||||
|
||||
@property
|
||||
def num_labels(self):
|
||||
return len(self.labels)
|
||||
return len(self._labels)
|
||||
|
||||
@property
|
||||
def unique_labels(self):
|
||||
@ -193,6 +193,13 @@ class LabeledComponents(AbstractComponents):
|
||||
"""Tensor containing the component labels."""
|
||||
return self._labels
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
unique, counts = torch.unique(self._labels,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
return dict(zip(unique.tolist(), counts.tolist()))
|
||||
|
||||
def _register_labels(self, labels):
|
||||
self.register_buffer("_labels", labels)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user