[QA] Fix for "no-self-use" (R0201)

This commit is contained in:
Jensun Ravichandran 2021-06-01 19:26:05 +02:00
parent e8e803e8ef
commit b1568a550a

View File

@ -32,6 +32,14 @@ def get_labels_object(distribution):
return labels
def _precheck_initializer(self, initializer):
if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some subtype of " \
f"{ComponentsInitializer}. " \
f"You have provided: {initializer=} instead."
raise TypeError(emsg)
class Components(torch.nn.Module):
"""Components is a set of learnable Tensors."""
def __init__(self,
@ -58,15 +66,8 @@ class Components(torch.nn.Module):
def _register_components(self, components):
self.register_parameter("_components", Parameter(components))
def _precheck_initializer(self, initializer):
if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some subtype of " \
f"{ComponentsInitializer}. " \
f"You have provided: {initializer=} instead."
raise TypeError(emsg)
def _initialize_components(self, initializer, num_components):
self._precheck_initializer(initializer)
_precheck_initializer(initializer)
_components = initializer.generate(num_components)
self._register_components(_components)
@ -78,7 +79,7 @@ class Components(torch.nn.Module):
if initialized_components is not None:
_components = torch.cat([self._components, initialized_components])
else:
self._precheck_initializer(initializer)
_precheck_initializer(initializer)
_new = initializer.generate(num)
_components = torch.cat([self._components, _new])
self._register_components(_components)
@ -128,12 +129,14 @@ class LabeledComponents(Components):
@property
def distribution(self):
clabels, counts = torch.unique(self._labels, sorted=True, return_counts=True)
clabels, counts = torch.unique(self._labels,
sorted=True,
return_counts=True)
return dict(zip(clabels.tolist(), counts.tolist()))
def _initialize_components(self, initializer, num_components):
if isinstance(initializer, ClassAwareInitializer):
self._precheck_initializer(initializer)
_precheck_initializer(initializer)
_components = initializer.generate(num_components,
self.initial_distribution)
self._register_components(_components)
@ -141,7 +144,7 @@ class LabeledComponents(Components):
super()._initialize_components(initializer, num_components)
def add_components(self, initializer, distribution):
self._precheck_initializer(initializer)
_precheck_initializer(initializer)
# Labels
labels = get_labels_object(distribution)