Add scaffolding
This commit is contained in:
		| @@ -1,4 +1,5 @@ | ||||
| """ProtoTroch Initializers.""" | ||||
| """ProtoTroch Component and Label Initializers.""" | ||||
|  | ||||
| import warnings | ||||
| from collections.abc import Iterable | ||||
| from itertools import chain | ||||
|   | ||||
| @@ -1,5 +1,7 @@ | ||||
| """ProtoTorch modules.""" | ||||
|  | ||||
| from .competitions import * | ||||
| from .initializers import * | ||||
| from .pooling import * | ||||
| from .transformations import * | ||||
| from .wrappers import LambdaLayer, LossLayer | ||||
|   | ||||
							
								
								
									
										61
									
								
								prototorch/modules/initializers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								prototorch/modules/initializers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | ||||
| """ProtoTroch Module Initializers.""" | ||||
|  | ||||
| import torch | ||||
|  | ||||
|  | ||||
| # Transformations | ||||
| class MatrixInitializer(object): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         ... | ||||
|  | ||||
|     def generate(self, shape): | ||||
|         raise NotImplementedError("Subclasses should implement this!") | ||||
|  | ||||
|  | ||||
| class ZerosInitializer(MatrixInitializer): | ||||
|     def generate(self, shape): | ||||
|         return torch.zeros(shape) | ||||
|  | ||||
|  | ||||
| class OnesInitializer(MatrixInitializer): | ||||
|     def __init__(self, scale=1.0): | ||||
|         super().__init__() | ||||
|         self.scale = scale | ||||
|  | ||||
|     def generate(self, shape): | ||||
|         return torch.ones(shape) * self.scale | ||||
|  | ||||
|  | ||||
| class UniformInitializer(MatrixInitializer): | ||||
|     def __init__(self, minimum=0.0, maximum=1.0, scale=1.0): | ||||
|         super().__init__() | ||||
|         self.minimum = minimum | ||||
|         self.maximum = maximum | ||||
|         self.scale = scale | ||||
|  | ||||
|     def generate(self, shape): | ||||
|         return torch.ones(shape).uniform_(self.minimum, | ||||
|                                           self.maximum) * self.scale | ||||
|  | ||||
|  | ||||
| class DataAwareInitializer(MatrixInitializer): | ||||
|     def __init__(self, data, transform=torch.nn.Identity()): | ||||
|         super().__init__() | ||||
|         self.data = data | ||||
|         self.transform = transform | ||||
|  | ||||
|     def __del__(self): | ||||
|         del self.data | ||||
|  | ||||
|  | ||||
| class EigenVectorInitializer(DataAwareInitializer): | ||||
|     def generate(self, shape): | ||||
|         # TODO | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|  | ||||
| # Aliases | ||||
| EV = EigenVectorInitializer | ||||
| Random = RandomInitializer = UniformInitializer | ||||
| Zeros = ZerosInitializer | ||||
| Ones = OnesInitializer | ||||
							
								
								
									
										49
									
								
								prototorch/modules/transformations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								prototorch/modules/transformations.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | ||||
| """ProtoTorch Transformation Layers.""" | ||||
|  | ||||
| import torch | ||||
| from torch.nn.parameter import Parameter | ||||
|  | ||||
| from .initializers import MatrixInitializer | ||||
|  | ||||
|  | ||||
| def _precheck_initializer(initializer): | ||||
|     if not isinstance(initializer, MatrixInitializer): | ||||
|         emsg = f"`initializer` has to be some subtype of " \ | ||||
|             f"{MatrixInitializer}. " \ | ||||
|             f"You have provided: {initializer=} instead." | ||||
|         raise TypeError(emsg) | ||||
|  | ||||
|  | ||||
| class Omega(torch.nn.Module): | ||||
|     """The Omega mapping used in GMLVQ.""" | ||||
|     def __init__(self, | ||||
|                  num_replicas=1, | ||||
|                  input_dim=None, | ||||
|                  latent_dim=None, | ||||
|                  initializer=None, | ||||
|                  *, | ||||
|                  initialized_weights=None): | ||||
|         super().__init__() | ||||
|  | ||||
|         if initialized_weights is not None: | ||||
|             self._register_weights(initialized_weights) | ||||
|         else: | ||||
|             if num_replicas == 1: | ||||
|                 shape = (input_dim, latent_dim) | ||||
|             else: | ||||
|                 shape = (num_replicas, input_dim, latent_dim) | ||||
|             self._initialize_weights(shape, initializer) | ||||
|  | ||||
|     def _register_weights(self, weights): | ||||
|         self.register_parameter("_omega", Parameter(weights)) | ||||
|  | ||||
|     def _initialize_weights(self, shape, initializer): | ||||
|         _precheck_initializer(initializer) | ||||
|         _omega = initializer.generate(shape) | ||||
|         self._register_weights(_omega) | ||||
|  | ||||
|     def forward(self): | ||||
|         return self._omega | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return f"(omega): (shape: {tuple(self._omega.shape)})" | ||||
		Reference in New Issue
	
	Block a user