Add more competition functions
This commit is contained in:
parent
8227525c82
commit
946cda00d2
@ -1,63 +1,115 @@
|
||||
"""ProtoTorch competition functions."""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def stratified_sum(
|
||||
value: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor, torch.LongTensor):
|
||||
"""Group-wise sum"""
|
||||
def stratified_sum_v1(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise sum."""
|
||||
uniques = labels.unique(sorted=True).tolist()
|
||||
labels = labels.tolist()
|
||||
|
||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
||||
labels = torch.LongTensor(list(map(key_val.get, labels)))
|
||||
|
||||
labels = labels.view(labels.size(0), 1).expand(-1, value.size(1))
|
||||
labels = labels.view(labels.size(0), 1).expand(-1, values.size(1))
|
||||
|
||||
unique_labels = labels.unique(dim=0)
|
||||
print(f"{labels=}")
|
||||
print(f"{unique_labels=}")
|
||||
|
||||
result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(
|
||||
0, labels, value)
|
||||
0, labels, values)
|
||||
return result.T
|
||||
|
||||
|
||||
def stratified_min(distances, labels):
|
||||
"""Group-wise minimum"""
|
||||
clabels = torch.unique(labels, dim=0)
|
||||
def stratify_with(values: torch.Tensor,
|
||||
labels: torch.LongTensor,
|
||||
fn: Callable,
|
||||
fill_value: float = 0.0) -> (torch.Tensor):
|
||||
"""Apply an arbitrary stratification strategy on the columns on `values`.
|
||||
|
||||
The outputs correspond to sorted labels.
|
||||
"""
|
||||
clabels = torch.unique(labels, dim=0, sorted=True)
|
||||
num_classes = clabels.size()[0]
|
||||
if distances.size()[1] == num_classes:
|
||||
# skip if only one prototype per class
|
||||
return distances
|
||||
batch_size = distances.size()[0]
|
||||
winning_distances = torch.zeros(num_classes, batch_size)
|
||||
inf = torch.full_like(distances.T, fill_value=float("inf"))
|
||||
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||
if values.size()[1] == num_classes:
|
||||
# skip if stratification is trivial
|
||||
return values
|
||||
batch_size = values.size()[0]
|
||||
winning_values = torch.zeros(num_classes, batch_size)
|
||||
filler = torch.full_like(values.T, fill_value=fill_value)
|
||||
for i, cl in enumerate(clabels):
|
||||
# cdists = distances.T[labels == cl]
|
||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||
if labels.ndim == 2:
|
||||
# if the labels are one-hot vectors
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
cdists = torch.where(matcher, distances.T, inf).T
|
||||
winning_distances[i] = torch.min(cdists, dim=1,
|
||||
keepdim=True).values.squeeze()
|
||||
cdists = torch.where(matcher, values.T, filler).T
|
||||
winning_values[i] = fn(cdists)
|
||||
if labels.ndim == 2:
|
||||
# Transpose to return with `batch_size` first and
|
||||
# reverse the columns to fix the ordering of the classes
|
||||
return torch.flip(winning_distances.T, dims=(1, ))
|
||||
return torch.flip(winning_values.T, dims=(1, ))
|
||||
|
||||
return winning_distances.T # return with `batch_size` first
|
||||
return winning_values.T # return with `batch_size` first
|
||||
|
||||
|
||||
def wtac(distances, labels):
|
||||
def stratified_sum(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise sum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
|
||||
fill_value=0.0)
|
||||
return winning_values
|
||||
|
||||
|
||||
def stratified_min(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise minimum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
|
||||
fill_value=float("inf"))
|
||||
return winning_values
|
||||
|
||||
|
||||
def stratified_max(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise maximum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
|
||||
fill_value=-1.0 * float("inf"))
|
||||
return winning_values
|
||||
|
||||
|
||||
def stratified_prod(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise maximum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
|
||||
fill_value=1.0)
|
||||
return winning_values
|
||||
|
||||
|
||||
def wtac(distances: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.LongTensor):
|
||||
winning_indices = torch.min(distances, dim=1).indices
|
||||
winning_labels = labels[winning_indices].squeeze()
|
||||
return winning_labels
|
||||
|
||||
|
||||
def knnc(distances, labels, k=1):
|
||||
def knnc(distances: torch.Tensor,
|
||||
labels: torch.LongTensor,
|
||||
k: int = 1) -> (torch.LongTensor):
|
||||
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
||||
# winning_labels = torch.mode(labels[winning_indices].squeeze(),
|
||||
# dim=1).values
|
||||
winning_labels = torch.mode(labels[winning_indices], dim=1).values
|
||||
return winning_labels
|
||||
|
@ -125,7 +125,7 @@ class TestCompetitions(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_min_simple(self):
|
||||
def test_stratified_min_trivial(self):
|
||||
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 1, 2])
|
||||
actual = competitions.stratified_min(d, labels)
|
||||
@ -135,6 +135,58 @@ class TestCompetitions(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_max(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||
actual = competitions.stratified_max(d, labels)
|
||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_max_one_hot(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 2, 1, 0])
|
||||
labels = torch.nn.functional.one_hot(labels, num_classes=3)
|
||||
actual = competitions.stratified_max(d, labels)
|
||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_sum(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.LongTensor([0, 0, 1, 2])
|
||||
actual = competitions.stratified_sum(d, labels)
|
||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_sum_one_hot(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
labels = torch.eye(3)[labels]
|
||||
actual = competitions.stratified_sum(d, labels)
|
||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_prod(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||
actual = competitions.stratified_prod(d, labels)
|
||||
desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_knnc_k1(self):
|
||||
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
|
Loading…
Reference in New Issue
Block a user