chore: Allow no-self-use for some class members

Classes are used as common interface and connection to pytorch.
This commit is contained in:
Alexander Engelsberger 2021-06-21 14:29:25 +02:00
parent 597c9fc1ee
commit 3d76dffe3c
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
3 changed files with 19 additions and 17 deletions

View File

@ -48,7 +48,7 @@ class WTAC(torch.nn.Module):
Thin wrapper over the `wtac` function.
"""
def forward(self, distances, labels):
def forward(self, distances, labels): # pylint: disable=no-self-use
return wtac(distances, labels)
@ -58,7 +58,7 @@ class LTAC(torch.nn.Module):
Thin wrapper over the `wtac` function.
"""
def forward(self, probs, labels):
def forward(self, probs, labels): # pylint: disable=no-self-use
return wtac(-1.0 * probs, labels)
@ -85,5 +85,5 @@ class CBCC(torch.nn.Module):
Thin wrapper over the `cbcc` function.
"""
def forward(self, detections, reasonings):
def forward(self, detections, reasonings): # pylint: disable=no-self-use
return cbcc(detections, reasonings)

View File

@ -303,17 +303,18 @@ class OneHotLabelsInitializer(LabelsInitializer):
# Reasonings
class AbstractReasoningsInitializer(ABC):
"""Abstract class for all reasonings initializers."""
def __init__(self, components_first: bool = True):
self.components_first = components_first
def compute_shape(self, distribution):
def compute_distribution_shape(distribution):
distribution = parse_distribution(distribution)
num_components = sum(distribution.values())
num_classes = len(distribution.keys())
return (num_components, num_classes, 2)
class AbstractReasoningsInitializer(ABC):
"""Abstract class for all reasonings initializers."""
def __init__(self, components_first: bool = True):
self.components_first = components_first
def generate_end_hook(self, reasonings):
if not self.components_first:
reasonings = reasonings.permute(2, 1, 0)
@ -349,7 +350,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with zeros."""
def generate(self, distribution: Union[dict, list, tuple]):
shape = self.compute_shape(distribution)
shape = compute_distribution_shape(distribution)
reasonings = torch.zeros(*shape)
reasonings = self.generate_end_hook(reasonings)
return reasonings
@ -358,7 +359,7 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with ones."""
def generate(self, distribution: Union[dict, list, tuple]):
shape = self.compute_shape(distribution)
shape = compute_distribution_shape(distribution)
reasonings = torch.ones(*shape)
reasonings = self.generate_end_hook(reasonings)
return reasonings
@ -372,7 +373,7 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
self.maximum = maximum
def generate(self, distribution: Union[dict, list, tuple]):
shape = self.compute_shape(distribution)
shape = compute_distribution_shape(distribution)
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
reasonings = self.generate_end_hook(reasonings)
return reasonings
@ -381,7 +382,8 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
"""Each component reasons positively for exactly one class."""
def generate(self, distribution: Union[dict, list, tuple]):
num_components, num_classes, _ = self.compute_shape(distribution)
num_components, num_classes, _ = compute_distribution_shape(
distribution)
A = OneHotLabelsInitializer().generate(distribution)
B = torch.zeros(num_components, num_classes)
reasonings = torch.stack([A, B], dim=-1)

View File

@ -82,23 +82,23 @@ def stratified_prod_pooling(values: torch.Tensor,
class StratifiedSumPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_sum_pooling` function."""
def forward(self, values, labels):
def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_sum_pooling(values, labels)
class StratifiedProdPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_prod_pooling` function."""
def forward(self, values, labels):
def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_prod_pooling(values, labels)
class StratifiedMinPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_min_pooling` function."""
def forward(self, values, labels):
def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_min_pooling(values, labels)
class StratifiedMaxPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_max_pooling` function."""
def forward(self, values, labels):
def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_max_pooling(values, labels)