From ba537fe1d516b46a1c32834bceb16e9a67f76610 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Tue, 27 Apr 2021 15:43:10 +0200 Subject: [PATCH] Automatic formatting. --- examples/new_components.py | 1 - prototorch/components/__init__.py | 6 +++++- prototorch/components/components.py | 24 ++++++++++++++---------- prototorch/components/initializers.py | 3 ++- prototorch/functions/distances.py | 7 +++++-- prototorch/modules/models.py | 3 +-- tests/test_functions.py | 9 +++++++-- 7 files changed, 34 insertions(+), 19 deletions(-) diff --git a/examples/new_components.py b/examples/new_components.py index d4a2555..2ef9ecc 100644 --- a/examples/new_components.py +++ b/examples/new_components.py @@ -2,7 +2,6 @@ # DATASET # import torch - from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler diff --git a/prototorch/components/__init__.py b/prototorch/components/__init__.py index 3ae0a51..df90b00 100644 --- a/prototorch/components/__init__.py +++ b/prototorch/components/__init__.py @@ -1,4 +1,8 @@ -from prototorch.components.components import Components, LabeledComponents, ReasoningComponents +from prototorch.components.components import ( + Components, + LabeledComponents, + ReasoningComponents, +) __all__ = [ "Components", diff --git a/prototorch/components/components.py b/prototorch/components/components.py index 267e22b..658a76f 100644 --- a/prototorch/components/components.py +++ b/prototorch/components/components.py @@ -1,24 +1,28 @@ """ProtoTorch components modules.""" -from typing import Tuple import warnings -from prototorch.components.initializers import EqualLabelInitializer, ZeroReasoningsInitializer + import torch from torch.nn.parameter import Parameter -from prototorch.functions.initializers import get_initializer +from prototorch.components.initializers import ( + EqualLabelInitializer, + ZeroReasoningsInitializer, +) class Components(torch.nn.Module): """ Components is a set of learnable Tensors. """ - def __init__(self, - number_of_components=None, - initializer=None, - *, - initialized_components=None, - dtype=torch.float32): + def __init__( + self, + number_of_components=None, + initializer=None, + *, + initialized_components=None, + dtype=torch.float32, + ): super().__init__() # Ignore all initialization settings if initialized_components is given. @@ -127,4 +131,4 @@ class ReasoningComponents(Components): return self._reasonings.detach().cpu() def forward(self): - return super().forward(), self._reasonings \ No newline at end of file + return super().forward(), self._reasonings diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py index c9ca22a..9c5d4e2 100644 --- a/prototorch/components/initializers.py +++ b/prototorch/components/initializers.py @@ -1,6 +1,7 @@ -import torch from collections.abc import Iterable +import torch + # Components class ComponentsInitializer: diff --git a/prototorch/functions/distances.py b/prototorch/functions/distances.py index 5949d0b..094270d 100644 --- a/prototorch/functions/distances.py +++ b/prototorch/functions/distances.py @@ -3,8 +3,11 @@ import numpy as np import torch -from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape, - equal_int_shape) +from prototorch.functions.helper import ( + _check_shapes, + _int_and_mixed_shape, + equal_int_shape, +) def squared_euclidean_distance(x, y): diff --git a/prototorch/modules/models.py b/prototorch/modules/models.py index 8753bce..2f9c14a 100644 --- a/prototorch/modules/models.py +++ b/prototorch/modules/models.py @@ -1,8 +1,7 @@ import torch from torch import nn -from prototorch.functions.distances import (euclidean_distance_matrix, - tangent_distance) +from prototorch.functions.distances import euclidean_distance_matrix, tangent_distance from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape from prototorch.functions.normalization import orthogonalization from prototorch.modules.prototypes import Prototypes1D diff --git a/tests/test_functions.py b/tests/test_functions.py index cdb39f2..4a67c6c 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -5,8 +5,13 @@ import unittest import numpy as np import torch -from prototorch.functions import (activations, competitions, distances, - initializers, losses) +from prototorch.functions import ( + activations, + competitions, + distances, + initializers, + losses, +) class TestActivations(unittest.TestCase):