chore: Allow no-self-use for some class members
Classes are used as common interface and connection to pytorch.
This commit is contained in:
parent
597c9fc1ee
commit
3d76dffe3c
@ -48,7 +48,7 @@ class WTAC(torch.nn.Module):
|
|||||||
Thin wrapper over the `wtac` function.
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def forward(self, distances, labels):
|
def forward(self, distances, labels): # pylint: disable=no-self-use
|
||||||
return wtac(distances, labels)
|
return wtac(distances, labels)
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ class LTAC(torch.nn.Module):
|
|||||||
Thin wrapper over the `wtac` function.
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def forward(self, probs, labels):
|
def forward(self, probs, labels): # pylint: disable=no-self-use
|
||||||
return wtac(-1.0 * probs, labels)
|
return wtac(-1.0 * probs, labels)
|
||||||
|
|
||||||
|
|
||||||
@ -85,5 +85,5 @@ class CBCC(torch.nn.Module):
|
|||||||
Thin wrapper over the `cbcc` function.
|
Thin wrapper over the `cbcc` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def forward(self, detections, reasonings):
|
def forward(self, detections, reasonings): # pylint: disable=no-self-use
|
||||||
return cbcc(detections, reasonings)
|
return cbcc(detections, reasonings)
|
||||||
|
@ -303,17 +303,18 @@ class OneHotLabelsInitializer(LabelsInitializer):
|
|||||||
|
|
||||||
|
|
||||||
# Reasonings
|
# Reasonings
|
||||||
|
def compute_distribution_shape(distribution):
|
||||||
|
distribution = parse_distribution(distribution)
|
||||||
|
num_components = sum(distribution.values())
|
||||||
|
num_classes = len(distribution.keys())
|
||||||
|
return (num_components, num_classes, 2)
|
||||||
|
|
||||||
|
|
||||||
class AbstractReasoningsInitializer(ABC):
|
class AbstractReasoningsInitializer(ABC):
|
||||||
"""Abstract class for all reasonings initializers."""
|
"""Abstract class for all reasonings initializers."""
|
||||||
def __init__(self, components_first: bool = True):
|
def __init__(self, components_first: bool = True):
|
||||||
self.components_first = components_first
|
self.components_first = components_first
|
||||||
|
|
||||||
def compute_shape(self, distribution):
|
|
||||||
distribution = parse_distribution(distribution)
|
|
||||||
num_components = sum(distribution.values())
|
|
||||||
num_classes = len(distribution.keys())
|
|
||||||
return (num_components, num_classes, 2)
|
|
||||||
|
|
||||||
def generate_end_hook(self, reasonings):
|
def generate_end_hook(self, reasonings):
|
||||||
if not self.components_first:
|
if not self.components_first:
|
||||||
reasonings = reasonings.permute(2, 1, 0)
|
reasonings = reasonings.permute(2, 1, 0)
|
||||||
@ -349,7 +350,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Reasonings are all initialized with zeros."""
|
"""Reasonings are all initialized with zeros."""
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = self.compute_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.zeros(*shape)
|
reasonings = torch.zeros(*shape)
|
||||||
reasonings = self.generate_end_hook(reasonings)
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
@ -358,7 +359,7 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Reasonings are all initialized with ones."""
|
"""Reasonings are all initialized with ones."""
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = self.compute_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.ones(*shape)
|
reasonings = torch.ones(*shape)
|
||||||
reasonings = self.generate_end_hook(reasonings)
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
@ -372,7 +373,7 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
self.maximum = maximum
|
self.maximum = maximum
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = self.compute_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
||||||
reasonings = self.generate_end_hook(reasonings)
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
@ -381,7 +382,8 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Each component reasons positively for exactly one class."""
|
"""Each component reasons positively for exactly one class."""
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
num_components, num_classes, _ = self.compute_shape(distribution)
|
num_components, num_classes, _ = compute_distribution_shape(
|
||||||
|
distribution)
|
||||||
A = OneHotLabelsInitializer().generate(distribution)
|
A = OneHotLabelsInitializer().generate(distribution)
|
||||||
B = torch.zeros(num_components, num_classes)
|
B = torch.zeros(num_components, num_classes)
|
||||||
reasonings = torch.stack([A, B], dim=-1)
|
reasonings = torch.stack([A, B], dim=-1)
|
||||||
|
@ -82,23 +82,23 @@ def stratified_prod_pooling(values: torch.Tensor,
|
|||||||
|
|
||||||
class StratifiedSumPooling(torch.nn.Module):
|
class StratifiedSumPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||||
def forward(self, values, labels):
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_sum_pooling(values, labels)
|
return stratified_sum_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedProdPooling(torch.nn.Module):
|
class StratifiedProdPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||||
def forward(self, values, labels):
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_prod_pooling(values, labels)
|
return stratified_prod_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedMinPooling(torch.nn.Module):
|
class StratifiedMinPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_min_pooling` function."""
|
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||||
def forward(self, values, labels):
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_min_pooling(values, labels)
|
return stratified_min_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedMaxPooling(torch.nn.Module):
|
class StratifiedMaxPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_max_pooling` function."""
|
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||||
def forward(self, values, labels):
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_max_pooling(values, labels)
|
return stratified_max_pooling(values, labels)
|
||||||
|
Loading…
Reference in New Issue
Block a user