[QA] Fix for "no-self-use" (R0201)
This commit is contained in:
parent
e8e803e8ef
commit
b1568a550a
@ -32,6 +32,14 @@ def get_labels_object(distribution):
|
||||
return labels
|
||||
|
||||
|
||||
def _precheck_initializer(self, initializer):
|
||||
if not isinstance(initializer, ComponentsInitializer):
|
||||
emsg = f"`initializer` has to be some subtype of " \
|
||||
f"{ComponentsInitializer}. " \
|
||||
f"You have provided: {initializer=} instead."
|
||||
raise TypeError(emsg)
|
||||
|
||||
|
||||
class Components(torch.nn.Module):
|
||||
"""Components is a set of learnable Tensors."""
|
||||
def __init__(self,
|
||||
@ -58,15 +66,8 @@ class Components(torch.nn.Module):
|
||||
def _register_components(self, components):
|
||||
self.register_parameter("_components", Parameter(components))
|
||||
|
||||
def _precheck_initializer(self, initializer):
|
||||
if not isinstance(initializer, ComponentsInitializer):
|
||||
emsg = f"`initializer` has to be some subtype of " \
|
||||
f"{ComponentsInitializer}. " \
|
||||
f"You have provided: {initializer=} instead."
|
||||
raise TypeError(emsg)
|
||||
|
||||
def _initialize_components(self, initializer, num_components):
|
||||
self._precheck_initializer(initializer)
|
||||
_precheck_initializer(initializer)
|
||||
_components = initializer.generate(num_components)
|
||||
self._register_components(_components)
|
||||
|
||||
@ -78,7 +79,7 @@ class Components(torch.nn.Module):
|
||||
if initialized_components is not None:
|
||||
_components = torch.cat([self._components, initialized_components])
|
||||
else:
|
||||
self._precheck_initializer(initializer)
|
||||
_precheck_initializer(initializer)
|
||||
_new = initializer.generate(num)
|
||||
_components = torch.cat([self._components, _new])
|
||||
self._register_components(_components)
|
||||
@ -128,12 +129,14 @@ class LabeledComponents(Components):
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
clabels, counts = torch.unique(self._labels, sorted=True, return_counts=True)
|
||||
clabels, counts = torch.unique(self._labels,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
return dict(zip(clabels.tolist(), counts.tolist()))
|
||||
|
||||
def _initialize_components(self, initializer, num_components):
|
||||
if isinstance(initializer, ClassAwareInitializer):
|
||||
self._precheck_initializer(initializer)
|
||||
_precheck_initializer(initializer)
|
||||
_components = initializer.generate(num_components,
|
||||
self.initial_distribution)
|
||||
self._register_components(_components)
|
||||
@ -141,7 +144,7 @@ class LabeledComponents(Components):
|
||||
super()._initialize_components(initializer, num_components)
|
||||
|
||||
def add_components(self, initializer, distribution):
|
||||
self._precheck_initializer(initializer)
|
||||
_precheck_initializer(initializer)
|
||||
|
||||
# Labels
|
||||
labels = get_labels_object(distribution)
|
||||
|
Loading…
Reference in New Issue
Block a user