[FEATURE] Add transforms

This commit is contained in:
Jensun Ravichandran
2021-06-16 21:53:36 +02:00
parent c95f91cc29
commit 7a6da0c5fc
4 changed files with 105 additions and 2 deletions

View File

@@ -7,3 +7,4 @@ from .initializers import *
from .losses import *
from .pooling import *
from .similarities import *
from .transforms import *

View File

@@ -313,7 +313,7 @@ class AbstractReasoningsInitializer(ABC):
@abstractmethod
def generate(self, distribution: Union[dict, list, tuple]):
...
return generate_end_hook(...)
return self.generate_end_hook(...)
class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
@@ -380,6 +380,51 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
return reasonings
# Transforms
class AbstractTransformInitializer(ABC):
"""Abstract class for all transform initializers."""
...
class AbstractLinearTransformInitializer(AbstractTransformInitializer):
"""Abstract class for all linear transform initializers."""
def __init__(self, out_dim_first: bool = False):
self.out_dim_first = out_dim_first
def generate_end_hook(self, weights):
if self.out_dim_first:
weights = weights.permute(1, 0)
return weights
@abstractmethod
def generate(self, in_dim: int, out_dim: int):
...
return self.generate_end_hook(...)
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with zeros."""
def generate(self, in_dim: int, out_dim: int):
weights = torch.zeros(in_dim, out_dim)
return self.generate_end_hook(weights)
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with ones."""
def generate(self, in_dim: int, out_dim: int):
weights = torch.ones(in_dim, out_dim)
return self.generate_end_hook(weights)
class EyeTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with the largest possible identity matrix."""
def generate(self, in_dim: int, out_dim: int):
weights = torch.zeros(in_dim, out_dim)
I = torch.eye(min(in_dim, out_dim))
weights[:I.shape[0], :I.shape[1]] = I
return self.generate_end_hook(weights)
# Aliases - Components
CACI = ClassAwareCompInitializer
DACI = DataAwareCompInitializer
@@ -406,3 +451,8 @@ ORI = OnesReasoningsInitializer
PPRI = PurePositiveReasoningsInitializer
RRI = RandomReasoningsInitializer
ZRI = ZerosReasoningsInitializer
# Aliases - Transforms
Eye = EyeTransformInitializer
OLTI = OnesLinearTransformInitializer
ZLTI = ZerosLinearTransformInitializer