[REFACTOR] Simplify ReasoningComponents

This commit is contained in:
Jensun Ravichandran 2021-06-14 14:44:36 +02:00
parent 6ad665f8c2
commit d2d6f31e7b

View File

@ -17,6 +17,7 @@ from .initializers import (
def validate_initializer(initializer, instanceof): def validate_initializer(initializer, instanceof):
"""Check if the initializer is valid."""
if not isinstance(initializer, instanceof): if not isinstance(initializer, instanceof):
emsg = f"`initializer` has to be an instance " \ emsg = f"`initializer` has to be an instance " \
f"of some subtype of {instanceof}. " \ f"of some subtype of {instanceof}. " \
@ -47,6 +48,17 @@ def removeind(ins, attr, indices):
return items, mask return items, mask
def get_cikwargs(init, distribution):
"""Return appropriate key-word arguments for a component initializer."""
if isinstance(init, ClassAwareCompInitializer):
cikwargs = dict(distribution=distribution)
else:
distribution = parse_distribution(distribution)
num_components = sum(distribution.values())
cikwargs = dict(num_components=num_components)
return cikwargs
class AbstractComponents(torch.nn.Module): class AbstractComponents(torch.nn.Module):
"""Abstract class for all components modules.""" """Abstract class for all components modules."""
@property @property
@ -194,12 +206,7 @@ class LabeledComponents(AbstractComponents):
AbstractComponentsInitializer) AbstractComponentsInitializer)
assert validate_initializer(labels_initializer, assert validate_initializer(labels_initializer,
AbstractLabelsInitializer) AbstractLabelsInitializer)
if isinstance(components_initializer, ClassAwareCompInitializer): cikwargs = get_cikwargs(components_initializer, distribution)
cikwargs = dict(distribution=distribution)
else:
distribution = parse_distribution(distribution)
num_components = sum(distribution.values())
cikwargs = dict(num_components=num_components)
_components, new_components = gencat(self, "_components", _components, new_components = gencat(self, "_components",
components_initializer, components_initializer,
**cikwargs) **cikwargs)
@ -259,47 +266,28 @@ class ReasoningComponents(AbstractComponents):
def add_components(self, distribution, components_initializer, def add_components(self, distribution, components_initializer,
reasonings_initializer: AbstractReasoningsInitializer): reasonings_initializer: AbstractReasoningsInitializer):
# Checks """Generate and add new components and reasonings."""
assert validate_initializer(components_initializer, assert validate_initializer(components_initializer,
AbstractComponentsInitializer) AbstractComponentsInitializer)
assert validate_initializer(reasonings_initializer, assert validate_initializer(reasonings_initializer,
AbstractReasoningsInitializer) AbstractReasoningsInitializer)
cikwargs = get_cikwargs(components_initializer, distribution)
distribution = parse_distribution(distribution) _components, new_components = gencat(self, "_components",
components_initializer,
# Generate new components **cikwargs)
if isinstance(components_initializer, ClassAwareCompInitializer): _reasonings, new_reasonings = gencat(self, "_reasonings",
new_components = components_initializer.generate(distribution) reasonings_initializer,
else: distribution)
num_components = sum(distribution.values())
new_components = components_initializer.generate(num_components)
# Generate new reasonings
new_reasonings = reasonings_initializer.generate(distribution)
# Register
if hasattr(self, "_components"):
_components = torch.cat([self._components, new_components])
else:
_components = new_components
if hasattr(self, "_reasonings"):
_reasonings = torch.cat([self._reasonings, new_reasonings])
else:
_reasonings = new_reasonings
self._register_components(_components) self._register_components(_components)
self._register_reasonings(_reasonings) self._register_reasonings(_reasonings)
return new_components, new_reasonings return new_components, new_reasonings
def remove_components(self, indices): def remove_components(self, indices):
"""Remove components and labels at specified indices.""" """Remove components and reasonings at specified indices."""
mask = torch.ones(self.num_components, dtype=torch.bool) _components, mask = removeind(self, "_components", indices)
mask[indices] = False _reasonings, mask = removeind(self, "_reasonings", indices)
_components = self._components[mask]
# TODO
# _reasonings = self._reasonings[mask]
self._register_components(_components) self._register_components(_components)
# self._register_reasonings(_reasonings) self._register_reasonings(_reasonings)
return mask return mask
def forward(self): def forward(self):