Automatic formatting.
This commit is contained in:
parent
b0cd2de18e
commit
ba537fe1d5
@ -2,7 +2,6 @@
|
|||||||
# DATASET
|
# DATASET
|
||||||
#
|
#
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
from prototorch.components.components import Components, LabeledComponents, ReasoningComponents
|
from prototorch.components.components import (
|
||||||
|
Components,
|
||||||
|
LabeledComponents,
|
||||||
|
ReasoningComponents,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Components",
|
"Components",
|
||||||
|
@ -1,24 +1,28 @@
|
|||||||
"""ProtoTorch components modules."""
|
"""ProtoTorch components modules."""
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
import warnings
|
import warnings
|
||||||
from prototorch.components.initializers import EqualLabelInitializer, ZeroReasoningsInitializer
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from prototorch.functions.initializers import get_initializer
|
from prototorch.components.initializers import (
|
||||||
|
EqualLabelInitializer,
|
||||||
|
ZeroReasoningsInitializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Components(torch.nn.Module):
|
class Components(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Components is a set of learnable Tensors.
|
Components is a set of learnable Tensors.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(
|
||||||
number_of_components=None,
|
self,
|
||||||
initializer=None,
|
number_of_components=None,
|
||||||
*,
|
initializer=None,
|
||||||
initialized_components=None,
|
*,
|
||||||
dtype=torch.float32):
|
initialized_components=None,
|
||||||
|
dtype=torch.float32,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Ignore all initialization settings if initialized_components is given.
|
# Ignore all initialization settings if initialized_components is given.
|
||||||
@ -127,4 +131,4 @@ class ReasoningComponents(Components):
|
|||||||
return self._reasonings.detach().cpu()
|
return self._reasonings.detach().cpu()
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return super().forward(), self._reasonings
|
return super().forward(), self._reasonings
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
# Components
|
# Components
|
||||||
class ComponentsInitializer:
|
class ComponentsInitializer:
|
||||||
|
@ -3,8 +3,11 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
from prototorch.functions.helper import (
|
||||||
equal_int_shape)
|
_check_shapes,
|
||||||
|
_int_and_mixed_shape,
|
||||||
|
equal_int_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def squared_euclidean_distance(x, y):
|
def squared_euclidean_distance(x, y):
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from prototorch.functions.distances import (euclidean_distance_matrix,
|
from prototorch.functions.distances import euclidean_distance_matrix, tangent_distance
|
||||||
tangent_distance)
|
|
||||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||||
from prototorch.functions.normalization import orthogonalization
|
from prototorch.functions.normalization import orthogonalization
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
|
@ -5,8 +5,13 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions import (activations, competitions, distances,
|
from prototorch.functions import (
|
||||||
initializers, losses)
|
activations,
|
||||||
|
competitions,
|
||||||
|
distances,
|
||||||
|
initializers,
|
||||||
|
losses,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestActivations(unittest.TestCase):
|
class TestActivations(unittest.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user