[FEATURE] Add standalone reasonings and CBC competition
This commit is contained in:
parent
0f450ed8a0
commit
6e8a52e371
@ -14,6 +14,7 @@ from .core import (
|
|||||||
components,
|
components,
|
||||||
distances,
|
distances,
|
||||||
initializers,
|
initializers,
|
||||||
|
similarities,
|
||||||
losses,
|
losses,
|
||||||
pooling,
|
pooling,
|
||||||
)
|
)
|
||||||
@ -31,6 +32,7 @@ __all_core__ = [
|
|||||||
"losses",
|
"losses",
|
||||||
"nn",
|
"nn",
|
||||||
"pooling",
|
"pooling",
|
||||||
|
"similarities",
|
||||||
"utils",
|
"utils",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -28,6 +28,24 @@ def knnc(distances: torch.Tensor,
|
|||||||
return winning_labels
|
return winning_labels
|
||||||
|
|
||||||
|
|
||||||
|
def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
|
||||||
|
"""Classification-By-Components Competition.
|
||||||
|
|
||||||
|
Returns probability distributions over the classes.
|
||||||
|
|
||||||
|
`detections` must be of shape [batch_size, num_components].
|
||||||
|
`reasonings` must be of shape [num_components, num_classes, 2].
|
||||||
|
|
||||||
|
"""
|
||||||
|
A, B = reasonings.permute(2, 1, 0).clamp(0, 1)
|
||||||
|
pk = A
|
||||||
|
nk = (1 - A) * B
|
||||||
|
numerator = (detections @ (pk - nk).T) + nk.sum(1)
|
||||||
|
probs = numerator / (pk + nk).sum(1)
|
||||||
|
# probs = probs.squeeze(0)
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
class WTAC(torch.nn.Module):
|
class WTAC(torch.nn.Module):
|
||||||
"""Winner-Takes-All-Competition Layer.
|
"""Winner-Takes-All-Competition Layer.
|
||||||
|
|
||||||
@ -63,3 +81,13 @@ class KNNC(torch.nn.Module):
|
|||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f"k: {self.k}"
|
return f"k: {self.k}"
|
||||||
|
|
||||||
|
|
||||||
|
class CBCC(torch.nn.Module):
|
||||||
|
"""Classification-By-Components Competition.
|
||||||
|
|
||||||
|
Thin wrapper over the `cbcc` function.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def forward(self, detections, reasonings):
|
||||||
|
return cbcc(detections, reasonings)
|
||||||
|
@ -13,6 +13,7 @@ from .initializers import (
|
|||||||
AbstractLabelsInitializer,
|
AbstractLabelsInitializer,
|
||||||
AbstractReasoningsInitializer,
|
AbstractReasoningsInitializer,
|
||||||
LabelsInitializer,
|
LabelsInitializer,
|
||||||
|
RandomReasoningsInitializer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -112,7 +113,7 @@ class AbstractLabels(torch.nn.Module):
|
|||||||
"""Abstract class for all labels modules."""
|
"""Abstract class for all labels modules."""
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
return self._labels
|
return self._labels.cpu()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_labels(self):
|
def num_labels(self):
|
||||||
@ -174,6 +175,10 @@ class Labels(AbstractLabels):
|
|||||||
self._register_labels(_labels)
|
self._register_labels(_labels)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
"""Simply return the labels."""
|
||||||
|
return self._labels
|
||||||
|
|
||||||
|
|
||||||
class LabeledComponents(AbstractComponents):
|
class LabeledComponents(AbstractComponents):
|
||||||
"""A set of adaptable components and corresponding unadaptable labels."""
|
"""A set of adaptable components and corresponding unadaptable labels."""
|
||||||
@ -188,11 +193,6 @@ class LabeledComponents(AbstractComponents):
|
|||||||
self.add_components(distribution, components_initializer,
|
self.add_components(distribution, components_initializer,
|
||||||
labels_initializer)
|
labels_initializer)
|
||||||
|
|
||||||
@property
|
|
||||||
def labels(self):
|
|
||||||
"""Tensor containing the component labels."""
|
|
||||||
return self._labels
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def distribution(self):
|
def distribution(self):
|
||||||
unique, counts = torch.unique(self._labels,
|
unique, counts = torch.unique(self._labels,
|
||||||
@ -200,6 +200,15 @@ class LabeledComponents(AbstractComponents):
|
|||||||
return_counts=True)
|
return_counts=True)
|
||||||
return dict(zip(unique.tolist(), counts.tolist()))
|
return dict(zip(unique.tolist(), counts.tolist()))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_classes(self):
|
||||||
|
return len(self.distribution.keys())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def labels(self):
|
||||||
|
"""Tensor containing the component labels."""
|
||||||
|
return self._labels.cpu()
|
||||||
|
|
||||||
def _register_labels(self, labels):
|
def _register_labels(self, labels):
|
||||||
self.register_buffer("_labels", labels)
|
self.register_buffer("_labels", labels)
|
||||||
|
|
||||||
@ -236,6 +245,64 @@ class LabeledComponents(AbstractComponents):
|
|||||||
return self._components, self._labels
|
return self._components, self._labels
|
||||||
|
|
||||||
|
|
||||||
|
class Reasonings(torch.nn.Module):
|
||||||
|
"""A set of standalone reasoning matrices.
|
||||||
|
|
||||||
|
The `reasonings` tensor is of shape [num_components, num_classes, 2].
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
distribution: Union[dict, list, tuple],
|
||||||
|
initializer:
|
||||||
|
AbstractReasoningsInitializer = RandomReasoningsInitializer(),
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_classes(self):
|
||||||
|
return self._reasonings.shape[1]
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def reasonings(self):
|
||||||
|
# """Tensor containing the reasoning matrices."""
|
||||||
|
# return self._reasonings.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reasonings(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1)
|
||||||
|
pk = A
|
||||||
|
nk = (1 - pk) * B
|
||||||
|
ik = 1 - pk - nk
|
||||||
|
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
|
||||||
|
return img.unsqueeze(1).cpu()
|
||||||
|
|
||||||
|
def _register_reasonings(self, reasonings):
|
||||||
|
self.register_buffer("_reasonings", reasonings)
|
||||||
|
|
||||||
|
def add_reasonings(
|
||||||
|
self,
|
||||||
|
distribution: Union[dict, list, tuple],
|
||||||
|
initializer:
|
||||||
|
AbstractReasoningsInitializer = RandomReasoningsInitializer()):
|
||||||
|
"""Generate and add new reasonings."""
|
||||||
|
assert validate_initializer(initializer, AbstractReasoningsInitializer)
|
||||||
|
_reasonings, new_reasonings = gencat(self, "_reasonings", initializer,
|
||||||
|
distribution)
|
||||||
|
self._register_reasonings(_reasonings)
|
||||||
|
return new_reasonings
|
||||||
|
|
||||||
|
def remove_reasonings(self, indices):
|
||||||
|
"""Remove reasonings at specified indices."""
|
||||||
|
_reasonings, mask = removeind(self, "_reasonings", indices)
|
||||||
|
self._register_reasonings(_reasonings)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
"""Simply return the reasonings."""
|
||||||
|
return self._reasonings
|
||||||
|
|
||||||
|
|
||||||
class ReasoningComponents(AbstractComponents):
|
class ReasoningComponents(AbstractComponents):
|
||||||
"""A set of components and a corresponding adapatable reasoning matrices.
|
"""A set of components and a corresponding adapatable reasoning matrices.
|
||||||
|
|
||||||
@ -260,13 +327,8 @@ class ReasoningComponents(AbstractComponents):
|
|||||||
reasonings_initializer)
|
reasonings_initializer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reasonings(self):
|
def num_classes(self):
|
||||||
"""Returns Reasoning Matrix.
|
return self._reasonings.shape[1]
|
||||||
|
|
||||||
Dimension NxCx2
|
|
||||||
|
|
||||||
"""
|
|
||||||
return self._reasonings.detach().cpu()
|
|
||||||
|
|
||||||
def _register_reasonings(self, reasonings):
|
def _register_reasonings(self, reasonings):
|
||||||
self.register_parameter("_reasonings", Parameter(reasonings))
|
self.register_parameter("_reasonings", Parameter(reasonings))
|
||||||
|
Loading…
Reference in New Issue
Block a user