feat: warn user when component counts do not match
This commit is contained in:
parent
08b3f9bbb9
commit
ed5b9b6c62
@ -26,11 +26,18 @@ class LiteralCompInitializer(AbstractComponentsInitializer):
|
||||
Use this to 'generate' pre-initialized components elsewhere.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, components):
|
||||
self.components = components
|
||||
|
||||
def generate(self, num_components: int = 0):
|
||||
"""Ignore `num_components` and simply return `self.components`."""
|
||||
provided_num_components = len(self.components)
|
||||
if provided_num_components != num_components:
|
||||
wmsg = f"The number of components ({provided_num_components}) " \
|
||||
f"provided to {self.__class__.__name__} " \
|
||||
f"does not match the expected number ({num_components})."
|
||||
warnings.warn(wmsg)
|
||||
if not isinstance(self.components, torch.Tensor):
|
||||
wmsg = f"Converting components to {torch.Tensor}..."
|
||||
warnings.warn(wmsg)
|
||||
@ -40,6 +47,7 @@ class LiteralCompInitializer(AbstractComponentsInitializer):
|
||||
|
||||
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all dimension-aware components initializers."""
|
||||
|
||||
def __init__(self, shape: Union[Iterable, int]):
|
||||
if isinstance(shape, Iterable):
|
||||
self.component_shape = tuple(shape)
|
||||
@ -53,6 +61,7 @@ class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
||||
|
||||
class ZerosCompInitializer(ShapeAwareCompInitializer):
|
||||
"""Generate zeros corresponding to the components shape."""
|
||||
|
||||
def generate(self, num_components: int):
|
||||
components = torch.zeros((num_components, ) + self.component_shape)
|
||||
return components
|
||||
@ -60,6 +69,7 @@ class ZerosCompInitializer(ShapeAwareCompInitializer):
|
||||
|
||||
class OnesCompInitializer(ShapeAwareCompInitializer):
|
||||
"""Generate ones corresponding to the components shape."""
|
||||
|
||||
def generate(self, num_components: int):
|
||||
components = torch.ones((num_components, ) + self.component_shape)
|
||||
return components
|
||||
@ -67,6 +77,7 @@ class OnesCompInitializer(ShapeAwareCompInitializer):
|
||||
|
||||
class FillValueCompInitializer(OnesCompInitializer):
|
||||
"""Generate components with the provided `fill_value`."""
|
||||
|
||||
def __init__(self, shape, fill_value: float = 1.0):
|
||||
super().__init__(shape)
|
||||
self.fill_value = fill_value
|
||||
@ -79,6 +90,7 @@ class FillValueCompInitializer(OnesCompInitializer):
|
||||
|
||||
class UniformCompInitializer(OnesCompInitializer):
|
||||
"""Generate components by sampling from a continuous uniform distribution."""
|
||||
|
||||
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
|
||||
super().__init__(shape)
|
||||
self.minimum = minimum
|
||||
@ -93,6 +105,7 @@ class UniformCompInitializer(OnesCompInitializer):
|
||||
|
||||
class RandomNormalCompInitializer(OnesCompInitializer):
|
||||
"""Generate components by sampling from a standard normal distribution."""
|
||||
|
||||
def __init__(self, shape, shift=0.0, scale=1.0):
|
||||
super().__init__(shape)
|
||||
self.shift = shift
|
||||
@ -113,6 +126,7 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
|
||||
`data` has to be a torch tensor.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data: torch.Tensor,
|
||||
noise: float = 0.0,
|
||||
@ -137,6 +151,7 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
|
||||
|
||||
class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
|
||||
"""'Generate' the components from the provided data."""
|
||||
|
||||
def generate(self, num_components: int = 0):
|
||||
"""Ignore `num_components` and simply return transformed `self.data`."""
|
||||
components = self.generate_end_hook(self.data)
|
||||
@ -145,6 +160,7 @@ class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
|
||||
|
||||
class SelectionCompInitializer(AbstractDataAwareCompInitializer):
|
||||
"""Generate components by uniformly sampling from the provided data."""
|
||||
|
||||
def generate(self, num_components: int):
|
||||
indices = torch.LongTensor(num_components).random_(0, len(self.data))
|
||||
samples = self.data[indices]
|
||||
@ -154,6 +170,7 @@ class SelectionCompInitializer(AbstractDataAwareCompInitializer):
|
||||
|
||||
class MeanCompInitializer(AbstractDataAwareCompInitializer):
|
||||
"""Generate components by computing the mean of the provided data."""
|
||||
|
||||
def generate(self, num_components: int):
|
||||
mean = self.data.mean(dim=0)
|
||||
repeat_dim = [num_components] + [1] * len(mean.shape)
|
||||
@ -172,6 +189,7 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
|
||||
target tensors.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data,
|
||||
noise: float = 0.0,
|
||||
@ -199,6 +217,7 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
|
||||
|
||||
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
||||
"""'Generate' components from provided data and requested distribution."""
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
"""Ignore `distribution` and simply return transformed `self.data`."""
|
||||
components = self.generate_end_hook(self.data)
|
||||
@ -207,6 +226,7 @@ class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
||||
|
||||
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
||||
"""Abstract class for all stratified components initializers."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]:
|
||||
@ -231,6 +251,7 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
||||
|
||||
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
|
||||
"""Generate components using stratified sampling from the provided data."""
|
||||
|
||||
@property
|
||||
def subinit_type(self):
|
||||
return SelectionCompInitializer
|
||||
@ -238,6 +259,7 @@ class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
|
||||
|
||||
class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
|
||||
"""Generate components at stratified means of the provided data."""
|
||||
|
||||
@property
|
||||
def subinit_type(self):
|
||||
return MeanCompInitializer
|
||||
@ -246,6 +268,7 @@ class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
|
||||
# Labels
|
||||
class AbstractLabelsInitializer(ABC):
|
||||
"""Abstract class for all labels initializers."""
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
@ -257,6 +280,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
|
||||
Use this to 'generate' pre-initialized labels elsewhere.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, labels):
|
||||
self.labels = labels
|
||||
|
||||
@ -275,6 +299,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
|
||||
|
||||
class DataAwareLabelsInitializer(AbstractLabelsInitializer):
|
||||
"""'Generate' the labels from a torch Dataset."""
|
||||
|
||||
def __init__(self, data):
|
||||
self.data, self.targets = parse_data_arg(data)
|
||||
|
||||
@ -285,6 +310,7 @@ class DataAwareLabelsInitializer(AbstractLabelsInitializer):
|
||||
|
||||
class LabelsInitializer(AbstractLabelsInitializer):
|
||||
"""Generate labels from `distribution`."""
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
labels_list = []
|
||||
@ -296,6 +322,7 @@ class LabelsInitializer(AbstractLabelsInitializer):
|
||||
|
||||
class OneHotLabelsInitializer(LabelsInitializer):
|
||||
"""Generate one-hot-encoded labels from `distribution`."""
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
num_classes = len(distribution.keys())
|
||||
@ -314,6 +341,7 @@ def compute_distribution_shape(distribution):
|
||||
|
||||
class AbstractReasoningsInitializer(ABC):
|
||||
"""Abstract class for all reasonings initializers."""
|
||||
|
||||
def __init__(self, components_first: bool = True):
|
||||
self.components_first = components_first
|
||||
|
||||
@ -334,6 +362,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
Use this to 'generate' pre-initialized reasonings elsewhere.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, reasonings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.reasonings = reasonings
|
||||
@ -351,6 +380,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
|
||||
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are all initialized with zeros."""
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = compute_distribution_shape(distribution)
|
||||
reasonings = torch.zeros(*shape)
|
||||
@ -360,6 +390,7 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
|
||||
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are all initialized with ones."""
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = compute_distribution_shape(distribution)
|
||||
reasonings = torch.ones(*shape)
|
||||
@ -369,6 +400,7 @@ class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
|
||||
class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are randomly initialized."""
|
||||
|
||||
def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.minimum = minimum
|
||||
@ -383,6 +415,7 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
|
||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Each component reasons positively for exactly one class."""
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
num_components, num_classes, _ = compute_distribution_shape(
|
||||
distribution)
|
||||
@ -401,6 +434,7 @@ class AbstractTransformInitializer(ABC):
|
||||
|
||||
class AbstractLinearTransformInitializer(AbstractTransformInitializer):
|
||||
"""Abstract class for all linear transform initializers."""
|
||||
|
||||
def __init__(self, out_dim_first: bool = False):
|
||||
self.out_dim_first = out_dim_first
|
||||
|
||||
@ -417,6 +451,7 @@ class AbstractLinearTransformInitializer(AbstractTransformInitializer):
|
||||
|
||||
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with zeros."""
|
||||
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.zeros(in_dim, out_dim)
|
||||
return self.generate_end_hook(weights)
|
||||
@ -424,6 +459,7 @@ class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||
|
||||
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with ones."""
|
||||
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.ones(in_dim, out_dim)
|
||||
return self.generate_end_hook(weights)
|
||||
@ -431,6 +467,7 @@ class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||
|
||||
class EyeTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with the largest possible identity matrix."""
|
||||
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.zeros(in_dim, out_dim)
|
||||
I = torch.eye(min(in_dim, out_dim))
|
||||
@ -440,6 +477,7 @@ class EyeTransformInitializer(AbstractLinearTransformInitializer):
|
||||
|
||||
class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
||||
"""Abstract class for all data-aware linear transform initializers."""
|
||||
|
||||
def __init__(self,
|
||||
data: torch.Tensor,
|
||||
noise: float = 0.0,
|
||||
|
Loading…
Reference in New Issue
Block a user