fix(initializers): bug fixes in LT initializers

This commit is contained in:
Jensun Ravichandran 2021-06-20 18:56:06 +02:00
parent 5a3dbfac2e
commit f78ff1a464

View File

@ -439,7 +439,9 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
def __init__(self, def __init__(self,
data: torch.Tensor, data: torch.Tensor,
noise: float = 0.0, noise: float = 0.0,
transform: Callable = torch.nn.Identity()): transform: Callable = torch.nn.Identity(),
out_dim_first: bool = False):
super().__init__(out_dim_first)
self.data = data self.data = data
self.noise = noise self.noise = noise
self.transform = transform self.transform = transform
@ -454,7 +456,6 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer): class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
"""Initialize a matrix with Eigenvectors from the data.""" """Initialize a matrix with Eigenvectors from the data."""
@abstractmethod
def generate(self, in_dim: int, out_dim: int): def generate(self, in_dim: int, out_dim: int):
_, _, weights = torch.pca_lowrank(self.data, q=out_dim) _, _, weights = torch.pca_lowrank(self.data, q=out_dim)
return self.generate_end_hook(weights) return self.generate_end_hook(weights)