Add stratified sum as competition
For example used in RSLVQ
This commit is contained in:
parent
0ba09db6fe
commit
62726df278
@ -3,7 +3,26 @@
|
||||
import torch
|
||||
|
||||
|
||||
def stratified_sum(
|
||||
value: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor, torch.LongTensor):
|
||||
"""Group-wise sum"""
|
||||
uniques = labels.unique(sorted=True).tolist()
|
||||
labels = labels.tolist()
|
||||
|
||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
||||
labels = torch.LongTensor(list(map(key_val.get, labels)))
|
||||
|
||||
labels = labels.view(labels.size(0), 1).expand(-1, value.size(1))
|
||||
|
||||
unique_labels = labels.unique(dim=0)
|
||||
result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(
|
||||
0, labels, value)
|
||||
return result.T
|
||||
|
||||
|
||||
def stratified_min(distances, labels):
|
||||
"""Group-wise minimum"""
|
||||
clabels = torch.unique(labels, dim=0)
|
||||
num_classes = clabels.size()[0]
|
||||
if distances.size()[1] == num_classes:
|
||||
|
Loading…
Reference in New Issue
Block a user