feat: add LiteralLinearTransformInitializer

This commit is contained in:
Jensun Ravichandran 2022-03-21 14:38:00 +01:00
parent 784a963527
commit 08b3f9bbb9
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F
2 changed files with 11 additions and 2 deletions

View File

@ -218,7 +218,7 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
for k, v in distribution.items(): for k, v in distribution.items():
stratified_data = self.data[self.targets == k] stratified_data = self.data[self.targets == k]
if len(stratified_data) == 0: if len(stratified_data) == 0:
raise ValueError(f"No data available for class {k}.") raise ValueError(f"No data available for class {k}.")
initializer = self.subinit_type( initializer = self.subinit_type(
stratified_data, stratified_data,
noise=self.noise, noise=self.noise,
@ -460,11 +460,19 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer): class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
"""Initialize a matrix with Eigenvectors from the data.""" """Initialize a matrix with Eigenvectors from the data."""
def generate(self, in_dim: int, out_dim: int): def generate(self, in_dim: int, out_dim: int):
_, _, weights = torch.pca_lowrank(self.data, q=out_dim) _, _, weights = torch.pca_lowrank(self.data, q=out_dim)
return self.generate_end_hook(weights) return self.generate_end_hook(weights)
class LiteralLinearTransformInitializer(AbstractDataAwareLTInitializer):
"""'Generate' the provided weights."""
def generate(self, in_dim: int, out_dim: int):
return self.generate_end_hook(self.data)
# Aliases - Components # Aliases - Components
CACI = ClassAwareCompInitializer CACI = ClassAwareCompInitializer
DACI = DataAwareCompInitializer DACI = DataAwareCompInitializer
@ -497,3 +505,4 @@ Eye = EyeTransformInitializer
OLTI = OnesLinearTransformInitializer OLTI = OnesLinearTransformInitializer
ZLTI = ZerosLinearTransformInitializer ZLTI = ZerosLinearTransformInitializer
PCALTI = PCALinearTransformInitializer PCALTI = PCALinearTransformInitializer
LLTI = LiteralLinearTransformInitializer

View File

@ -23,6 +23,7 @@ INSTALL_REQUIRES = [
"torchvision>=0.7.1", "torchvision>=0.7.1",
"numpy>=1.9.1", "numpy>=1.9.1",
"sklearn", "sklearn",
"matplotlib",
] ]
DATASETS = [ DATASETS = [
"requests", "requests",
@ -40,7 +41,6 @@ DOCS = [
"sphinx-autodoc-typehints", "sphinx-autodoc-typehints",
] ]
EXAMPLES = [ EXAMPLES = [
"matplotlib",
"torchinfo", "torchinfo",
] ]
TESTS = ["codecov", "pytest"] TESTS = ["codecov", "pytest"]