Automatic formatting.

This commit is contained in:
Alexander Engelsberger 2021-04-27 15:43:10 +02:00
parent b0cd2de18e
commit ba537fe1d5
7 changed files with 34 additions and 19 deletions

View File

@ -2,7 +2,6 @@
# DATASET
#
import torch
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler

View File

@ -1,4 +1,8 @@
from prototorch.components.components import Components, LabeledComponents, ReasoningComponents
from prototorch.components.components import (
Components,
LabeledComponents,
ReasoningComponents,
)
__all__ = [
"Components",

View File

@ -1,24 +1,28 @@
"""ProtoTorch components modules."""
from typing import Tuple
import warnings
from prototorch.components.initializers import EqualLabelInitializer, ZeroReasoningsInitializer
import torch
from torch.nn.parameter import Parameter
from prototorch.functions.initializers import get_initializer
from prototorch.components.initializers import (
EqualLabelInitializer,
ZeroReasoningsInitializer,
)
class Components(torch.nn.Module):
"""
Components is a set of learnable Tensors.
"""
def __init__(self,
number_of_components=None,
initializer=None,
*,
initialized_components=None,
dtype=torch.float32):
def __init__(
self,
number_of_components=None,
initializer=None,
*,
initialized_components=None,
dtype=torch.float32,
):
super().__init__()
# Ignore all initialization settings if initialized_components is given.
@ -127,4 +131,4 @@ class ReasoningComponents(Components):
return self._reasonings.detach().cpu()
def forward(self):
return super().forward(), self._reasonings
return super().forward(), self._reasonings

View File

@ -1,6 +1,7 @@
import torch
from collections.abc import Iterable
import torch
# Components
class ComponentsInitializer:

View File

@ -3,8 +3,11 @@
import numpy as np
import torch
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
equal_int_shape)
from prototorch.functions.helper import (
_check_shapes,
_int_and_mixed_shape,
equal_int_shape,
)
def squared_euclidean_distance(x, y):

View File

@ -1,8 +1,7 @@
import torch
from torch import nn
from prototorch.functions.distances import (euclidean_distance_matrix,
tangent_distance)
from prototorch.functions.distances import euclidean_distance_matrix, tangent_distance
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D

View File

@ -5,8 +5,13 @@ import unittest
import numpy as np
import torch
from prototorch.functions import (activations, competitions, distances,
initializers, losses)
from prototorch.functions import (
activations,
competitions,
distances,
initializers,
losses,
)
class TestActivations(unittest.TestCase):