feat: add LiteralLinearTransformInitializer
				
					
				
			This commit is contained in:
		| @@ -460,11 +460,19 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer): | |||||||
|  |  | ||||||
| class PCALinearTransformInitializer(AbstractDataAwareLTInitializer): | class PCALinearTransformInitializer(AbstractDataAwareLTInitializer): | ||||||
|     """Initialize a matrix with Eigenvectors from the data.""" |     """Initialize a matrix with Eigenvectors from the data.""" | ||||||
|  |  | ||||||
|     def generate(self, in_dim: int, out_dim: int): |     def generate(self, in_dim: int, out_dim: int): | ||||||
|         _, _, weights = torch.pca_lowrank(self.data, q=out_dim) |         _, _, weights = torch.pca_lowrank(self.data, q=out_dim) | ||||||
|         return self.generate_end_hook(weights) |         return self.generate_end_hook(weights) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LiteralLinearTransformInitializer(AbstractDataAwareLTInitializer): | ||||||
|  |     """'Generate' the provided weights.""" | ||||||
|  |  | ||||||
|  |     def generate(self, in_dim: int, out_dim: int): | ||||||
|  |         return self.generate_end_hook(self.data) | ||||||
|  |  | ||||||
|  |  | ||||||
| # Aliases - Components | # Aliases - Components | ||||||
| CACI = ClassAwareCompInitializer | CACI = ClassAwareCompInitializer | ||||||
| DACI = DataAwareCompInitializer | DACI = DataAwareCompInitializer | ||||||
| @@ -497,3 +505,4 @@ Eye = EyeTransformInitializer | |||||||
| OLTI = OnesLinearTransformInitializer | OLTI = OnesLinearTransformInitializer | ||||||
| ZLTI = ZerosLinearTransformInitializer | ZLTI = ZerosLinearTransformInitializer | ||||||
| PCALTI = PCALinearTransformInitializer | PCALTI = PCALinearTransformInitializer | ||||||
|  | LLTI = LiteralLinearTransformInitializer | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @@ -23,6 +23,7 @@ INSTALL_REQUIRES = [ | |||||||
|     "torchvision>=0.7.1", |     "torchvision>=0.7.1", | ||||||
|     "numpy>=1.9.1", |     "numpy>=1.9.1", | ||||||
|     "sklearn", |     "sklearn", | ||||||
|  |     "matplotlib", | ||||||
| ] | ] | ||||||
| DATASETS = [ | DATASETS = [ | ||||||
|     "requests", |     "requests", | ||||||
| @@ -40,7 +41,6 @@ DOCS = [ | |||||||
|     "sphinx-autodoc-typehints", |     "sphinx-autodoc-typehints", | ||||||
| ] | ] | ||||||
| EXAMPLES = [ | EXAMPLES = [ | ||||||
|     "matplotlib", |  | ||||||
|     "torchinfo", |     "torchinfo", | ||||||
| ] | ] | ||||||
| TESTS = ["codecov", "pytest"] | TESTS = ["codecov", "pytest"] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user