fix(initializers): bug fixes in LT initializers
This commit is contained in:
parent
5a3dbfac2e
commit
f78ff1a464
@ -439,7 +439,9 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
||||
def __init__(self,
|
||||
data: torch.Tensor,
|
||||
noise: float = 0.0,
|
||||
transform: Callable = torch.nn.Identity()):
|
||||
transform: Callable = torch.nn.Identity(),
|
||||
out_dim_first: bool = False):
|
||||
super().__init__(out_dim_first)
|
||||
self.data = data
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
@ -454,7 +456,6 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
||||
|
||||
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
|
||||
"""Initialize a matrix with Eigenvectors from the data."""
|
||||
@abstractmethod
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
_, _, weights = torch.pca_lowrank(self.data, q=out_dim)
|
||||
return self.generate_end_hook(weights)
|
||||
|
Loading…
Reference in New Issue
Block a user