Add more competition functions
This commit is contained in:
		@@ -1,63 +1,115 @@
 | 
				
			|||||||
"""ProtoTorch competition functions."""
 | 
					"""ProtoTorch competition functions."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Callable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def stratified_sum(
 | 
					def stratified_sum_v1(values: torch.Tensor,
 | 
				
			||||||
        value: torch.Tensor,
 | 
					                      labels: torch.LongTensor) -> (torch.Tensor):
 | 
				
			||||||
        labels: torch.LongTensor) -> (torch.Tensor, torch.LongTensor):
 | 
					    """Group-wise sum."""
 | 
				
			||||||
    """Group-wise sum"""
 | 
					 | 
				
			||||||
    uniques = labels.unique(sorted=True).tolist()
 | 
					    uniques = labels.unique(sorted=True).tolist()
 | 
				
			||||||
    labels = labels.tolist()
 | 
					    labels = labels.tolist()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
 | 
					    key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
 | 
				
			||||||
    labels = torch.LongTensor(list(map(key_val.get, labels)))
 | 
					    labels = torch.LongTensor(list(map(key_val.get, labels)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    labels = labels.view(labels.size(0), 1).expand(-1, value.size(1))
 | 
					    labels = labels.view(labels.size(0), 1).expand(-1, values.size(1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    unique_labels = labels.unique(dim=0)
 | 
					    unique_labels = labels.unique(dim=0)
 | 
				
			||||||
 | 
					    print(f"{labels=}")
 | 
				
			||||||
 | 
					    print(f"{unique_labels=}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(
 | 
					    result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(
 | 
				
			||||||
        0, labels, value)
 | 
					        0, labels, values)
 | 
				
			||||||
    return result.T
 | 
					    return result.T
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def stratified_min(distances, labels):
 | 
					def stratify_with(values: torch.Tensor,
 | 
				
			||||||
    """Group-wise minimum"""
 | 
					                  labels: torch.LongTensor,
 | 
				
			||||||
    clabels = torch.unique(labels, dim=0)
 | 
					                  fn: Callable,
 | 
				
			||||||
 | 
					                  fill_value: float = 0.0) -> (torch.Tensor):
 | 
				
			||||||
 | 
					    """Apply an arbitrary stratification strategy on the columns on `values`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The outputs correspond to sorted labels.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    clabels = torch.unique(labels, dim=0, sorted=True)
 | 
				
			||||||
    num_classes = clabels.size()[0]
 | 
					    num_classes = clabels.size()[0]
 | 
				
			||||||
    if distances.size()[1] == num_classes:
 | 
					    if values.size()[1] == num_classes:
 | 
				
			||||||
        # skip if only one prototype per class
 | 
					        # skip if stratification is trivial
 | 
				
			||||||
        return distances
 | 
					        return values
 | 
				
			||||||
    batch_size = distances.size()[0]
 | 
					    batch_size = values.size()[0]
 | 
				
			||||||
    winning_distances = torch.zeros(num_classes, batch_size)
 | 
					    winning_values = torch.zeros(num_classes, batch_size)
 | 
				
			||||||
    inf = torch.full_like(distances.T, fill_value=float("inf"))
 | 
					    filler = torch.full_like(values.T, fill_value=fill_value)
 | 
				
			||||||
    # distances_to_wpluses = torch.where(matcher, distances, inf)
 | 
					 | 
				
			||||||
    for i, cl in enumerate(clabels):
 | 
					    for i, cl in enumerate(clabels):
 | 
				
			||||||
        # cdists = distances.T[labels == cl]
 | 
					 | 
				
			||||||
        matcher = torch.eq(labels.unsqueeze(dim=1), cl)
 | 
					        matcher = torch.eq(labels.unsqueeze(dim=1), cl)
 | 
				
			||||||
        if labels.ndim == 2:
 | 
					        if labels.ndim == 2:
 | 
				
			||||||
            # if the labels are one-hot vectors
 | 
					            # if the labels are one-hot vectors
 | 
				
			||||||
            matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
 | 
					            matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
 | 
				
			||||||
        cdists = torch.where(matcher, distances.T, inf).T
 | 
					        cdists = torch.where(matcher, values.T, filler).T
 | 
				
			||||||
        winning_distances[i] = torch.min(cdists, dim=1,
 | 
					        winning_values[i] = fn(cdists)
 | 
				
			||||||
                                         keepdim=True).values.squeeze()
 | 
					 | 
				
			||||||
    if labels.ndim == 2:
 | 
					    if labels.ndim == 2:
 | 
				
			||||||
        # Transpose to return with `batch_size` first and
 | 
					        # Transpose to return with `batch_size` first and
 | 
				
			||||||
        # reverse the columns to fix the ordering of the classes
 | 
					        # reverse the columns to fix the ordering of the classes
 | 
				
			||||||
        return torch.flip(winning_distances.T, dims=(1, ))
 | 
					        return torch.flip(winning_values.T, dims=(1, ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return winning_distances.T  # return with `batch_size` first
 | 
					    return winning_values.T  # return with `batch_size` first
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def wtac(distances, labels):
 | 
					def stratified_sum(values: torch.Tensor,
 | 
				
			||||||
 | 
					                   labels: torch.LongTensor) -> (torch.Tensor):
 | 
				
			||||||
 | 
					    """Group-wise sum."""
 | 
				
			||||||
 | 
					    winning_values = stratify_with(
 | 
				
			||||||
 | 
					        values,
 | 
				
			||||||
 | 
					        labels,
 | 
				
			||||||
 | 
					        fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
 | 
				
			||||||
 | 
					        fill_value=0.0)
 | 
				
			||||||
 | 
					    return winning_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def stratified_min(values: torch.Tensor,
 | 
				
			||||||
 | 
					                   labels: torch.LongTensor) -> (torch.Tensor):
 | 
				
			||||||
 | 
					    """Group-wise minimum."""
 | 
				
			||||||
 | 
					    winning_values = stratify_with(
 | 
				
			||||||
 | 
					        values,
 | 
				
			||||||
 | 
					        labels,
 | 
				
			||||||
 | 
					        fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
 | 
				
			||||||
 | 
					        fill_value=float("inf"))
 | 
				
			||||||
 | 
					    return winning_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def stratified_max(values: torch.Tensor,
 | 
				
			||||||
 | 
					                   labels: torch.LongTensor) -> (torch.Tensor):
 | 
				
			||||||
 | 
					    """Group-wise maximum."""
 | 
				
			||||||
 | 
					    winning_values = stratify_with(
 | 
				
			||||||
 | 
					        values,
 | 
				
			||||||
 | 
					        labels,
 | 
				
			||||||
 | 
					        fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
 | 
				
			||||||
 | 
					        fill_value=-1.0 * float("inf"))
 | 
				
			||||||
 | 
					    return winning_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def stratified_prod(values: torch.Tensor,
 | 
				
			||||||
 | 
					                    labels: torch.LongTensor) -> (torch.Tensor):
 | 
				
			||||||
 | 
					    """Group-wise maximum."""
 | 
				
			||||||
 | 
					    winning_values = stratify_with(
 | 
				
			||||||
 | 
					        values,
 | 
				
			||||||
 | 
					        labels,
 | 
				
			||||||
 | 
					        fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
 | 
				
			||||||
 | 
					        fill_value=1.0)
 | 
				
			||||||
 | 
					    return winning_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def wtac(distances: torch.Tensor,
 | 
				
			||||||
 | 
					         labels: torch.LongTensor) -> (torch.LongTensor):
 | 
				
			||||||
    winning_indices = torch.min(distances, dim=1).indices
 | 
					    winning_indices = torch.min(distances, dim=1).indices
 | 
				
			||||||
    winning_labels = labels[winning_indices].squeeze()
 | 
					    winning_labels = labels[winning_indices].squeeze()
 | 
				
			||||||
    return winning_labels
 | 
					    return winning_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def knnc(distances, labels, k=1):
 | 
					def knnc(distances: torch.Tensor,
 | 
				
			||||||
 | 
					         labels: torch.LongTensor,
 | 
				
			||||||
 | 
					         k: int = 1) -> (torch.LongTensor):
 | 
				
			||||||
    winning_indices = torch.topk(-distances, k=k, dim=1).indices
 | 
					    winning_indices = torch.topk(-distances, k=k, dim=1).indices
 | 
				
			||||||
    # winning_labels = torch.mode(labels[winning_indices].squeeze(),
 | 
					 | 
				
			||||||
    #                             dim=1).values
 | 
					 | 
				
			||||||
    winning_labels = torch.mode(labels[winning_indices], dim=1).values
 | 
					    winning_labels = torch.mode(labels[winning_indices], dim=1).values
 | 
				
			||||||
    return winning_labels
 | 
					    return winning_labels
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -125,7 +125,7 @@ class TestCompetitions(unittest.TestCase):
 | 
				
			|||||||
                                                        decimal=5)
 | 
					                                                        decimal=5)
 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_stratified_min_simple(self):
 | 
					    def test_stratified_min_trivial(self):
 | 
				
			||||||
        d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
 | 
					        d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
 | 
				
			||||||
        labels = torch.tensor([0, 1, 2])
 | 
					        labels = torch.tensor([0, 1, 2])
 | 
				
			||||||
        actual = competitions.stratified_min(d, labels)
 | 
					        actual = competitions.stratified_min(d, labels)
 | 
				
			||||||
@@ -135,6 +135,58 @@ class TestCompetitions(unittest.TestCase):
 | 
				
			|||||||
                                                        decimal=5)
 | 
					                                                        decimal=5)
 | 
				
			||||||
        self.assertIsNone(mismatch)
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stratified_max(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 0, 3, 2, 0])
 | 
				
			||||||
 | 
					        actual = competitions.stratified_max(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stratified_max_one_hot(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 0, 2, 1, 0])
 | 
				
			||||||
 | 
					        labels = torch.nn.functional.one_hot(labels, num_classes=3)
 | 
				
			||||||
 | 
					        actual = competitions.stratified_max(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stratified_sum(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
 | 
				
			||||||
 | 
					        labels = torch.LongTensor([0, 0, 1, 2])
 | 
				
			||||||
 | 
					        actual = competitions.stratified_sum(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stratified_sum_one_hot(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 0, 1, 2])
 | 
				
			||||||
 | 
					        labels = torch.eye(3)[labels]
 | 
				
			||||||
 | 
					        actual = competitions.stratified_sum(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_stratified_prod(self):
 | 
				
			||||||
 | 
					        d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
 | 
				
			||||||
 | 
					        labels = torch.tensor([0, 0, 3, 2, 0])
 | 
				
			||||||
 | 
					        actual = competitions.stratified_prod(d, labels)
 | 
				
			||||||
 | 
					        desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
 | 
				
			||||||
 | 
					        mismatch = np.testing.assert_array_almost_equal(actual,
 | 
				
			||||||
 | 
					                                                        desired,
 | 
				
			||||||
 | 
					                                                        decimal=5)
 | 
				
			||||||
 | 
					        self.assertIsNone(mismatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_knnc_k1(self):
 | 
					    def test_knnc_k1(self):
 | 
				
			||||||
        d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
 | 
					        d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
 | 
				
			||||||
        labels = torch.tensor([0, 1, 2, 3])
 | 
					        labels = torch.tensor([0, 1, 2, 3])
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user