chore: Allow no-self-use for some class members
Classes are used as common interface and connection to pytorch.
This commit is contained in:
parent
597c9fc1ee
commit
3d76dffe3c
@ -48,7 +48,7 @@ class WTAC(torch.nn.Module):
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, distances, labels):
|
||||
def forward(self, distances, labels): # pylint: disable=no-self-use
|
||||
return wtac(distances, labels)
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ class LTAC(torch.nn.Module):
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, probs, labels):
|
||||
def forward(self, probs, labels): # pylint: disable=no-self-use
|
||||
return wtac(-1.0 * probs, labels)
|
||||
|
||||
|
||||
@ -85,5 +85,5 @@ class CBCC(torch.nn.Module):
|
||||
Thin wrapper over the `cbcc` function.
|
||||
|
||||
"""
|
||||
def forward(self, detections, reasonings):
|
||||
def forward(self, detections, reasonings): # pylint: disable=no-self-use
|
||||
return cbcc(detections, reasonings)
|
||||
|
@ -303,17 +303,18 @@ class OneHotLabelsInitializer(LabelsInitializer):
|
||||
|
||||
|
||||
# Reasonings
|
||||
class AbstractReasoningsInitializer(ABC):
|
||||
"""Abstract class for all reasonings initializers."""
|
||||
def __init__(self, components_first: bool = True):
|
||||
self.components_first = components_first
|
||||
|
||||
def compute_shape(self, distribution):
|
||||
def compute_distribution_shape(distribution):
|
||||
distribution = parse_distribution(distribution)
|
||||
num_components = sum(distribution.values())
|
||||
num_classes = len(distribution.keys())
|
||||
return (num_components, num_classes, 2)
|
||||
|
||||
|
||||
class AbstractReasoningsInitializer(ABC):
|
||||
"""Abstract class for all reasonings initializers."""
|
||||
def __init__(self, components_first: bool = True):
|
||||
self.components_first = components_first
|
||||
|
||||
def generate_end_hook(self, reasonings):
|
||||
if not self.components_first:
|
||||
reasonings = reasonings.permute(2, 1, 0)
|
||||
@ -349,7 +350,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are all initialized with zeros."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
shape = compute_distribution_shape(distribution)
|
||||
reasonings = torch.zeros(*shape)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
@ -358,7 +359,7 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are all initialized with ones."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
shape = compute_distribution_shape(distribution)
|
||||
reasonings = torch.ones(*shape)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
@ -372,7 +373,7 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
self.maximum = maximum
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
shape = compute_distribution_shape(distribution)
|
||||
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
@ -381,7 +382,8 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Each component reasons positively for exactly one class."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
num_components, num_classes, _ = self.compute_shape(distribution)
|
||||
num_components, num_classes, _ = compute_distribution_shape(
|
||||
distribution)
|
||||
A = OneHotLabelsInitializer().generate(distribution)
|
||||
B = torch.zeros(num_components, num_classes)
|
||||
reasonings = torch.stack([A, B], dim=-1)
|
||||
|
@ -82,23 +82,23 @@ def stratified_prod_pooling(values: torch.Tensor,
|
||||
|
||||
class StratifiedSumPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||
return stratified_sum_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedProdPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||
return stratified_prod_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMinPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||
return stratified_min_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMaxPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||
return stratified_max_pooling(values, labels)
|
||||
|
Loading…
Reference in New Issue
Block a user