feat: warn user when component counts do not match

This commit is contained in:
Jensun Ravichandran 2022-03-29 14:39:41 +02:00
parent 08b3f9bbb9
commit ed5b9b6c62
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F

View File

@ -26,11 +26,18 @@ class LiteralCompInitializer(AbstractComponentsInitializer):
Use this to 'generate' pre-initialized components elsewhere.
"""
def __init__(self, components):
self.components = components
def generate(self, num_components: int = 0):
"""Ignore `num_components` and simply return `self.components`."""
provided_num_components = len(self.components)
if provided_num_components != num_components:
wmsg = f"The number of components ({provided_num_components}) " \
f"provided to {self.__class__.__name__} " \
f"does not match the expected number ({num_components})."
warnings.warn(wmsg)
if not isinstance(self.components, torch.Tensor):
wmsg = f"Converting components to {torch.Tensor}..."
warnings.warn(wmsg)
@ -40,6 +47,7 @@ class LiteralCompInitializer(AbstractComponentsInitializer):
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all dimension-aware components initializers."""
def __init__(self, shape: Union[Iterable, int]):
if isinstance(shape, Iterable):
self.component_shape = tuple(shape)
@ -53,6 +61,7 @@ class ShapeAwareCompInitializer(AbstractComponentsInitializer):
class ZerosCompInitializer(ShapeAwareCompInitializer):
"""Generate zeros corresponding to the components shape."""
def generate(self, num_components: int):
components = torch.zeros((num_components, ) + self.component_shape)
return components
@ -60,6 +69,7 @@ class ZerosCompInitializer(ShapeAwareCompInitializer):
class OnesCompInitializer(ShapeAwareCompInitializer):
"""Generate ones corresponding to the components shape."""
def generate(self, num_components: int):
components = torch.ones((num_components, ) + self.component_shape)
return components
@ -67,6 +77,7 @@ class OnesCompInitializer(ShapeAwareCompInitializer):
class FillValueCompInitializer(OnesCompInitializer):
"""Generate components with the provided `fill_value`."""
def __init__(self, shape, fill_value: float = 1.0):
super().__init__(shape)
self.fill_value = fill_value
@ -79,6 +90,7 @@ class FillValueCompInitializer(OnesCompInitializer):
class UniformCompInitializer(OnesCompInitializer):
"""Generate components by sampling from a continuous uniform distribution."""
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
super().__init__(shape)
self.minimum = minimum
@ -93,6 +105,7 @@ class UniformCompInitializer(OnesCompInitializer):
class RandomNormalCompInitializer(OnesCompInitializer):
"""Generate components by sampling from a standard normal distribution."""
def __init__(self, shape, shift=0.0, scale=1.0):
super().__init__(shape)
self.shift = shift
@ -113,6 +126,7 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
`data` has to be a torch tensor.
"""
def __init__(self,
data: torch.Tensor,
noise: float = 0.0,
@ -137,6 +151,7 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
"""'Generate' the components from the provided data."""
def generate(self, num_components: int = 0):
"""Ignore `num_components` and simply return transformed `self.data`."""
components = self.generate_end_hook(self.data)
@ -145,6 +160,7 @@ class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
class SelectionCompInitializer(AbstractDataAwareCompInitializer):
"""Generate components by uniformly sampling from the provided data."""
def generate(self, num_components: int):
indices = torch.LongTensor(num_components).random_(0, len(self.data))
samples = self.data[indices]
@ -154,6 +170,7 @@ class SelectionCompInitializer(AbstractDataAwareCompInitializer):
class MeanCompInitializer(AbstractDataAwareCompInitializer):
"""Generate components by computing the mean of the provided data."""
def generate(self, num_components: int):
mean = self.data.mean(dim=0)
repeat_dim = [num_components] + [1] * len(mean.shape)
@ -172,6 +189,7 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
target tensors.
"""
def __init__(self,
data,
noise: float = 0.0,
@ -199,6 +217,7 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
"""'Generate' components from provided data and requested distribution."""
def generate(self, distribution: Union[dict, list, tuple]):
"""Ignore `distribution` and simply return transformed `self.data`."""
components = self.generate_end_hook(self.data)
@ -207,6 +226,7 @@ class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
"""Abstract class for all stratified components initializers."""
@property
@abstractmethod
def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]:
@ -231,6 +251,7 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
"""Generate components using stratified sampling from the provided data."""
@property
def subinit_type(self):
return SelectionCompInitializer
@ -238,6 +259,7 @@ class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
"""Generate components at stratified means of the provided data."""
@property
def subinit_type(self):
return MeanCompInitializer
@ -246,6 +268,7 @@ class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
# Labels
class AbstractLabelsInitializer(ABC):
"""Abstract class for all labels initializers."""
@abstractmethod
def generate(self, distribution: Union[dict, list, tuple]):
...
@ -257,6 +280,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
Use this to 'generate' pre-initialized labels elsewhere.
"""
def __init__(self, labels):
self.labels = labels
@ -275,6 +299,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
class DataAwareLabelsInitializer(AbstractLabelsInitializer):
"""'Generate' the labels from a torch Dataset."""
def __init__(self, data):
self.data, self.targets = parse_data_arg(data)
@ -285,6 +310,7 @@ class DataAwareLabelsInitializer(AbstractLabelsInitializer):
class LabelsInitializer(AbstractLabelsInitializer):
"""Generate labels from `distribution`."""
def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution)
labels_list = []
@ -296,6 +322,7 @@ class LabelsInitializer(AbstractLabelsInitializer):
class OneHotLabelsInitializer(LabelsInitializer):
"""Generate one-hot-encoded labels from `distribution`."""
def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution)
num_classes = len(distribution.keys())
@ -314,6 +341,7 @@ def compute_distribution_shape(distribution):
class AbstractReasoningsInitializer(ABC):
"""Abstract class for all reasonings initializers."""
def __init__(self, components_first: bool = True):
self.components_first = components_first
@ -334,6 +362,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
Use this to 'generate' pre-initialized reasonings elsewhere.
"""
def __init__(self, reasonings, **kwargs):
super().__init__(**kwargs)
self.reasonings = reasonings
@ -351,6 +380,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with zeros."""
def generate(self, distribution: Union[dict, list, tuple]):
shape = compute_distribution_shape(distribution)
reasonings = torch.zeros(*shape)
@ -360,6 +390,7 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with ones."""
def generate(self, distribution: Union[dict, list, tuple]):
shape = compute_distribution_shape(distribution)
reasonings = torch.ones(*shape)
@ -369,6 +400,7 @@ class OnesReasoningsInitializer(AbstractReasoningsInitializer):
class RandomReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are randomly initialized."""
def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
super().__init__(**kwargs)
self.minimum = minimum
@ -383,6 +415,7 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
"""Each component reasons positively for exactly one class."""
def generate(self, distribution: Union[dict, list, tuple]):
num_components, num_classes, _ = compute_distribution_shape(
distribution)
@ -401,6 +434,7 @@ class AbstractTransformInitializer(ABC):
class AbstractLinearTransformInitializer(AbstractTransformInitializer):
"""Abstract class for all linear transform initializers."""
def __init__(self, out_dim_first: bool = False):
self.out_dim_first = out_dim_first
@ -417,6 +451,7 @@ class AbstractLinearTransformInitializer(AbstractTransformInitializer):
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with zeros."""
def generate(self, in_dim: int, out_dim: int):
weights = torch.zeros(in_dim, out_dim)
return self.generate_end_hook(weights)
@ -424,6 +459,7 @@ class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with ones."""
def generate(self, in_dim: int, out_dim: int):
weights = torch.ones(in_dim, out_dim)
return self.generate_end_hook(weights)
@ -431,6 +467,7 @@ class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
class EyeTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with the largest possible identity matrix."""
def generate(self, in_dim: int, out_dim: int):
weights = torch.zeros(in_dim, out_dim)
I = torch.eye(min(in_dim, out_dim))
@ -440,6 +477,7 @@ class EyeTransformInitializer(AbstractLinearTransformInitializer):
class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
"""Abstract class for all data-aware linear transform initializers."""
def __init__(self,
data: torch.Tensor,
noise: float = 0.0,