Add stratified sum as competition

For example used in RSLVQ
This commit is contained in:
Alexander Engelsberger 2021-05-28 16:49:39 +02:00
parent 0ba09db6fe
commit 62726df278

View File

@ -3,7 +3,26 @@
import torch
def stratified_sum(
value: torch.Tensor,
labels: torch.LongTensor) -> (torch.Tensor, torch.LongTensor):
"""Group-wise sum"""
uniques = labels.unique(sorted=True).tolist()
labels = labels.tolist()
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
labels = torch.LongTensor(list(map(key_val.get, labels)))
labels = labels.view(labels.size(0), 1).expand(-1, value.size(1))
unique_labels = labels.unique(dim=0)
result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(
0, labels, value)
return result.T
def stratified_min(distances, labels):
"""Group-wise minimum"""
clabels = torch.unique(labels, dim=0)
num_classes = clabels.size()[0]
if distances.size()[1] == num_classes: