6 Commits

Author SHA1 Message Date
Alexander Engelsberger
09c80e2d54 Merge branch 'master' into kernel_distances 2021-05-11 16:10:56 +02:00
Alexander Engelsberger
65e0637b17 Fix RBF Kernel Dimensions. 2021-04-27 17:58:05 +02:00
Alexander Engelsberger
209f9e641b Fix kernel dimensions. 2021-04-27 16:56:56 +02:00
Alexander Engelsberger
ba537fe1d5 Automatic formatting. 2021-04-27 15:43:10 +02:00
Alexander Engelsberger
b0cd2de18e Batch Kernel. [Ineficient] 2021-04-27 15:38:34 +02:00
Alexander Engelsberger
7d353f5b5a Kernel Distances. 2021-04-27 12:06:15 +02:00
12 changed files with 239 additions and 20 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.3
current_version = 0.4.2
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags
#
release = "0.4.3"
release = "0.4.2"
# -- General configuration ---------------------------------------------------

View File

@@ -1,7 +1,7 @@
"""ProtoTorch package."""
# Core Setup
__version__ = "0.4.3"
__version__ = "0.4.2"
__all_core__ = [
"datasets",

View File

@@ -67,9 +67,8 @@ class LabeledComponents(Components):
*,
initialized_components=None):
if initialized_components is not None:
components, component_labels = initialized_components
super().__init__(initialized_components=components)
self._labels = component_labels
super().__init__(initialized_components=initialized_components[0])
self._labels = initialized_components[1]
else:
self._initialize_labels(distribution)
super().__init__(number_of_components=len(self._labels),

View File

@@ -1,6 +1,11 @@
"""ProtoTorch datasets."""
from .abstract import NumpyDataset
from .iris import Iris
from .spiral import Spiral
from .tecator import Tecator
__all__ = [
"NumpyDataset",
"Spiral",
"Tecator",
]

View File

@@ -3,6 +3,7 @@
import torch
# @torch.jit.script
def stratified_min(distances, labels):
clabels = torch.unique(labels, dim=0)
nclasses = clabels.size()[0]
@@ -30,14 +31,15 @@ def stratified_min(distances, labels):
return winning_distances.T # return with `batch_size` first
# @torch.jit.script
def wtac(distances, labels):
winning_indices = torch.min(distances, dim=1).indices
winning_labels = labels[winning_indices].squeeze()
return winning_labels
def knnc(distances, labels, k=1):
winning_indices = torch.topk(-distances, k=k, dim=1).indices
winning_labels = torch.mode(labels[winning_indices].squeeze(),
dim=1).values
# @torch.jit.script
def knnc(distances, labels, k):
winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices
winning_labels = labels[winning_indices].squeeze()
return winning_labels

View File

@@ -3,8 +3,11 @@
import numpy as np
import torch
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
equal_int_shape)
from prototorch.functions.helper import (
_check_shapes,
_int_and_mixed_shape,
equal_int_shape,
)
def squared_euclidean_distance(x, y):
@@ -261,5 +264,86 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
return diss.permute([1, 0, 2]).squeeze(-1)
class KernelDistance:
r"""Kernel Distance
Distance based on a kernel function.
"""
def __init__(self, kernel_fn):
self.kernel_fn = kernel_fn
def __call__(self, x_batch: torch.Tensor, y_batch: torch.Tensor):
return self._single_call(x_batch, y_batch)
def _single_call(self, x, y):
remove_dims = []
if len(x.shape) == 1:
x = x.unsqueeze(0)
remove_dims.append(0)
if len(y.shape) == 1:
y = y.unsqueeze(0)
remove_dims.append(-1)
output = self.kernel_fn(x, x).diag().unsqueeze(1) - 2 * self.kernel_fn(
x, y) + self.kernel_fn(y, y).diag()
for dim in remove_dims:
output.squeeze_(dim)
return torch.sqrt(output)
class BatchKernelDistance:
r"""Kernel Distance
Distance based on a kernel function.
"""
def __init__(self, kernel_fn):
self.kernel_fn = kernel_fn
def __call__(self, x_batch: torch.Tensor, y_batch: torch.Tensor):
remove_dims = 0
# Extend Single inputs
if len(x_batch.shape) == 1:
x_batch = x_batch.unsqueeze(0)
remove_dims += 1
if len(y_batch.shape) == 1:
y_batch = y_batch.unsqueeze(0)
remove_dims += 1
# Loop over batches
output = torch.FloatTensor(len(x_batch), len(y_batch))
for i, x in enumerate(x_batch):
for j, y in enumerate(y_batch):
output[i][j] = self._single_call(x, y)
for _ in range(remove_dims):
output.squeeze_(0)
return output
def _single_call(self, x, y):
kappa_xx = self.kernel_fn(x, x)
kappa_xy = self.kernel_fn(x, y)
kappa_yy = self.kernel_fn(y, y)
squared_distance = kappa_xx - 2 * kappa_xy + kappa_yy
return torch.sqrt(squared_distance)
class SquaredKernelDistance(KernelDistance):
r"""Squared Kernel Distance
Kernel distance without final squareroot.
"""
def single_call(self, x, y):
kappa_xx = self.kernel_fn(x, x)
kappa_xy = self.kernel_fn(x, y)
kappa_yy = self.kernel_fn(y, y)
return kappa_xx - 2 * kappa_xy + kappa_yy
# Aliases
sed = squared_euclidean_distance
sed = squared_euclidean_distance

View File

@@ -0,0 +1,28 @@
"""
Experimental Kernels
"""
import torch
class ExplicitKernel:
def __init__(self, projection=torch.nn.Identity()):
self.projection = projection
def __call__(self, x, y):
return self.projection(x) @ self.projection(y).T
class RadialBasisFunctionKernel:
def __init__(self, sigma) -> None:
self.s2 = sigma * sigma
def __call__(self, x, y):
remove_dim = False
if len(x.shape) > 1:
x = x.unsqueeze(1)
remove_dim = True
output = torch.exp(-torch.sum((x - y)**2, dim=-1) / (2 * self.s2))
if remove_dim:
output = output.squeeze(1)
return output

View File

@@ -1,8 +1,7 @@
import torch
from torch import nn
from prototorch.functions.distances import (euclidean_distance_matrix,
tangent_distance)
from prototorch.functions.distances import euclidean_distance_matrix, tangent_distance
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D

View File

@@ -23,7 +23,6 @@ INSTALL_REQUIRES = [
]
DATASETS = [
"requests",
"sklearn",
"tqdm",
]
DEV = ["bumpversion"]
@@ -43,7 +42,7 @@ ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
setup(
name="prototorch",
version="0.4.3",
version="0.4.2",
description="Highly extensible, GPU-supported "
"Learning Vector Quantization (LVQ) toolbox "
"built using PyTorch and its nn API.",

View File

@@ -5,8 +5,13 @@ import unittest
import numpy as np
import torch
from prototorch.functions import (activations, competitions, distances,
initializers, losses)
from prototorch.functions import (
activations,
competitions,
distances,
initializers,
losses,
)
class TestActivations(unittest.TestCase):

98
tests/test_kernels.py Normal file
View File

@@ -0,0 +1,98 @@
"""ProtoTorch kernels test suite."""
import unittest
import numpy as np
import torch
from prototorch.functions.distances import KernelDistance
from prototorch.functions.kernels import ExplicitKernel, RadialBasisFunctionKernel
class TestExplicitKernel(unittest.TestCase):
def setUp(self):
self.single_x = torch.randn(1024)
self.single_y = torch.randn(1024)
self.batch_x = torch.randn(32, 1024)
self.batch_y = torch.randn(32, 1024)
def test_single_values(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.single_x, self.single_y).shape, torch.Size([]))
def test_single_batch(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.single_x, self.batch_y).shape, torch.Size([32]))
def test_batch_single(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.batch_x, self.single_y).shape, torch.Size([32]))
def test_batch_values(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))
class TestRadialBasisFunctionKernel(unittest.TestCase):
def setUp(self):
self.single_x = torch.randn(1024)
self.single_y = torch.randn(1024)
self.batch_x = torch.randn(32, 1024)
self.batch_y = torch.randn(32, 1024)
def test_single_values(self):
kernel = RadialBasisFunctionKernel(1)
self.assertEqual(
kernel(self.single_x, self.single_y).shape, torch.Size([]))
def test_single_batch(self):
kernel = RadialBasisFunctionKernel(1)
self.assertEqual(
kernel(self.single_x, self.batch_y).shape, torch.Size([32]))
def test_batch_single(self):
kernel = RadialBasisFunctionKernel(1)
self.assertEqual(
kernel(self.batch_x, self.single_y).shape, torch.Size([32]))
def test_batch_values(self):
kernel = RadialBasisFunctionKernel(1)
self.assertEqual(
kernel(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))
class TestKernelDistance(unittest.TestCase):
def setUp(self):
self.single_x = torch.randn(1024)
self.single_y = torch.randn(1024)
self.batch_x = torch.randn(32, 1024)
self.batch_y = torch.randn(32, 1024)
self.kernel = ExplicitKernel()
def test_single_values(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.single_x, self.single_y).shape, torch.Size([]))
def test_single_batch(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.single_x, self.batch_y).shape, torch.Size([32]))
def test_batch_single(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.batch_x, self.single_y).shape, torch.Size([32]))
def test_batch_values(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))