Add more competition functions

This commit is contained in:
Jensun Ravichandran 2021-06-01 12:37:21 +02:00
parent 8227525c82
commit 946cda00d2
2 changed files with 131 additions and 27 deletions

View File

@ -1,63 +1,115 @@
"""ProtoTorch competition functions.""" """ProtoTorch competition functions."""
from typing import Callable
import torch import torch
def stratified_sum( def stratified_sum_v1(values: torch.Tensor,
value: torch.Tensor, labels: torch.LongTensor) -> (torch.Tensor):
labels: torch.LongTensor) -> (torch.Tensor, torch.LongTensor): """Group-wise sum."""
"""Group-wise sum"""
uniques = labels.unique(sorted=True).tolist() uniques = labels.unique(sorted=True).tolist()
labels = labels.tolist() labels = labels.tolist()
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))} key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
labels = torch.LongTensor(list(map(key_val.get, labels))) labels = torch.LongTensor(list(map(key_val.get, labels)))
labels = labels.view(labels.size(0), 1).expand(-1, value.size(1)) labels = labels.view(labels.size(0), 1).expand(-1, values.size(1))
unique_labels = labels.unique(dim=0) unique_labels = labels.unique(dim=0)
print(f"{labels=}")
print(f"{unique_labels=}")
result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_( result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(
0, labels, value) 0, labels, values)
return result.T return result.T
def stratified_min(distances, labels): def stratify_with(values: torch.Tensor,
"""Group-wise minimum""" labels: torch.LongTensor,
clabels = torch.unique(labels, dim=0) fn: Callable,
fill_value: float = 0.0) -> (torch.Tensor):
"""Apply an arbitrary stratification strategy on the columns on `values`.
The outputs correspond to sorted labels.
"""
clabels = torch.unique(labels, dim=0, sorted=True)
num_classes = clabels.size()[0] num_classes = clabels.size()[0]
if distances.size()[1] == num_classes: if values.size()[1] == num_classes:
# skip if only one prototype per class # skip if stratification is trivial
return distances return values
batch_size = distances.size()[0] batch_size = values.size()[0]
winning_distances = torch.zeros(num_classes, batch_size) winning_values = torch.zeros(num_classes, batch_size)
inf = torch.full_like(distances.T, fill_value=float("inf")) filler = torch.full_like(values.T, fill_value=fill_value)
# distances_to_wpluses = torch.where(matcher, distances, inf)
for i, cl in enumerate(clabels): for i, cl in enumerate(clabels):
# cdists = distances.T[labels == cl]
matcher = torch.eq(labels.unsqueeze(dim=1), cl) matcher = torch.eq(labels.unsqueeze(dim=1), cl)
if labels.ndim == 2: if labels.ndim == 2:
# if the labels are one-hot vectors # if the labels are one-hot vectors
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
cdists = torch.where(matcher, distances.T, inf).T cdists = torch.where(matcher, values.T, filler).T
winning_distances[i] = torch.min(cdists, dim=1, winning_values[i] = fn(cdists)
keepdim=True).values.squeeze()
if labels.ndim == 2: if labels.ndim == 2:
# Transpose to return with `batch_size` first and # Transpose to return with `batch_size` first and
# reverse the columns to fix the ordering of the classes # reverse the columns to fix the ordering of the classes
return torch.flip(winning_distances.T, dims=(1, )) return torch.flip(winning_values.T, dims=(1, ))
return winning_distances.T # return with `batch_size` first return winning_values.T # return with `batch_size` first
def wtac(distances, labels): def stratified_sum(values: torch.Tensor,
labels: torch.LongTensor) -> (torch.Tensor):
"""Group-wise sum."""
winning_values = stratify_with(
values,
labels,
fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
fill_value=0.0)
return winning_values
def stratified_min(values: torch.Tensor,
labels: torch.LongTensor) -> (torch.Tensor):
"""Group-wise minimum."""
winning_values = stratify_with(
values,
labels,
fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
fill_value=float("inf"))
return winning_values
def stratified_max(values: torch.Tensor,
labels: torch.LongTensor) -> (torch.Tensor):
"""Group-wise maximum."""
winning_values = stratify_with(
values,
labels,
fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
fill_value=-1.0 * float("inf"))
return winning_values
def stratified_prod(values: torch.Tensor,
labels: torch.LongTensor) -> (torch.Tensor):
"""Group-wise maximum."""
winning_values = stratify_with(
values,
labels,
fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
fill_value=1.0)
return winning_values
def wtac(distances: torch.Tensor,
labels: torch.LongTensor) -> (torch.LongTensor):
winning_indices = torch.min(distances, dim=1).indices winning_indices = torch.min(distances, dim=1).indices
winning_labels = labels[winning_indices].squeeze() winning_labels = labels[winning_indices].squeeze()
return winning_labels return winning_labels
def knnc(distances, labels, k=1): def knnc(distances: torch.Tensor,
labels: torch.LongTensor,
k: int = 1) -> (torch.LongTensor):
winning_indices = torch.topk(-distances, k=k, dim=1).indices winning_indices = torch.topk(-distances, k=k, dim=1).indices
# winning_labels = torch.mode(labels[winning_indices].squeeze(),
# dim=1).values
winning_labels = torch.mode(labels[winning_indices], dim=1).values winning_labels = torch.mode(labels[winning_indices], dim=1).values
return winning_labels return winning_labels

View File

@ -125,7 +125,7 @@ class TestCompetitions(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_stratified_min_simple(self): def test_stratified_min_trivial(self):
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]]) d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
labels = torch.tensor([0, 1, 2]) labels = torch.tensor([0, 1, 2])
actual = competitions.stratified_min(d, labels) actual = competitions.stratified_min(d, labels)
@ -135,6 +135,58 @@ class TestCompetitions(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_stratified_max(self):
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
labels = torch.tensor([0, 0, 3, 2, 0])
actual = competitions.stratified_max(d, labels)
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_max_one_hot(self):
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
labels = torch.tensor([0, 0, 2, 1, 0])
labels = torch.nn.functional.one_hot(labels, num_classes=3)
actual = competitions.stratified_max(d, labels)
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_sum(self):
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
labels = torch.LongTensor([0, 0, 1, 2])
actual = competitions.stratified_sum(d, labels)
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_sum_one_hot(self):
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
labels = torch.tensor([0, 0, 1, 2])
labels = torch.eye(3)[labels]
actual = competitions.stratified_sum(d, labels)
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_prod(self):
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
labels = torch.tensor([0, 0, 3, 2, 0])
actual = competitions.stratified_prod(d, labels)
desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_knnc_k1(self): def test_knnc_k1(self):
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
labels = torch.tensor([0, 1, 2, 3]) labels = torch.tensor([0, 1, 2, 3])