[REFACTOR] Simplify ReasoningComponents

This commit is contained in:
Jensun Ravichandran 2021-06-14 14:44:36 +02:00
parent 6ad665f8c2
commit d2d6f31e7b

View File

@ -17,6 +17,7 @@ from .initializers import (
def validate_initializer(initializer, instanceof):
"""Check if the initializer is valid."""
if not isinstance(initializer, instanceof):
emsg = f"`initializer` has to be an instance " \
f"of some subtype of {instanceof}. " \
@ -47,6 +48,17 @@ def removeind(ins, attr, indices):
return items, mask
def get_cikwargs(init, distribution):
"""Return appropriate key-word arguments for a component initializer."""
if isinstance(init, ClassAwareCompInitializer):
cikwargs = dict(distribution=distribution)
else:
distribution = parse_distribution(distribution)
num_components = sum(distribution.values())
cikwargs = dict(num_components=num_components)
return cikwargs
class AbstractComponents(torch.nn.Module):
"""Abstract class for all components modules."""
@property
@ -194,12 +206,7 @@ class LabeledComponents(AbstractComponents):
AbstractComponentsInitializer)
assert validate_initializer(labels_initializer,
AbstractLabelsInitializer)
if isinstance(components_initializer, ClassAwareCompInitializer):
cikwargs = dict(distribution=distribution)
else:
distribution = parse_distribution(distribution)
num_components = sum(distribution.values())
cikwargs = dict(num_components=num_components)
cikwargs = get_cikwargs(components_initializer, distribution)
_components, new_components = gencat(self, "_components",
components_initializer,
**cikwargs)
@ -259,47 +266,28 @@ class ReasoningComponents(AbstractComponents):
def add_components(self, distribution, components_initializer,
reasonings_initializer: AbstractReasoningsInitializer):
# Checks
"""Generate and add new components and reasonings."""
assert validate_initializer(components_initializer,
AbstractComponentsInitializer)
assert validate_initializer(reasonings_initializer,
AbstractReasoningsInitializer)
distribution = parse_distribution(distribution)
# Generate new components
if isinstance(components_initializer, ClassAwareCompInitializer):
new_components = components_initializer.generate(distribution)
else:
num_components = sum(distribution.values())
new_components = components_initializer.generate(num_components)
# Generate new reasonings
new_reasonings = reasonings_initializer.generate(distribution)
# Register
if hasattr(self, "_components"):
_components = torch.cat([self._components, new_components])
else:
_components = new_components
if hasattr(self, "_reasonings"):
_reasonings = torch.cat([self._reasonings, new_reasonings])
else:
_reasonings = new_reasonings
cikwargs = get_cikwargs(components_initializer, distribution)
_components, new_components = gencat(self, "_components",
components_initializer,
**cikwargs)
_reasonings, new_reasonings = gencat(self, "_reasonings",
reasonings_initializer,
distribution)
self._register_components(_components)
self._register_reasonings(_reasonings)
return new_components, new_reasonings
def remove_components(self, indices):
"""Remove components and labels at specified indices."""
mask = torch.ones(self.num_components, dtype=torch.bool)
mask[indices] = False
_components = self._components[mask]
# TODO
# _reasonings = self._reasonings[mask]
"""Remove components and reasonings at specified indices."""
_components, mask = removeind(self, "_components", indices)
_reasonings, mask = removeind(self, "_reasonings", indices)
self._register_components(_components)
# self._register_reasonings(_reasonings)
self._register_reasonings(_reasonings)
return mask
def forward(self):