[QA] Fix for "no-self-use" (R0201)

This commit is contained in:
Jensun Ravichandran 2021-06-01 19:26:05 +02:00
parent e8e803e8ef
commit b1568a550a

View File

@ -32,6 +32,14 @@ def get_labels_object(distribution):
return labels return labels
def _precheck_initializer(self, initializer):
if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some subtype of " \
f"{ComponentsInitializer}. " \
f"You have provided: {initializer=} instead."
raise TypeError(emsg)
class Components(torch.nn.Module): class Components(torch.nn.Module):
"""Components is a set of learnable Tensors.""" """Components is a set of learnable Tensors."""
def __init__(self, def __init__(self,
@ -58,15 +66,8 @@ class Components(torch.nn.Module):
def _register_components(self, components): def _register_components(self, components):
self.register_parameter("_components", Parameter(components)) self.register_parameter("_components", Parameter(components))
def _precheck_initializer(self, initializer):
if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some subtype of " \
f"{ComponentsInitializer}. " \
f"You have provided: {initializer=} instead."
raise TypeError(emsg)
def _initialize_components(self, initializer, num_components): def _initialize_components(self, initializer, num_components):
self._precheck_initializer(initializer) _precheck_initializer(initializer)
_components = initializer.generate(num_components) _components = initializer.generate(num_components)
self._register_components(_components) self._register_components(_components)
@ -78,7 +79,7 @@ class Components(torch.nn.Module):
if initialized_components is not None: if initialized_components is not None:
_components = torch.cat([self._components, initialized_components]) _components = torch.cat([self._components, initialized_components])
else: else:
self._precheck_initializer(initializer) _precheck_initializer(initializer)
_new = initializer.generate(num) _new = initializer.generate(num)
_components = torch.cat([self._components, _new]) _components = torch.cat([self._components, _new])
self._register_components(_components) self._register_components(_components)
@ -128,12 +129,14 @@ class LabeledComponents(Components):
@property @property
def distribution(self): def distribution(self):
clabels, counts = torch.unique(self._labels, sorted=True, return_counts=True) clabels, counts = torch.unique(self._labels,
sorted=True,
return_counts=True)
return dict(zip(clabels.tolist(), counts.tolist())) return dict(zip(clabels.tolist(), counts.tolist()))
def _initialize_components(self, initializer, num_components): def _initialize_components(self, initializer, num_components):
if isinstance(initializer, ClassAwareInitializer): if isinstance(initializer, ClassAwareInitializer):
self._precheck_initializer(initializer) _precheck_initializer(initializer)
_components = initializer.generate(num_components, _components = initializer.generate(num_components,
self.initial_distribution) self.initial_distribution)
self._register_components(_components) self._register_components(_components)
@ -141,7 +144,7 @@ class LabeledComponents(Components):
super()._initialize_components(initializer, num_components) super()._initialize_components(initializer, num_components)
def add_components(self, initializer, distribution): def add_components(self, initializer, distribution):
self._precheck_initializer(initializer) _precheck_initializer(initializer)
# Labels # Labels
labels = get_labels_object(distribution) labels = get_labels_object(distribution)