[FEATURE] Add transforms
This commit is contained in:
@@ -7,3 +7,4 @@ from .initializers import *
|
||||
from .losses import *
|
||||
from .pooling import *
|
||||
from .similarities import *
|
||||
from .transforms import *
|
||||
|
@@ -313,7 +313,7 @@ class AbstractReasoningsInitializer(ABC):
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
return generate_end_hook(...)
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
|
||||
class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
@@ -380,6 +380,51 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
return reasonings
|
||||
|
||||
|
||||
# Transforms
|
||||
class AbstractTransformInitializer(ABC):
|
||||
"""Abstract class for all transform initializers."""
|
||||
...
|
||||
|
||||
|
||||
class AbstractLinearTransformInitializer(AbstractTransformInitializer):
|
||||
"""Abstract class for all linear transform initializers."""
|
||||
def __init__(self, out_dim_first: bool = False):
|
||||
self.out_dim_first = out_dim_first
|
||||
|
||||
def generate_end_hook(self, weights):
|
||||
if self.out_dim_first:
|
||||
weights = weights.permute(1, 0)
|
||||
return weights
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
|
||||
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with zeros."""
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.zeros(in_dim, out_dim)
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with ones."""
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.ones(in_dim, out_dim)
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
class EyeTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with the largest possible identity matrix."""
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.zeros(in_dim, out_dim)
|
||||
I = torch.eye(min(in_dim, out_dim))
|
||||
weights[:I.shape[0], :I.shape[1]] = I
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
# Aliases - Components
|
||||
CACI = ClassAwareCompInitializer
|
||||
DACI = DataAwareCompInitializer
|
||||
@@ -406,3 +451,8 @@ ORI = OnesReasoningsInitializer
|
||||
PPRI = PurePositiveReasoningsInitializer
|
||||
RRI = RandomReasoningsInitializer
|
||||
ZRI = ZerosReasoningsInitializer
|
||||
|
||||
# Aliases - Transforms
|
||||
Eye = EyeTransformInitializer
|
||||
OLTI = OnesLinearTransformInitializer
|
||||
ZLTI = ZerosLinearTransformInitializer
|
||||
|
Reference in New Issue
Block a user