[REFACTOR] Simplify ReasoningComponents
This commit is contained in:
		@@ -17,6 +17,7 @@ from .initializers import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def validate_initializer(initializer, instanceof):
 | 
					def validate_initializer(initializer, instanceof):
 | 
				
			||||||
 | 
					    """Check if the initializer is valid."""
 | 
				
			||||||
    if not isinstance(initializer, instanceof):
 | 
					    if not isinstance(initializer, instanceof):
 | 
				
			||||||
        emsg = f"`initializer` has to be an instance " \
 | 
					        emsg = f"`initializer` has to be an instance " \
 | 
				
			||||||
            f"of some subtype of {instanceof}. " \
 | 
					            f"of some subtype of {instanceof}. " \
 | 
				
			||||||
@@ -47,6 +48,17 @@ def removeind(ins, attr, indices):
 | 
				
			|||||||
    return items, mask
 | 
					    return items, mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_cikwargs(init, distribution):
 | 
				
			||||||
 | 
					    """Return appropriate key-word arguments for a component initializer."""
 | 
				
			||||||
 | 
					    if isinstance(init, ClassAwareCompInitializer):
 | 
				
			||||||
 | 
					        cikwargs = dict(distribution=distribution)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        distribution = parse_distribution(distribution)
 | 
				
			||||||
 | 
					        num_components = sum(distribution.values())
 | 
				
			||||||
 | 
					        cikwargs = dict(num_components=num_components)
 | 
				
			||||||
 | 
					    return cikwargs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AbstractComponents(torch.nn.Module):
 | 
					class AbstractComponents(torch.nn.Module):
 | 
				
			||||||
    """Abstract class for all components modules."""
 | 
					    """Abstract class for all components modules."""
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
@@ -194,12 +206,7 @@ class LabeledComponents(AbstractComponents):
 | 
				
			|||||||
                                    AbstractComponentsInitializer)
 | 
					                                    AbstractComponentsInitializer)
 | 
				
			||||||
        assert validate_initializer(labels_initializer,
 | 
					        assert validate_initializer(labels_initializer,
 | 
				
			||||||
                                    AbstractLabelsInitializer)
 | 
					                                    AbstractLabelsInitializer)
 | 
				
			||||||
        if isinstance(components_initializer, ClassAwareCompInitializer):
 | 
					        cikwargs = get_cikwargs(components_initializer, distribution)
 | 
				
			||||||
            cikwargs = dict(distribution=distribution)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            distribution = parse_distribution(distribution)
 | 
					 | 
				
			||||||
            num_components = sum(distribution.values())
 | 
					 | 
				
			||||||
            cikwargs = dict(num_components=num_components)
 | 
					 | 
				
			||||||
        _components, new_components = gencat(self, "_components",
 | 
					        _components, new_components = gencat(self, "_components",
 | 
				
			||||||
                                             components_initializer,
 | 
					                                             components_initializer,
 | 
				
			||||||
                                             **cikwargs)
 | 
					                                             **cikwargs)
 | 
				
			||||||
@@ -259,47 +266,28 @@ class ReasoningComponents(AbstractComponents):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def add_components(self, distribution, components_initializer,
 | 
					    def add_components(self, distribution, components_initializer,
 | 
				
			||||||
                       reasonings_initializer: AbstractReasoningsInitializer):
 | 
					                       reasonings_initializer: AbstractReasoningsInitializer):
 | 
				
			||||||
        # Checks
 | 
					        """Generate and add new components and reasonings."""
 | 
				
			||||||
        assert validate_initializer(components_initializer,
 | 
					        assert validate_initializer(components_initializer,
 | 
				
			||||||
                                    AbstractComponentsInitializer)
 | 
					                                    AbstractComponentsInitializer)
 | 
				
			||||||
        assert validate_initializer(reasonings_initializer,
 | 
					        assert validate_initializer(reasonings_initializer,
 | 
				
			||||||
                                    AbstractReasoningsInitializer)
 | 
					                                    AbstractReasoningsInitializer)
 | 
				
			||||||
 | 
					        cikwargs = get_cikwargs(components_initializer, distribution)
 | 
				
			||||||
        distribution = parse_distribution(distribution)
 | 
					        _components, new_components = gencat(self, "_components",
 | 
				
			||||||
 | 
					                                             components_initializer,
 | 
				
			||||||
        # Generate new components
 | 
					                                             **cikwargs)
 | 
				
			||||||
        if isinstance(components_initializer, ClassAwareCompInitializer):
 | 
					        _reasonings, new_reasonings = gencat(self, "_reasonings",
 | 
				
			||||||
            new_components = components_initializer.generate(distribution)
 | 
					                                             reasonings_initializer,
 | 
				
			||||||
        else:
 | 
					                                             distribution)
 | 
				
			||||||
            num_components = sum(distribution.values())
 | 
					 | 
				
			||||||
            new_components = components_initializer.generate(num_components)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Generate new reasonings
 | 
					 | 
				
			||||||
        new_reasonings = reasonings_initializer.generate(distribution)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Register
 | 
					 | 
				
			||||||
        if hasattr(self, "_components"):
 | 
					 | 
				
			||||||
            _components = torch.cat([self._components, new_components])
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            _components = new_components
 | 
					 | 
				
			||||||
        if hasattr(self, "_reasonings"):
 | 
					 | 
				
			||||||
            _reasonings = torch.cat([self._reasonings, new_reasonings])
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            _reasonings = new_reasonings
 | 
					 | 
				
			||||||
        self._register_components(_components)
 | 
					        self._register_components(_components)
 | 
				
			||||||
        self._register_reasonings(_reasonings)
 | 
					        self._register_reasonings(_reasonings)
 | 
				
			||||||
 | 
					 | 
				
			||||||
        return new_components, new_reasonings
 | 
					        return new_components, new_reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def remove_components(self, indices):
 | 
					    def remove_components(self, indices):
 | 
				
			||||||
        """Remove components and labels at specified indices."""
 | 
					        """Remove components and reasonings at specified indices."""
 | 
				
			||||||
        mask = torch.ones(self.num_components, dtype=torch.bool)
 | 
					        _components, mask = removeind(self, "_components", indices)
 | 
				
			||||||
        mask[indices] = False
 | 
					        _reasonings, mask = removeind(self, "_reasonings", indices)
 | 
				
			||||||
        _components = self._components[mask]
 | 
					 | 
				
			||||||
        # TODO
 | 
					 | 
				
			||||||
        # _reasonings = self._reasonings[mask]
 | 
					 | 
				
			||||||
        self._register_components(_components)
 | 
					        self._register_components(_components)
 | 
				
			||||||
        # self._register_reasonings(_reasonings)
 | 
					        self._register_reasonings(_reasonings)
 | 
				
			||||||
        return mask
 | 
					        return mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self):
 | 
					    def forward(self):
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user