Add minor cosmetic changes
This commit is contained in:
parent
a167565857
commit
dab91e471a
@ -1,4 +1,4 @@
|
|||||||
"""ProtoTorch abstract datasets
|
"""ProtoTorch abstract dataset classes.
|
||||||
|
|
||||||
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
|
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
class Dataset(torch.utils.data.Dataset):
|
class Dataset(torch.utils.data.Dataset):
|
||||||
"""Abstract dataset class to be inherited"""
|
"""Abstract dataset class to be inherited."""
|
||||||
_repr_indent = 2
|
_repr_indent = 2
|
||||||
|
|
||||||
def __init__(self, root):
|
def __init__(self, root):
|
||||||
@ -29,7 +29,7 @@ class Dataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class ProtoDataset(Dataset):
|
class ProtoDataset(Dataset):
|
||||||
"""Abstract dataset class to be inherited"""
|
"""Abstract dataset class to be inherited."""
|
||||||
training_file = 'training.pt'
|
training_file = 'training.pt'
|
||||||
test_file = 'test.pt'
|
test_file = 'test.pt'
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ def register_activation(function):
|
|||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
# @torch.jit.script
|
# @torch.jit.script
|
||||||
def identity(x, beta=torch.tensor([0])):
|
def identity(x, beta=torch.tensor(0)):
|
||||||
"""Identity activation function.
|
"""Identity activation function.
|
||||||
|
|
||||||
Definition:
|
Definition:
|
||||||
@ -27,7 +27,7 @@ def identity(x, beta=torch.tensor([0])):
|
|||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
# @torch.jit.script
|
# @torch.jit.script
|
||||||
def sigmoid_beta(x, beta=torch.tensor([10])):
|
def sigmoid_beta(x, beta=torch.tensor(10)):
|
||||||
r"""Sigmoid activation function with scaling.
|
r"""Sigmoid activation function with scaling.
|
||||||
|
|
||||||
Definition:
|
Definition:
|
||||||
@ -42,7 +42,7 @@ def sigmoid_beta(x, beta=torch.tensor([10])):
|
|||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
# @torch.jit.script
|
# @torch.jit.script
|
||||||
def swish_beta(x, beta=torch.tensor([10])):
|
def swish_beta(x, beta=torch.tensor(10)):
|
||||||
r"""Swish activation function with scaling.
|
r"""Swish activation function with scaling.
|
||||||
|
|
||||||
Definition:
|
Definition:
|
||||||
|
Loading…
Reference in New Issue
Block a user