Add more competition functions
This commit is contained in:
		| @@ -1,63 +1,115 @@ | |||||||
| """ProtoTorch competition functions.""" | """ProtoTorch competition functions.""" | ||||||
|  |  | ||||||
|  | from typing import Callable | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
|  |  | ||||||
|  |  | ||||||
| def stratified_sum( | def stratified_sum_v1(values: torch.Tensor, | ||||||
|         value: torch.Tensor, |                       labels: torch.LongTensor) -> (torch.Tensor): | ||||||
|         labels: torch.LongTensor) -> (torch.Tensor, torch.LongTensor): |     """Group-wise sum.""" | ||||||
|     """Group-wise sum""" |  | ||||||
|     uniques = labels.unique(sorted=True).tolist() |     uniques = labels.unique(sorted=True).tolist() | ||||||
|     labels = labels.tolist() |     labels = labels.tolist() | ||||||
|  |  | ||||||
|     key_val = {key: val for key, val in zip(uniques, range(len(uniques)))} |     key_val = {key: val for key, val in zip(uniques, range(len(uniques)))} | ||||||
|     labels = torch.LongTensor(list(map(key_val.get, labels))) |     labels = torch.LongTensor(list(map(key_val.get, labels))) | ||||||
|  |  | ||||||
|     labels = labels.view(labels.size(0), 1).expand(-1, value.size(1)) |     labels = labels.view(labels.size(0), 1).expand(-1, values.size(1)) | ||||||
|  |  | ||||||
|     unique_labels = labels.unique(dim=0) |     unique_labels = labels.unique(dim=0) | ||||||
|  |     print(f"{labels=}") | ||||||
|  |     print(f"{unique_labels=}") | ||||||
|  |  | ||||||
|     result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_( |     result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_( | ||||||
|         0, labels, value) |         0, labels, values) | ||||||
|     return result.T |     return result.T | ||||||
|  |  | ||||||
|  |  | ||||||
| def stratified_min(distances, labels): | def stratify_with(values: torch.Tensor, | ||||||
|     """Group-wise minimum""" |                   labels: torch.LongTensor, | ||||||
|     clabels = torch.unique(labels, dim=0) |                   fn: Callable, | ||||||
|  |                   fill_value: float = 0.0) -> (torch.Tensor): | ||||||
|  |     """Apply an arbitrary stratification strategy on the columns on `values`. | ||||||
|  |  | ||||||
|  |     The outputs correspond to sorted labels. | ||||||
|  |     """ | ||||||
|  |     clabels = torch.unique(labels, dim=0, sorted=True) | ||||||
|     num_classes = clabels.size()[0] |     num_classes = clabels.size()[0] | ||||||
|     if distances.size()[1] == num_classes: |     if values.size()[1] == num_classes: | ||||||
|         # skip if only one prototype per class |         # skip if stratification is trivial | ||||||
|         return distances |         return values | ||||||
|     batch_size = distances.size()[0] |     batch_size = values.size()[0] | ||||||
|     winning_distances = torch.zeros(num_classes, batch_size) |     winning_values = torch.zeros(num_classes, batch_size) | ||||||
|     inf = torch.full_like(distances.T, fill_value=float("inf")) |     filler = torch.full_like(values.T, fill_value=fill_value) | ||||||
|     # distances_to_wpluses = torch.where(matcher, distances, inf) |  | ||||||
|     for i, cl in enumerate(clabels): |     for i, cl in enumerate(clabels): | ||||||
|         # cdists = distances.T[labels == cl] |  | ||||||
|         matcher = torch.eq(labels.unsqueeze(dim=1), cl) |         matcher = torch.eq(labels.unsqueeze(dim=1), cl) | ||||||
|         if labels.ndim == 2: |         if labels.ndim == 2: | ||||||
|             # if the labels are one-hot vectors |             # if the labels are one-hot vectors | ||||||
|             matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) |             matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) | ||||||
|         cdists = torch.where(matcher, distances.T, inf).T |         cdists = torch.where(matcher, values.T, filler).T | ||||||
|         winning_distances[i] = torch.min(cdists, dim=1, |         winning_values[i] = fn(cdists) | ||||||
|                                          keepdim=True).values.squeeze() |  | ||||||
|     if labels.ndim == 2: |     if labels.ndim == 2: | ||||||
|         # Transpose to return with `batch_size` first and |         # Transpose to return with `batch_size` first and | ||||||
|         # reverse the columns to fix the ordering of the classes |         # reverse the columns to fix the ordering of the classes | ||||||
|         return torch.flip(winning_distances.T, dims=(1, )) |         return torch.flip(winning_values.T, dims=(1, )) | ||||||
|  |  | ||||||
|     return winning_distances.T  # return with `batch_size` first |     return winning_values.T  # return with `batch_size` first | ||||||
|  |  | ||||||
|  |  | ||||||
| def wtac(distances, labels): | def stratified_sum(values: torch.Tensor, | ||||||
|  |                    labels: torch.LongTensor) -> (torch.Tensor): | ||||||
|  |     """Group-wise sum.""" | ||||||
|  |     winning_values = stratify_with( | ||||||
|  |         values, | ||||||
|  |         labels, | ||||||
|  |         fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(), | ||||||
|  |         fill_value=0.0) | ||||||
|  |     return winning_values | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def stratified_min(values: torch.Tensor, | ||||||
|  |                    labels: torch.LongTensor) -> (torch.Tensor): | ||||||
|  |     """Group-wise minimum.""" | ||||||
|  |     winning_values = stratify_with( | ||||||
|  |         values, | ||||||
|  |         labels, | ||||||
|  |         fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(), | ||||||
|  |         fill_value=float("inf")) | ||||||
|  |     return winning_values | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def stratified_max(values: torch.Tensor, | ||||||
|  |                    labels: torch.LongTensor) -> (torch.Tensor): | ||||||
|  |     """Group-wise maximum.""" | ||||||
|  |     winning_values = stratify_with( | ||||||
|  |         values, | ||||||
|  |         labels, | ||||||
|  |         fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(), | ||||||
|  |         fill_value=-1.0 * float("inf")) | ||||||
|  |     return winning_values | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def stratified_prod(values: torch.Tensor, | ||||||
|  |                     labels: torch.LongTensor) -> (torch.Tensor): | ||||||
|  |     """Group-wise maximum.""" | ||||||
|  |     winning_values = stratify_with( | ||||||
|  |         values, | ||||||
|  |         labels, | ||||||
|  |         fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(), | ||||||
|  |         fill_value=1.0) | ||||||
|  |     return winning_values | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def wtac(distances: torch.Tensor, | ||||||
|  |          labels: torch.LongTensor) -> (torch.LongTensor): | ||||||
|     winning_indices = torch.min(distances, dim=1).indices |     winning_indices = torch.min(distances, dim=1).indices | ||||||
|     winning_labels = labels[winning_indices].squeeze() |     winning_labels = labels[winning_indices].squeeze() | ||||||
|     return winning_labels |     return winning_labels | ||||||
|  |  | ||||||
|  |  | ||||||
| def knnc(distances, labels, k=1): | def knnc(distances: torch.Tensor, | ||||||
|  |          labels: torch.LongTensor, | ||||||
|  |          k: int = 1) -> (torch.LongTensor): | ||||||
|     winning_indices = torch.topk(-distances, k=k, dim=1).indices |     winning_indices = torch.topk(-distances, k=k, dim=1).indices | ||||||
|     # winning_labels = torch.mode(labels[winning_indices].squeeze(), |  | ||||||
|     #                             dim=1).values |  | ||||||
|     winning_labels = torch.mode(labels[winning_indices], dim=1).values |     winning_labels = torch.mode(labels[winning_indices], dim=1).values | ||||||
|     return winning_labels |     return winning_labels | ||||||
|   | |||||||
| @@ -125,7 +125,7 @@ class TestCompetitions(unittest.TestCase): | |||||||
|                                                         decimal=5) |                                                         decimal=5) | ||||||
|         self.assertIsNone(mismatch) |         self.assertIsNone(mismatch) | ||||||
|  |  | ||||||
|     def test_stratified_min_simple(self): |     def test_stratified_min_trivial(self): | ||||||
|         d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]]) |         d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]]) | ||||||
|         labels = torch.tensor([0, 1, 2]) |         labels = torch.tensor([0, 1, 2]) | ||||||
|         actual = competitions.stratified_min(d, labels) |         actual = competitions.stratified_min(d, labels) | ||||||
| @@ -135,6 +135,58 @@ class TestCompetitions(unittest.TestCase): | |||||||
|                                                         decimal=5) |                                                         decimal=5) | ||||||
|         self.assertIsNone(mismatch) |         self.assertIsNone(mismatch) | ||||||
|  |  | ||||||
|  |     def test_stratified_max(self): | ||||||
|  |         d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) | ||||||
|  |         labels = torch.tensor([0, 0, 3, 2, 0]) | ||||||
|  |         actual = competitions.stratified_max(d, labels) | ||||||
|  |         desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) | ||||||
|  |         mismatch = np.testing.assert_array_almost_equal(actual, | ||||||
|  |                                                         desired, | ||||||
|  |                                                         decimal=5) | ||||||
|  |         self.assertIsNone(mismatch) | ||||||
|  |  | ||||||
|  |     def test_stratified_max_one_hot(self): | ||||||
|  |         d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) | ||||||
|  |         labels = torch.tensor([0, 0, 2, 1, 0]) | ||||||
|  |         labels = torch.nn.functional.one_hot(labels, num_classes=3) | ||||||
|  |         actual = competitions.stratified_max(d, labels) | ||||||
|  |         desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]]) | ||||||
|  |         mismatch = np.testing.assert_array_almost_equal(actual, | ||||||
|  |                                                         desired, | ||||||
|  |                                                         decimal=5) | ||||||
|  |         self.assertIsNone(mismatch) | ||||||
|  |  | ||||||
|  |     def test_stratified_sum(self): | ||||||
|  |         d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) | ||||||
|  |         labels = torch.LongTensor([0, 0, 1, 2]) | ||||||
|  |         actual = competitions.stratified_sum(d, labels) | ||||||
|  |         desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) | ||||||
|  |         mismatch = np.testing.assert_array_almost_equal(actual, | ||||||
|  |                                                         desired, | ||||||
|  |                                                         decimal=5) | ||||||
|  |         self.assertIsNone(mismatch) | ||||||
|  |  | ||||||
|  |     def test_stratified_sum_one_hot(self): | ||||||
|  |         d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]]) | ||||||
|  |         labels = torch.tensor([0, 0, 1, 2]) | ||||||
|  |         labels = torch.eye(3)[labels] | ||||||
|  |         actual = competitions.stratified_sum(d, labels) | ||||||
|  |         desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]]) | ||||||
|  |         mismatch = np.testing.assert_array_almost_equal(actual, | ||||||
|  |                                                         desired, | ||||||
|  |                                                         decimal=5) | ||||||
|  |         self.assertIsNone(mismatch) | ||||||
|  |  | ||||||
|  |     def test_stratified_prod(self): | ||||||
|  |         d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]]) | ||||||
|  |         labels = torch.tensor([0, 0, 3, 2, 0]) | ||||||
|  |         actual = competitions.stratified_prod(d, labels) | ||||||
|  |         desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]]) | ||||||
|  |         mismatch = np.testing.assert_array_almost_equal(actual, | ||||||
|  |                                                         desired, | ||||||
|  |                                                         decimal=5) | ||||||
|  |         self.assertIsNone(mismatch) | ||||||
|  |  | ||||||
|     def test_knnc_k1(self): |     def test_knnc_k1(self): | ||||||
|         d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) |         d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) | ||||||
|         labels = torch.tensor([0, 1, 2, 3]) |         labels = torch.tensor([0, 1, 2, 3]) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user