Automatic formatting.

This commit is contained in:
Alexander Engelsberger 2021-04-27 15:43:10 +02:00
parent b0cd2de18e
commit ba537fe1d5
7 changed files with 34 additions and 19 deletions

View File

@ -2,7 +2,6 @@
# DATASET # DATASET
# #
import torch import torch
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler

View File

@ -1,4 +1,8 @@
from prototorch.components.components import Components, LabeledComponents, ReasoningComponents from prototorch.components.components import (
Components,
LabeledComponents,
ReasoningComponents,
)
__all__ = [ __all__ = [
"Components", "Components",

View File

@ -1,24 +1,28 @@
"""ProtoTorch components modules.""" """ProtoTorch components modules."""
from typing import Tuple
import warnings import warnings
from prototorch.components.initializers import EqualLabelInitializer, ZeroReasoningsInitializer
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from prototorch.functions.initializers import get_initializer from prototorch.components.initializers import (
EqualLabelInitializer,
ZeroReasoningsInitializer,
)
class Components(torch.nn.Module): class Components(torch.nn.Module):
""" """
Components is a set of learnable Tensors. Components is a set of learnable Tensors.
""" """
def __init__(self, def __init__(
number_of_components=None, self,
initializer=None, number_of_components=None,
*, initializer=None,
initialized_components=None, *,
dtype=torch.float32): initialized_components=None,
dtype=torch.float32,
):
super().__init__() super().__init__()
# Ignore all initialization settings if initialized_components is given. # Ignore all initialization settings if initialized_components is given.
@ -127,4 +131,4 @@ class ReasoningComponents(Components):
return self._reasonings.detach().cpu() return self._reasonings.detach().cpu()
def forward(self): def forward(self):
return super().forward(), self._reasonings return super().forward(), self._reasonings

View File

@ -1,6 +1,7 @@
import torch
from collections.abc import Iterable from collections.abc import Iterable
import torch
# Components # Components
class ComponentsInitializer: class ComponentsInitializer:

View File

@ -3,8 +3,11 @@
import numpy as np import numpy as np
import torch import torch
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape, from prototorch.functions.helper import (
equal_int_shape) _check_shapes,
_int_and_mixed_shape,
equal_int_shape,
)
def squared_euclidean_distance(x, y): def squared_euclidean_distance(x, y):

View File

@ -1,8 +1,7 @@
import torch import torch
from torch import nn from torch import nn
from prototorch.functions.distances import (euclidean_distance_matrix, from prototorch.functions.distances import euclidean_distance_matrix, tangent_distance
tangent_distance)
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
from prototorch.functions.normalization import orthogonalization from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D from prototorch.modules.prototypes import Prototypes1D

View File

@ -5,8 +5,13 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from prototorch.functions import (activations, competitions, distances, from prototorch.functions import (
initializers, losses) activations,
competitions,
distances,
initializers,
losses,
)
class TestActivations(unittest.TestCase): class TestActivations(unittest.TestCase):