Add competition and pooling modules
This commit is contained in:
parent
0c28eda706
commit
b03c9b1d3c
@ -2,11 +2,4 @@
|
|||||||
|
|
||||||
from .activations import identity, sigmoid_beta, swish_beta
|
from .activations import identity, sigmoid_beta, swish_beta
|
||||||
from .competitions import knnc, wtac
|
from .competitions import knnc, wtac
|
||||||
|
from .pooling import *
|
||||||
__all__ = [
|
|
||||||
"identity",
|
|
||||||
"sigmoid_beta",
|
|
||||||
"swish_beta",
|
|
||||||
"knnc",
|
|
||||||
"wtac",
|
|
||||||
]
|
|
||||||
|
@ -1,107 +1,15 @@
|
|||||||
"""ProtoTorch competition functions."""
|
"""ProtoTorch competition functions."""
|
||||||
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def stratified_sum_v1(values: torch.Tensor,
|
|
||||||
labels: torch.LongTensor) -> (torch.Tensor):
|
|
||||||
"""Group-wise sum."""
|
|
||||||
uniques = labels.unique(sorted=True).tolist()
|
|
||||||
labels = labels.tolist()
|
|
||||||
|
|
||||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
|
||||||
labels = torch.LongTensor(list(map(key_val.get, labels)))
|
|
||||||
|
|
||||||
labels = labels.view(labels.size(0), 1).expand(-1, values.size(1))
|
|
||||||
|
|
||||||
unique_labels = labels.unique(dim=0)
|
|
||||||
print(f"{labels=}")
|
|
||||||
print(f"{unique_labels=}")
|
|
||||||
|
|
||||||
result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(
|
|
||||||
0, labels, values)
|
|
||||||
return result.T
|
|
||||||
|
|
||||||
|
|
||||||
def stratify_with(values: torch.Tensor,
|
|
||||||
labels: torch.LongTensor,
|
|
||||||
fn: Callable,
|
|
||||||
fill_value: float = 0.0) -> (torch.Tensor):
|
|
||||||
"""Apply an arbitrary stratification strategy on the columns on `values`.
|
|
||||||
|
|
||||||
The outputs correspond to sorted labels.
|
|
||||||
"""
|
|
||||||
clabels = torch.unique(labels, dim=0, sorted=True)
|
|
||||||
num_classes = clabels.size()[0]
|
|
||||||
if values.size()[1] == num_classes:
|
|
||||||
# skip if stratification is trivial
|
|
||||||
return values
|
|
||||||
batch_size = values.size()[0]
|
|
||||||
winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
|
|
||||||
filler = torch.full_like(values.T, fill_value=fill_value)
|
|
||||||
for i, cl in enumerate(clabels):
|
|
||||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
|
||||||
if labels.ndim == 2:
|
|
||||||
# if the labels are one-hot vectors
|
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
|
||||||
cdists = torch.where(matcher, values.T, filler).T
|
|
||||||
winning_values[i] = fn(cdists)
|
|
||||||
if labels.ndim == 2:
|
|
||||||
# Transpose to return with `batch_size` first and
|
|
||||||
# reverse the columns to fix the ordering of the classes
|
|
||||||
return torch.flip(winning_values.T, dims=(1, ))
|
|
||||||
|
|
||||||
return winning_values.T # return with `batch_size` first
|
|
||||||
|
|
||||||
|
|
||||||
def stratified_sum(values: torch.Tensor,
|
|
||||||
labels: torch.LongTensor) -> (torch.Tensor):
|
|
||||||
"""Group-wise sum."""
|
|
||||||
winning_values = stratify_with(
|
|
||||||
values,
|
|
||||||
labels,
|
|
||||||
fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
|
|
||||||
fill_value=0.0)
|
|
||||||
return winning_values
|
|
||||||
|
|
||||||
|
|
||||||
def stratified_min(values: torch.Tensor,
|
|
||||||
labels: torch.LongTensor) -> (torch.Tensor):
|
|
||||||
"""Group-wise minimum."""
|
|
||||||
winning_values = stratify_with(
|
|
||||||
values,
|
|
||||||
labels,
|
|
||||||
fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
|
|
||||||
fill_value=float("inf"))
|
|
||||||
return winning_values
|
|
||||||
|
|
||||||
|
|
||||||
def stratified_max(values: torch.Tensor,
|
|
||||||
labels: torch.LongTensor) -> (torch.Tensor):
|
|
||||||
"""Group-wise maximum."""
|
|
||||||
winning_values = stratify_with(
|
|
||||||
values,
|
|
||||||
labels,
|
|
||||||
fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
|
|
||||||
fill_value=-1.0 * float("inf"))
|
|
||||||
return winning_values
|
|
||||||
|
|
||||||
|
|
||||||
def stratified_prod(values: torch.Tensor,
|
|
||||||
labels: torch.LongTensor) -> (torch.Tensor):
|
|
||||||
"""Group-wise maximum."""
|
|
||||||
winning_values = stratify_with(
|
|
||||||
values,
|
|
||||||
labels,
|
|
||||||
fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
|
|
||||||
fill_value=1.0)
|
|
||||||
return winning_values
|
|
||||||
|
|
||||||
|
|
||||||
def wtac(distances: torch.Tensor,
|
def wtac(distances: torch.Tensor,
|
||||||
labels: torch.LongTensor) -> (torch.LongTensor):
|
labels: torch.LongTensor) -> (torch.LongTensor):
|
||||||
|
"""Winner-Takes-All-Competition.
|
||||||
|
|
||||||
|
Returns the labels corresponding to the winners.
|
||||||
|
|
||||||
|
"""
|
||||||
winning_indices = torch.min(distances, dim=1).indices
|
winning_indices = torch.min(distances, dim=1).indices
|
||||||
winning_labels = labels[winning_indices].squeeze()
|
winning_labels = labels[winning_indices].squeeze()
|
||||||
return winning_labels
|
return winning_labels
|
||||||
@ -110,6 +18,11 @@ def wtac(distances: torch.Tensor,
|
|||||||
def knnc(distances: torch.Tensor,
|
def knnc(distances: torch.Tensor,
|
||||||
labels: torch.LongTensor,
|
labels: torch.LongTensor,
|
||||||
k: int = 1) -> (torch.LongTensor):
|
k: int = 1) -> (torch.LongTensor):
|
||||||
|
"""K-Nearest-Neighbors-Competition.
|
||||||
|
|
||||||
|
Returns the labels corresponding to the winners.
|
||||||
|
|
||||||
|
"""
|
||||||
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
||||||
winning_labels = torch.mode(labels[winning_indices], dim=1).values
|
winning_labels = torch.mode(labels[winning_indices], dim=1).values
|
||||||
return winning_labels
|
return winning_labels
|
||||||
|
80
prototorch/functions/pooling.py
Normal file
80
prototorch/functions/pooling.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
"""ProtoTorch pooling functions."""
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def stratify_with(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor,
|
||||||
|
fn: Callable,
|
||||||
|
fill_value: float = 0.0) -> (torch.Tensor):
|
||||||
|
"""Apply an arbitrary stratification strategy on the columns on `values`.
|
||||||
|
|
||||||
|
The outputs correspond to sorted labels.
|
||||||
|
"""
|
||||||
|
clabels = torch.unique(labels, dim=0, sorted=True)
|
||||||
|
num_classes = clabels.size()[0]
|
||||||
|
if values.size()[1] == num_classes:
|
||||||
|
# skip if stratification is trivial
|
||||||
|
return values
|
||||||
|
batch_size = values.size()[0]
|
||||||
|
winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
|
||||||
|
filler = torch.full_like(values.T, fill_value=fill_value)
|
||||||
|
for i, cl in enumerate(clabels):
|
||||||
|
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||||
|
if labels.ndim == 2:
|
||||||
|
# if the labels are one-hot vectors
|
||||||
|
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||||
|
cdists = torch.where(matcher, values.T, filler).T
|
||||||
|
winning_values[i] = fn(cdists)
|
||||||
|
if labels.ndim == 2:
|
||||||
|
# Transpose to return with `batch_size` first and
|
||||||
|
# reverse the columns to fix the ordering of the classes
|
||||||
|
return torch.flip(winning_values.T, dims=(1, ))
|
||||||
|
|
||||||
|
return winning_values.T # return with `batch_size` first
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_sum_pooling(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.Tensor):
|
||||||
|
"""Group-wise sum."""
|
||||||
|
winning_values = stratify_with(
|
||||||
|
values,
|
||||||
|
labels,
|
||||||
|
fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
|
||||||
|
fill_value=0.0)
|
||||||
|
return winning_values
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_min_pooling(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.Tensor):
|
||||||
|
"""Group-wise minimum."""
|
||||||
|
winning_values = stratify_with(
|
||||||
|
values,
|
||||||
|
labels,
|
||||||
|
fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
|
||||||
|
fill_value=float("inf"))
|
||||||
|
return winning_values
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_max_pooling(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.Tensor):
|
||||||
|
"""Group-wise maximum."""
|
||||||
|
winning_values = stratify_with(
|
||||||
|
values,
|
||||||
|
labels,
|
||||||
|
fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
|
||||||
|
fill_value=-1.0 * float("inf"))
|
||||||
|
return winning_values
|
||||||
|
|
||||||
|
|
||||||
|
def stratified_prod_pooling(values: torch.Tensor,
|
||||||
|
labels: torch.LongTensor) -> (torch.Tensor):
|
||||||
|
"""Group-wise maximum."""
|
||||||
|
winning_values = stratify_with(
|
||||||
|
values,
|
||||||
|
labels,
|
||||||
|
fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
|
||||||
|
fill_value=1.0)
|
||||||
|
return winning_values
|
@ -1,3 +1,5 @@
|
|||||||
"""ProtoTorch modules."""
|
"""ProtoTorch modules."""
|
||||||
|
|
||||||
from .utils import LambdaLayer
|
from .competitions import *
|
||||||
|
from .pooling import *
|
||||||
|
from .wrappers import LambdaLayer, LossLayer
|
||||||
|
41
prototorch/modules/competitions.py
Normal file
41
prototorch/modules/competitions.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
"""ProtoTorch Competition Modules."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.functions.competitions import knnc, wtac
|
||||||
|
|
||||||
|
|
||||||
|
class WTAC(torch.nn.Module):
|
||||||
|
"""Winner-Takes-All-Competition Layer.
|
||||||
|
|
||||||
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def forward(self, distances, labels):
|
||||||
|
return wtac(distances, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class LTAC(torch.nn.Module):
|
||||||
|
"""Loser-Takes-All-Competition Layer.
|
||||||
|
|
||||||
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def forward(self, probs, labels):
|
||||||
|
return wtac(-1.0 * probs, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class KNNC(torch.nn.Module):
|
||||||
|
"""K-Nearest-Neighbors-Competition.
|
||||||
|
|
||||||
|
Thin wrapper over the `knnc` function.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, k=1, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.k = k
|
||||||
|
|
||||||
|
def forward(self, distances, labels):
|
||||||
|
return knnc(distances, labels, k=self.k)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"k: {self.k}"
|
31
prototorch/modules/pooling.py
Normal file
31
prototorch/modules/pooling.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
"""ProtoTorch Pooling Modules."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.functions.pooling import (stratified_max_pooling,
|
||||||
|
stratified_min_pooling,
|
||||||
|
stratified_prod_pooling,
|
||||||
|
stratified_sum_pooling)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedSumPooling(torch.nn.Module):
|
||||||
|
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||||
|
def forward(self, values, labels):
|
||||||
|
return stratified_sum(values, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedProdPooling(torch.nn.Module):
|
||||||
|
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||||
|
def forward(self, values, labels):
|
||||||
|
return stratified_prod_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedMinPooling(torch.nn.Module):
|
||||||
|
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||||
|
def forward(self, values, labels):
|
||||||
|
return stratified_min_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedMaxPooling(torch.nn.Module):
|
||||||
|
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||||
|
def forward(self, values, labels):
|
||||||
|
return stratified_max_pooling(values, labels)
|
Loading…
Reference in New Issue
Block a user