feat: add LiteralLinearTransformInitializer

This commit is contained in:
Jensun Ravichandran 2022-03-21 14:38:00 +01:00
parent 784a963527
commit 08b3f9bbb9
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F
2 changed files with 11 additions and 2 deletions

View File

@ -460,11 +460,19 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
"""Initialize a matrix with Eigenvectors from the data."""
def generate(self, in_dim: int, out_dim: int):
_, _, weights = torch.pca_lowrank(self.data, q=out_dim)
return self.generate_end_hook(weights)
class LiteralLinearTransformInitializer(AbstractDataAwareLTInitializer):
"""'Generate' the provided weights."""
def generate(self, in_dim: int, out_dim: int):
return self.generate_end_hook(self.data)
# Aliases - Components
CACI = ClassAwareCompInitializer
DACI = DataAwareCompInitializer
@ -497,3 +505,4 @@ Eye = EyeTransformInitializer
OLTI = OnesLinearTransformInitializer
ZLTI = ZerosLinearTransformInitializer
PCALTI = PCALinearTransformInitializer
LLTI = LiteralLinearTransformInitializer

View File

@ -23,6 +23,7 @@ INSTALL_REQUIRES = [
"torchvision>=0.7.1",
"numpy>=1.9.1",
"sklearn",
"matplotlib",
]
DATASETS = [
"requests",
@ -40,7 +41,6 @@ DOCS = [
"sphinx-autodoc-typehints",
]
EXAMPLES = [
"matplotlib",
"torchinfo",
]
TESTS = ["codecov", "pytest"]