From f78ff1a464852d129fe0537c94076cab32c29cb5 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sun, 20 Jun 2021 18:56:06 +0200 Subject: [PATCH] fix(initializers): bug fixes in LT initializers --- prototorch/core/initializers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index fc5e83f..fa4299c 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -439,7 +439,9 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer): def __init__(self, data: torch.Tensor, noise: float = 0.0, - transform: Callable = torch.nn.Identity()): + transform: Callable = torch.nn.Identity(), + out_dim_first: bool = False): + super().__init__(out_dim_first) self.data = data self.noise = noise self.transform = transform @@ -454,7 +456,6 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer): class PCALinearTransformInitializer(AbstractDataAwareLTInitializer): """Initialize a matrix with Eigenvectors from the data.""" - @abstractmethod def generate(self, in_dim: int, out_dim: int): _, _, weights = torch.pca_lowrank(self.data, q=out_dim) return self.generate_end_hook(weights)