Compare commits
1 Commits
master
...
feature/tr
Author | SHA1 | Date | |
---|---|---|---|
|
17b45249f4 |
@ -1,4 +1,5 @@
|
||||
"""ProtoTroch Initializers."""
|
||||
"""ProtoTroch Component and Label Initializers."""
|
||||
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from itertools import chain
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""ProtoTorch modules."""
|
||||
|
||||
from .competitions import *
|
||||
from .initializers import *
|
||||
from .pooling import *
|
||||
from .transformations import *
|
||||
from .wrappers import LambdaLayer, LossLayer
|
||||
|
61
prototorch/modules/initializers.py
Normal file
61
prototorch/modules/initializers.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""ProtoTroch Module Initializers."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Transformations
|
||||
class MatrixInitializer(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
...
|
||||
|
||||
def generate(self, shape):
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
|
||||
class ZerosInitializer(MatrixInitializer):
|
||||
def generate(self, shape):
|
||||
return torch.zeros(shape)
|
||||
|
||||
|
||||
class OnesInitializer(MatrixInitializer):
|
||||
def __init__(self, scale=1.0):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, shape):
|
||||
return torch.ones(shape) * self.scale
|
||||
|
||||
|
||||
class UniformInitializer(MatrixInitializer):
|
||||
def __init__(self, minimum=0.0, maximum=1.0, scale=1.0):
|
||||
super().__init__()
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, shape):
|
||||
return torch.ones(shape).uniform_(self.minimum,
|
||||
self.maximum) * self.scale
|
||||
|
||||
|
||||
class DataAwareInitializer(MatrixInitializer):
|
||||
def __init__(self, data, transform=torch.nn.Identity()):
|
||||
super().__init__()
|
||||
self.data = data
|
||||
self.transform = transform
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
|
||||
|
||||
class EigenVectorInitializer(DataAwareInitializer):
|
||||
def generate(self, shape):
|
||||
# TODO
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# Aliases
|
||||
EV = EigenVectorInitializer
|
||||
Random = RandomInitializer = UniformInitializer
|
||||
Zeros = ZerosInitializer
|
||||
Ones = OnesInitializer
|
49
prototorch/modules/transformations.py
Normal file
49
prototorch/modules/transformations.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""ProtoTorch Transformation Layers."""
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .initializers import MatrixInitializer
|
||||
|
||||
|
||||
def _precheck_initializer(initializer):
|
||||
if not isinstance(initializer, MatrixInitializer):
|
||||
emsg = f"`initializer` has to be some subtype of " \
|
||||
f"{MatrixInitializer}. " \
|
||||
f"You have provided: {initializer=} instead."
|
||||
raise TypeError(emsg)
|
||||
|
||||
|
||||
class Omega(torch.nn.Module):
|
||||
"""The Omega mapping used in GMLVQ."""
|
||||
def __init__(self,
|
||||
num_replicas=1,
|
||||
input_dim=None,
|
||||
latent_dim=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_weights=None):
|
||||
super().__init__()
|
||||
|
||||
if initialized_weights is not None:
|
||||
self._register_weights(initialized_weights)
|
||||
else:
|
||||
if num_replicas == 1:
|
||||
shape = (input_dim, latent_dim)
|
||||
else:
|
||||
shape = (num_replicas, input_dim, latent_dim)
|
||||
self._initialize_weights(shape, initializer)
|
||||
|
||||
def _register_weights(self, weights):
|
||||
self.register_parameter("_omega", Parameter(weights))
|
||||
|
||||
def _initialize_weights(self, shape, initializer):
|
||||
_precheck_initializer(initializer)
|
||||
_omega = initializer.generate(shape)
|
||||
self._register_weights(_omega)
|
||||
|
||||
def forward(self):
|
||||
return self._omega
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(omega): (shape: {tuple(self._omega.shape)})"
|
Loading…
Reference in New Issue
Block a user