[FEATURE] Add standalone reasonings and CBC competition
This commit is contained in:
parent
0f450ed8a0
commit
6e8a52e371
@ -14,6 +14,7 @@ from .core import (
|
||||
components,
|
||||
distances,
|
||||
initializers,
|
||||
similarities,
|
||||
losses,
|
||||
pooling,
|
||||
)
|
||||
@ -31,6 +32,7 @@ __all_core__ = [
|
||||
"losses",
|
||||
"nn",
|
||||
"pooling",
|
||||
"similarities",
|
||||
"utils",
|
||||
]
|
||||
|
||||
|
@ -28,6 +28,24 @@ def knnc(distances: torch.Tensor,
|
||||
return winning_labels
|
||||
|
||||
|
||||
def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
|
||||
"""Classification-By-Components Competition.
|
||||
|
||||
Returns probability distributions over the classes.
|
||||
|
||||
`detections` must be of shape [batch_size, num_components].
|
||||
`reasonings` must be of shape [num_components, num_classes, 2].
|
||||
|
||||
"""
|
||||
A, B = reasonings.permute(2, 1, 0).clamp(0, 1)
|
||||
pk = A
|
||||
nk = (1 - A) * B
|
||||
numerator = (detections @ (pk - nk).T) + nk.sum(1)
|
||||
probs = numerator / (pk + nk).sum(1)
|
||||
# probs = probs.squeeze(0)
|
||||
return probs
|
||||
|
||||
|
||||
class WTAC(torch.nn.Module):
|
||||
"""Winner-Takes-All-Competition Layer.
|
||||
|
||||
@ -63,3 +81,13 @@ class KNNC(torch.nn.Module):
|
||||
|
||||
def extra_repr(self):
|
||||
return f"k: {self.k}"
|
||||
|
||||
|
||||
class CBCC(torch.nn.Module):
|
||||
"""Classification-By-Components Competition.
|
||||
|
||||
Thin wrapper over the `cbcc` function.
|
||||
|
||||
"""
|
||||
def forward(self, detections, reasonings):
|
||||
return cbcc(detections, reasonings)
|
||||
|
@ -13,6 +13,7 @@ from .initializers import (
|
||||
AbstractLabelsInitializer,
|
||||
AbstractReasoningsInitializer,
|
||||
LabelsInitializer,
|
||||
RandomReasoningsInitializer,
|
||||
)
|
||||
|
||||
|
||||
@ -112,7 +113,7 @@ class AbstractLabels(torch.nn.Module):
|
||||
"""Abstract class for all labels modules."""
|
||||
@property
|
||||
def labels(self):
|
||||
return self._labels
|
||||
return self._labels.cpu()
|
||||
|
||||
@property
|
||||
def num_labels(self):
|
||||
@ -174,6 +175,10 @@ class Labels(AbstractLabels):
|
||||
self._register_labels(_labels)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the labels."""
|
||||
return self._labels
|
||||
|
||||
|
||||
class LabeledComponents(AbstractComponents):
|
||||
"""A set of adaptable components and corresponding unadaptable labels."""
|
||||
@ -188,11 +193,6 @@ class LabeledComponents(AbstractComponents):
|
||||
self.add_components(distribution, components_initializer,
|
||||
labels_initializer)
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
"""Tensor containing the component labels."""
|
||||
return self._labels
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
unique, counts = torch.unique(self._labels,
|
||||
@ -200,6 +200,15 @@ class LabeledComponents(AbstractComponents):
|
||||
return_counts=True)
|
||||
return dict(zip(unique.tolist(), counts.tolist()))
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.distribution.keys())
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
"""Tensor containing the component labels."""
|
||||
return self._labels.cpu()
|
||||
|
||||
def _register_labels(self, labels):
|
||||
self.register_buffer("_labels", labels)
|
||||
|
||||
@ -236,6 +245,64 @@ class LabeledComponents(AbstractComponents):
|
||||
return self._components, self._labels
|
||||
|
||||
|
||||
class Reasonings(torch.nn.Module):
|
||||
"""A set of standalone reasoning matrices.
|
||||
|
||||
The `reasonings` tensor is of shape [num_components, num_classes, 2].
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
initializer:
|
||||
AbstractReasoningsInitializer = RandomReasoningsInitializer(),
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self._reasonings.shape[1]
|
||||
|
||||
# @property
|
||||
# def reasonings(self):
|
||||
# """Tensor containing the reasoning matrices."""
|
||||
# return self._reasonings.detach().cpu()
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
with torch.no_grad():
|
||||
A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1)
|
||||
pk = A
|
||||
nk = (1 - pk) * B
|
||||
ik = 1 - pk - nk
|
||||
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
|
||||
return img.unsqueeze(1).cpu()
|
||||
|
||||
def _register_reasonings(self, reasonings):
|
||||
self.register_buffer("_reasonings", reasonings)
|
||||
|
||||
def add_reasonings(
|
||||
self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
initializer:
|
||||
AbstractReasoningsInitializer = RandomReasoningsInitializer()):
|
||||
"""Generate and add new reasonings."""
|
||||
assert validate_initializer(initializer, AbstractReasoningsInitializer)
|
||||
_reasonings, new_reasonings = gencat(self, "_reasonings", initializer,
|
||||
distribution)
|
||||
self._register_reasonings(_reasonings)
|
||||
return new_reasonings
|
||||
|
||||
def remove_reasonings(self, indices):
|
||||
"""Remove reasonings at specified indices."""
|
||||
_reasonings, mask = removeind(self, "_reasonings", indices)
|
||||
self._register_reasonings(_reasonings)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
"""Simply return the reasonings."""
|
||||
return self._reasonings
|
||||
|
||||
|
||||
class ReasoningComponents(AbstractComponents):
|
||||
"""A set of components and a corresponding adapatable reasoning matrices.
|
||||
|
||||
@ -260,13 +327,8 @@ class ReasoningComponents(AbstractComponents):
|
||||
reasonings_initializer)
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
"""Returns Reasoning Matrix.
|
||||
|
||||
Dimension NxCx2
|
||||
|
||||
"""
|
||||
return self._reasonings.detach().cpu()
|
||||
def num_classes(self):
|
||||
return self._reasonings.shape[1]
|
||||
|
||||
def _register_reasonings(self, reasonings):
|
||||
self.register_parameter("_reasonings", Parameter(reasonings))
|
||||
|
Loading…
Reference in New Issue
Block a user