Automatic formatting.
This commit is contained in:
parent
b0cd2de18e
commit
ba537fe1d5
@ -2,7 +2,6 @@
|
||||
# DATASET
|
||||
#
|
||||
import torch
|
||||
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
|
@ -1,4 +1,8 @@
|
||||
from prototorch.components.components import Components, LabeledComponents, ReasoningComponents
|
||||
from prototorch.components.components import (
|
||||
Components,
|
||||
LabeledComponents,
|
||||
ReasoningComponents,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Components",
|
||||
|
@ -1,24 +1,28 @@
|
||||
"""ProtoTorch components modules."""
|
||||
|
||||
from typing import Tuple
|
||||
import warnings
|
||||
from prototorch.components.initializers import EqualLabelInitializer, ZeroReasoningsInitializer
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
from prototorch.components.initializers import (
|
||||
EqualLabelInitializer,
|
||||
ZeroReasoningsInitializer,
|
||||
)
|
||||
|
||||
|
||||
class Components(torch.nn.Module):
|
||||
"""
|
||||
Components is a set of learnable Tensors.
|
||||
"""
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
number_of_components=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None,
|
||||
dtype=torch.float32):
|
||||
dtype=torch.float32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Ignore all initialization settings if initialized_components is given.
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Components
|
||||
class ComponentsInitializer:
|
||||
|
@ -3,8 +3,11 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
||||
equal_int_shape)
|
||||
from prototorch.functions.helper import (
|
||||
_check_shapes,
|
||||
_int_and_mixed_shape,
|
||||
equal_int_shape,
|
||||
)
|
||||
|
||||
|
||||
def squared_euclidean_distance(x, y):
|
||||
|
@ -1,8 +1,7 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from prototorch.functions.distances import (euclidean_distance_matrix,
|
||||
tangent_distance)
|
||||
from prototorch.functions.distances import euclidean_distance_matrix, tangent_distance
|
||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
@ -5,8 +5,13 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.functions import (activations, competitions, distances,
|
||||
initializers, losses)
|
||||
from prototorch.functions import (
|
||||
activations,
|
||||
competitions,
|
||||
distances,
|
||||
initializers,
|
||||
losses,
|
||||
)
|
||||
|
||||
|
||||
class TestActivations(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user