Add LambdaLayer

This commit is contained in:
Jensun Ravichandran 2021-05-31 16:47:20 +02:00
parent e61ae73749
commit 8227525c82
2 changed files with 19 additions and 1 deletions

View File

@ -1 +1,3 @@
"""ProtoTorch modules."""
"""ProtoTorch modules."""
from .utils import LambdaLayer

View File

@ -0,0 +1,16 @@
"""ProtoTorch utility modules."""
import torch
class LambdaLayer(torch.nn.Module):
def __init__(self, fn, name=None):
super().__init__()
self.fn = fn
self.name = name or fn.__name__ # lambda fns get <lambda>
def forward(self, *args, **kwargs):
return self.fn(*args, **kwargs)
def extra_repr(self):
return self.name