feat: add LiteralLinearTransformInitializer
This commit is contained in:
parent
784a963527
commit
08b3f9bbb9
@ -218,7 +218,7 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
|||||||
for k, v in distribution.items():
|
for k, v in distribution.items():
|
||||||
stratified_data = self.data[self.targets == k]
|
stratified_data = self.data[self.targets == k]
|
||||||
if len(stratified_data) == 0:
|
if len(stratified_data) == 0:
|
||||||
raise ValueError(f"No data available for class {k}.")
|
raise ValueError(f"No data available for class {k}.")
|
||||||
initializer = self.subinit_type(
|
initializer = self.subinit_type(
|
||||||
stratified_data,
|
stratified_data,
|
||||||
noise=self.noise,
|
noise=self.noise,
|
||||||
@ -460,11 +460,19 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
|||||||
|
|
||||||
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
|
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
|
||||||
"""Initialize a matrix with Eigenvectors from the data."""
|
"""Initialize a matrix with Eigenvectors from the data."""
|
||||||
|
|
||||||
def generate(self, in_dim: int, out_dim: int):
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
_, _, weights = torch.pca_lowrank(self.data, q=out_dim)
|
_, _, weights = torch.pca_lowrank(self.data, q=out_dim)
|
||||||
return self.generate_end_hook(weights)
|
return self.generate_end_hook(weights)
|
||||||
|
|
||||||
|
|
||||||
|
class LiteralLinearTransformInitializer(AbstractDataAwareLTInitializer):
|
||||||
|
"""'Generate' the provided weights."""
|
||||||
|
|
||||||
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
|
return self.generate_end_hook(self.data)
|
||||||
|
|
||||||
|
|
||||||
# Aliases - Components
|
# Aliases - Components
|
||||||
CACI = ClassAwareCompInitializer
|
CACI = ClassAwareCompInitializer
|
||||||
DACI = DataAwareCompInitializer
|
DACI = DataAwareCompInitializer
|
||||||
@ -497,3 +505,4 @@ Eye = EyeTransformInitializer
|
|||||||
OLTI = OnesLinearTransformInitializer
|
OLTI = OnesLinearTransformInitializer
|
||||||
ZLTI = ZerosLinearTransformInitializer
|
ZLTI = ZerosLinearTransformInitializer
|
||||||
PCALTI = PCALinearTransformInitializer
|
PCALTI = PCALinearTransformInitializer
|
||||||
|
LLTI = LiteralLinearTransformInitializer
|
||||||
|
2
setup.py
2
setup.py
@ -23,6 +23,7 @@ INSTALL_REQUIRES = [
|
|||||||
"torchvision>=0.7.1",
|
"torchvision>=0.7.1",
|
||||||
"numpy>=1.9.1",
|
"numpy>=1.9.1",
|
||||||
"sklearn",
|
"sklearn",
|
||||||
|
"matplotlib",
|
||||||
]
|
]
|
||||||
DATASETS = [
|
DATASETS = [
|
||||||
"requests",
|
"requests",
|
||||||
@ -40,7 +41,6 @@ DOCS = [
|
|||||||
"sphinx-autodoc-typehints",
|
"sphinx-autodoc-typehints",
|
||||||
]
|
]
|
||||||
EXAMPLES = [
|
EXAMPLES = [
|
||||||
"matplotlib",
|
|
||||||
"torchinfo",
|
"torchinfo",
|
||||||
]
|
]
|
||||||
TESTS = ["codecov", "pytest"]
|
TESTS = ["codecov", "pytest"]
|
||||||
|
Loading…
Reference in New Issue
Block a user