Add stratified sum as competition
For example used in RSLVQ
This commit is contained in:
parent
0ba09db6fe
commit
62726df278
@ -3,7 +3,26 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_sum(
|
||||||
|
value: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.Tensor, torch.LongTensor):
|
||||||
|
"""Group-wise sum"""
|
||||||
|
uniques = labels.unique(sorted=True).tolist()
|
||||||
|
labels = labels.tolist()
|
||||||
|
|
||||||
|
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
||||||
|
labels = torch.LongTensor(list(map(key_val.get, labels)))
|
||||||
|
|
||||||
|
labels = labels.view(labels.size(0), 1).expand(-1, value.size(1))
|
||||||
|
|
||||||
|
unique_labels = labels.unique(dim=0)
|
||||||
|
result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(
|
||||||
|
0, labels, value)
|
||||||
|
return result.T
|
||||||
|
|
||||||
|
|
||||||
def stratified_min(distances, labels):
|
def stratified_min(distances, labels):
|
||||||
|
"""Group-wise minimum"""
|
||||||
clabels = torch.unique(labels, dim=0)
|
clabels = torch.unique(labels, dim=0)
|
||||||
num_classes = clabels.size()[0]
|
num_classes = clabels.size()[0]
|
||||||
if distances.size()[1] == num_classes:
|
if distances.size()[1] == num_classes:
|
||||||
|
Loading…
Reference in New Issue
Block a user