diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index fc5e83f..fa4299c 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -439,7 +439,9 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer): def __init__(self, data: torch.Tensor, noise: float = 0.0, - transform: Callable = torch.nn.Identity()): + transform: Callable = torch.nn.Identity(), + out_dim_first: bool = False): + super().__init__(out_dim_first) self.data = data self.noise = noise self.transform = transform @@ -454,7 +456,6 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer): class PCALinearTransformInitializer(AbstractDataAwareLTInitializer): """Initialize a matrix with Eigenvectors from the data.""" - @abstractmethod def generate(self, in_dim: int, out_dim: int): _, _, weights = torch.pca_lowrank(self.data, q=out_dim) return self.generate_end_hook(weights)