fix(initializers): bug fixes in LT initializers
This commit is contained in:
parent
5a3dbfac2e
commit
f78ff1a464
@ -439,7 +439,9 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
data: torch.Tensor,
|
data: torch.Tensor,
|
||||||
noise: float = 0.0,
|
noise: float = 0.0,
|
||||||
transform: Callable = torch.nn.Identity()):
|
transform: Callable = torch.nn.Identity(),
|
||||||
|
out_dim_first: bool = False):
|
||||||
|
super().__init__(out_dim_first)
|
||||||
self.data = data
|
self.data = data
|
||||||
self.noise = noise
|
self.noise = noise
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
@ -454,7 +456,6 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
|||||||
|
|
||||||
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
|
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
|
||||||
"""Initialize a matrix with Eigenvectors from the data."""
|
"""Initialize a matrix with Eigenvectors from the data."""
|
||||||
@abstractmethod
|
|
||||||
def generate(self, in_dim: int, out_dim: int):
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
_, _, weights = torch.pca_lowrank(self.data, q=out_dim)
|
_, _, weights = torch.pca_lowrank(self.data, q=out_dim)
|
||||||
return self.generate_end_hook(weights)
|
return self.generate_end_hook(weights)
|
||||||
|
Loading…
Reference in New Issue
Block a user