Compare commits
1 Commits
master
...
feature/tr
Author | SHA1 | Date | |
---|---|---|---|
|
17b45249f4 |
@ -1,4 +1,5 @@
|
|||||||
"""ProtoTroch Initializers."""
|
"""ProtoTroch Component and Label Initializers."""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""ProtoTorch modules."""
|
"""ProtoTorch modules."""
|
||||||
|
|
||||||
from .competitions import *
|
from .competitions import *
|
||||||
|
from .initializers import *
|
||||||
from .pooling import *
|
from .pooling import *
|
||||||
|
from .transformations import *
|
||||||
from .wrappers import LambdaLayer, LossLayer
|
from .wrappers import LambdaLayer, LossLayer
|
||||||
|
61
prototorch/modules/initializers.py
Normal file
61
prototorch/modules/initializers.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
"""ProtoTroch Module Initializers."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# Transformations
|
||||||
|
class MatrixInitializer(object):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
...
|
||||||
|
|
||||||
|
def generate(self, shape):
|
||||||
|
raise NotImplementedError("Subclasses should implement this!")
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosInitializer(MatrixInitializer):
|
||||||
|
def generate(self, shape):
|
||||||
|
return torch.zeros(shape)
|
||||||
|
|
||||||
|
|
||||||
|
class OnesInitializer(MatrixInitializer):
|
||||||
|
def __init__(self, scale=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def generate(self, shape):
|
||||||
|
return torch.ones(shape) * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class UniformInitializer(MatrixInitializer):
|
||||||
|
def __init__(self, minimum=0.0, maximum=1.0, scale=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.minimum = minimum
|
||||||
|
self.maximum = maximum
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def generate(self, shape):
|
||||||
|
return torch.ones(shape).uniform_(self.minimum,
|
||||||
|
self.maximum) * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class DataAwareInitializer(MatrixInitializer):
|
||||||
|
def __init__(self, data, transform=torch.nn.Identity()):
|
||||||
|
super().__init__()
|
||||||
|
self.data = data
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
del self.data
|
||||||
|
|
||||||
|
|
||||||
|
class EigenVectorInitializer(DataAwareInitializer):
|
||||||
|
def generate(self, shape):
|
||||||
|
# TODO
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
# Aliases
|
||||||
|
EV = EigenVectorInitializer
|
||||||
|
Random = RandomInitializer = UniformInitializer
|
||||||
|
Zeros = ZerosInitializer
|
||||||
|
Ones = OnesInitializer
|
49
prototorch/modules/transformations.py
Normal file
49
prototorch/modules/transformations.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
"""ProtoTorch Transformation Layers."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from .initializers import MatrixInitializer
|
||||||
|
|
||||||
|
|
||||||
|
def _precheck_initializer(initializer):
|
||||||
|
if not isinstance(initializer, MatrixInitializer):
|
||||||
|
emsg = f"`initializer` has to be some subtype of " \
|
||||||
|
f"{MatrixInitializer}. " \
|
||||||
|
f"You have provided: {initializer=} instead."
|
||||||
|
raise TypeError(emsg)
|
||||||
|
|
||||||
|
|
||||||
|
class Omega(torch.nn.Module):
|
||||||
|
"""The Omega mapping used in GMLVQ."""
|
||||||
|
def __init__(self,
|
||||||
|
num_replicas=1,
|
||||||
|
input_dim=None,
|
||||||
|
latent_dim=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_weights=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if initialized_weights is not None:
|
||||||
|
self._register_weights(initialized_weights)
|
||||||
|
else:
|
||||||
|
if num_replicas == 1:
|
||||||
|
shape = (input_dim, latent_dim)
|
||||||
|
else:
|
||||||
|
shape = (num_replicas, input_dim, latent_dim)
|
||||||
|
self._initialize_weights(shape, initializer)
|
||||||
|
|
||||||
|
def _register_weights(self, weights):
|
||||||
|
self.register_parameter("_omega", Parameter(weights))
|
||||||
|
|
||||||
|
def _initialize_weights(self, shape, initializer):
|
||||||
|
_precheck_initializer(initializer)
|
||||||
|
_omega = initializer.generate(shape)
|
||||||
|
self._register_weights(_omega)
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self._omega
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"(omega): (shape: {tuple(self._omega.shape)})"
|
Loading…
Reference in New Issue
Block a user