From 503ef0e05fdd640ee0990d3a99d6680d4dd7b199 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 17 May 2021 16:58:57 +0200 Subject: [PATCH] Cleanup components --- prototorch/components/components.py | 33 ++++++++++++++--------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/prototorch/components/components.py b/prototorch/components/components.py index 4b976cd..870c0fa 100644 --- a/prototorch/components/components.py +++ b/prototorch/components/components.py @@ -16,21 +16,23 @@ from torch.nn.parameter import Parameter class Components(torch.nn.Module): """Components is a set of learnable Tensors.""" def __init__(self, - number_of_components=None, + ncomps=None, initializer=None, *, initialized_components=None): super().__init__() + self.ncomps = ncomps + # Ignore all initialization settings if initialized_components is given. if initialized_components is not None: self.register_parameter("_components", Parameter(initialized_components)) - if number_of_components is not None or initializer is not None: + if ncomps is not None or initializer is not None: wmsg = "Arguments ignored while initializing Components" warnings.warn(wmsg) else: - self._initialize_components(number_of_components, initializer) + self._initialize_components(initializer) def _precheck_initializer(self, initializer): if not isinstance(initializer, ComponentsInitializer): @@ -39,9 +41,9 @@ class Components(torch.nn.Module): f"You have provided: {initializer=} instead." raise TypeError(emsg) - def _initialize_components(self, number_of_components, initializer): + def _initialize_components(self, initializer): self._precheck_initializer(initializer) - _components = initializer.generate(number_of_components) + _components = initializer.generate(self.ncomps) self.register_parameter("_components", Parameter(_components)) @property @@ -72,17 +74,16 @@ class LabeledComponents(Components): self._labels = component_labels else: _labels = self._initialize_labels(distribution) - super().__init__(number_of_components=len(_labels), - initializer=initializer) - self.register_buffer('_labels', _labels) + super().__init__(len(_labels), initializer=initializer) + self.register_buffer("_labels", _labels) - def _initialize_components(self, number_of_components, initializer): + def _initialize_components(self, initializer): if isinstance(initializer, ClassAwareInitializer): self._precheck_initializer(initializer) - self._components = Parameter( - initializer.generate(number_of_components, self.distribution)) + _components = initializer.generate(self.ncomps, self.distribution) + self.register_parameter("_components", Parameter(_components)) else: - super()._initialize_components(number_of_components, initializer) + super()._initialize_components(initializer) def _initialize_labels(self, distribution): if type(distribution) == tuple: @@ -130,14 +131,12 @@ class ReasoningComponents(Components): self.register_parameter("_reasonings", reasonings) else: self._initialize_reasonings(reasonings) - super().__init__(number_of_components=len(self._reasonings), - initializer=initializer) + super().__init__(len(self._reasonings), initializer=initializer) def _initialize_reasonings(self, reasonings): if type(reasonings) == tuple: - num_classes, number_of_components = reasonings - reasonings = ZeroReasoningsInitializer(num_classes, - number_of_components) + nclasses, ncomps = reasonings + reasonings = ZeroReasoningsInitializer(nclasses, ncomps) _reasonings = reasonings.generate() self.register_parameter("_reasonings", _reasonings)