From 08b3f9bbb9744f074ab5dc319d3d514e38e629ac Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 21 Mar 2022 14:38:00 +0100 Subject: [PATCH] feat: add `LiteralLinearTransformInitializer` --- prototorch/core/initializers.py | 11 ++++++++++- setup.py | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index a21fad8..2f9137d 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -218,7 +218,7 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer): for k, v in distribution.items(): stratified_data = self.data[self.targets == k] if len(stratified_data) == 0: - raise ValueError(f"No data available for class {k}.") + raise ValueError(f"No data available for class {k}.") initializer = self.subinit_type( stratified_data, noise=self.noise, @@ -460,11 +460,19 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer): class PCALinearTransformInitializer(AbstractDataAwareLTInitializer): """Initialize a matrix with Eigenvectors from the data.""" + def generate(self, in_dim: int, out_dim: int): _, _, weights = torch.pca_lowrank(self.data, q=out_dim) return self.generate_end_hook(weights) +class LiteralLinearTransformInitializer(AbstractDataAwareLTInitializer): + """'Generate' the provided weights.""" + + def generate(self, in_dim: int, out_dim: int): + return self.generate_end_hook(self.data) + + # Aliases - Components CACI = ClassAwareCompInitializer DACI = DataAwareCompInitializer @@ -497,3 +505,4 @@ Eye = EyeTransformInitializer OLTI = OnesLinearTransformInitializer ZLTI = ZerosLinearTransformInitializer PCALTI = PCALinearTransformInitializer +LLTI = LiteralLinearTransformInitializer diff --git a/setup.py b/setup.py index b8bcc9e..6ed792b 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ INSTALL_REQUIRES = [ "torchvision>=0.7.1", "numpy>=1.9.1", "sklearn", + "matplotlib", ] DATASETS = [ "requests", @@ -40,7 +41,6 @@ DOCS = [ "sphinx-autodoc-typehints", ] EXAMPLES = [ - "matplotlib", "torchinfo", ] TESTS = ["codecov", "pytest"]