Add competition and pooling modules
This commit is contained in:
		@@ -2,11 +2,4 @@
 | 
			
		||||
 | 
			
		||||
from .activations import identity, sigmoid_beta, swish_beta
 | 
			
		||||
from .competitions import knnc, wtac
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    "identity",
 | 
			
		||||
    "sigmoid_beta",
 | 
			
		||||
    "swish_beta",
 | 
			
		||||
    "knnc",
 | 
			
		||||
    "wtac",
 | 
			
		||||
]
 | 
			
		||||
from .pooling import *
 | 
			
		||||
 
 | 
			
		||||
@@ -1,107 +1,15 @@
 | 
			
		||||
"""ProtoTorch competition functions."""
 | 
			
		||||
 | 
			
		||||
from typing import Callable
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stratified_sum_v1(values: torch.Tensor,
 | 
			
		||||
                      labels: torch.LongTensor) -> (torch.Tensor):
 | 
			
		||||
    """Group-wise sum."""
 | 
			
		||||
    uniques = labels.unique(sorted=True).tolist()
 | 
			
		||||
    labels = labels.tolist()
 | 
			
		||||
 | 
			
		||||
    key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
 | 
			
		||||
    labels = torch.LongTensor(list(map(key_val.get, labels)))
 | 
			
		||||
 | 
			
		||||
    labels = labels.view(labels.size(0), 1).expand(-1, values.size(1))
 | 
			
		||||
 | 
			
		||||
    unique_labels = labels.unique(dim=0)
 | 
			
		||||
    print(f"{labels=}")
 | 
			
		||||
    print(f"{unique_labels=}")
 | 
			
		||||
 | 
			
		||||
    result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(
 | 
			
		||||
        0, labels, values)
 | 
			
		||||
    return result.T
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stratify_with(values: torch.Tensor,
 | 
			
		||||
                  labels: torch.LongTensor,
 | 
			
		||||
                  fn: Callable,
 | 
			
		||||
                  fill_value: float = 0.0) -> (torch.Tensor):
 | 
			
		||||
    """Apply an arbitrary stratification strategy on the columns on `values`.
 | 
			
		||||
 | 
			
		||||
    The outputs correspond to sorted labels.
 | 
			
		||||
    """
 | 
			
		||||
    clabels = torch.unique(labels, dim=0, sorted=True)
 | 
			
		||||
    num_classes = clabels.size()[0]
 | 
			
		||||
    if values.size()[1] == num_classes:
 | 
			
		||||
        # skip if stratification is trivial
 | 
			
		||||
        return values
 | 
			
		||||
    batch_size = values.size()[0]
 | 
			
		||||
    winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
 | 
			
		||||
    filler = torch.full_like(values.T, fill_value=fill_value)
 | 
			
		||||
    for i, cl in enumerate(clabels):
 | 
			
		||||
        matcher = torch.eq(labels.unsqueeze(dim=1), cl)
 | 
			
		||||
        if labels.ndim == 2:
 | 
			
		||||
            # if the labels are one-hot vectors
 | 
			
		||||
            matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
 | 
			
		||||
        cdists = torch.where(matcher, values.T, filler).T
 | 
			
		||||
        winning_values[i] = fn(cdists)
 | 
			
		||||
    if labels.ndim == 2:
 | 
			
		||||
        # Transpose to return with `batch_size` first and
 | 
			
		||||
        # reverse the columns to fix the ordering of the classes
 | 
			
		||||
        return torch.flip(winning_values.T, dims=(1, ))
 | 
			
		||||
 | 
			
		||||
    return winning_values.T  # return with `batch_size` first
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stratified_sum(values: torch.Tensor,
 | 
			
		||||
                   labels: torch.LongTensor) -> (torch.Tensor):
 | 
			
		||||
    """Group-wise sum."""
 | 
			
		||||
    winning_values = stratify_with(
 | 
			
		||||
        values,
 | 
			
		||||
        labels,
 | 
			
		||||
        fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
 | 
			
		||||
        fill_value=0.0)
 | 
			
		||||
    return winning_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stratified_min(values: torch.Tensor,
 | 
			
		||||
                   labels: torch.LongTensor) -> (torch.Tensor):
 | 
			
		||||
    """Group-wise minimum."""
 | 
			
		||||
    winning_values = stratify_with(
 | 
			
		||||
        values,
 | 
			
		||||
        labels,
 | 
			
		||||
        fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
 | 
			
		||||
        fill_value=float("inf"))
 | 
			
		||||
    return winning_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stratified_max(values: torch.Tensor,
 | 
			
		||||
                   labels: torch.LongTensor) -> (torch.Tensor):
 | 
			
		||||
    """Group-wise maximum."""
 | 
			
		||||
    winning_values = stratify_with(
 | 
			
		||||
        values,
 | 
			
		||||
        labels,
 | 
			
		||||
        fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
 | 
			
		||||
        fill_value=-1.0 * float("inf"))
 | 
			
		||||
    return winning_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stratified_prod(values: torch.Tensor,
 | 
			
		||||
                    labels: torch.LongTensor) -> (torch.Tensor):
 | 
			
		||||
    """Group-wise maximum."""
 | 
			
		||||
    winning_values = stratify_with(
 | 
			
		||||
        values,
 | 
			
		||||
        labels,
 | 
			
		||||
        fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
 | 
			
		||||
        fill_value=1.0)
 | 
			
		||||
    return winning_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def wtac(distances: torch.Tensor,
 | 
			
		||||
         labels: torch.LongTensor) -> (torch.LongTensor):
 | 
			
		||||
    """Winner-Takes-All-Competition.
 | 
			
		||||
 | 
			
		||||
    Returns the labels corresponding to the winners.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    winning_indices = torch.min(distances, dim=1).indices
 | 
			
		||||
    winning_labels = labels[winning_indices].squeeze()
 | 
			
		||||
    return winning_labels
 | 
			
		||||
@@ -110,6 +18,11 @@ def wtac(distances: torch.Tensor,
 | 
			
		||||
def knnc(distances: torch.Tensor,
 | 
			
		||||
         labels: torch.LongTensor,
 | 
			
		||||
         k: int = 1) -> (torch.LongTensor):
 | 
			
		||||
    """K-Nearest-Neighbors-Competition.
 | 
			
		||||
 | 
			
		||||
    Returns the labels corresponding to the winners.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    winning_indices = torch.topk(-distances, k=k, dim=1).indices
 | 
			
		||||
    winning_labels = torch.mode(labels[winning_indices], dim=1).values
 | 
			
		||||
    return winning_labels
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										80
									
								
								prototorch/functions/pooling.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								prototorch/functions/pooling.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,80 @@
 | 
			
		||||
"""ProtoTorch pooling functions."""
 | 
			
		||||
 | 
			
		||||
from typing import Callable
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stratify_with(values: torch.Tensor,
 | 
			
		||||
                  labels: torch.LongTensor,
 | 
			
		||||
                  fn: Callable,
 | 
			
		||||
                  fill_value: float = 0.0) -> (torch.Tensor):
 | 
			
		||||
    """Apply an arbitrary stratification strategy on the columns on `values`.
 | 
			
		||||
 | 
			
		||||
    The outputs correspond to sorted labels.
 | 
			
		||||
    """
 | 
			
		||||
    clabels = torch.unique(labels, dim=0, sorted=True)
 | 
			
		||||
    num_classes = clabels.size()[0]
 | 
			
		||||
    if values.size()[1] == num_classes:
 | 
			
		||||
        # skip if stratification is trivial
 | 
			
		||||
        return values
 | 
			
		||||
    batch_size = values.size()[0]
 | 
			
		||||
    winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
 | 
			
		||||
    filler = torch.full_like(values.T, fill_value=fill_value)
 | 
			
		||||
    for i, cl in enumerate(clabels):
 | 
			
		||||
        matcher = torch.eq(labels.unsqueeze(dim=1), cl)
 | 
			
		||||
        if labels.ndim == 2:
 | 
			
		||||
            # if the labels are one-hot vectors
 | 
			
		||||
            matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
 | 
			
		||||
        cdists = torch.where(matcher, values.T, filler).T
 | 
			
		||||
        winning_values[i] = fn(cdists)
 | 
			
		||||
    if labels.ndim == 2:
 | 
			
		||||
        # Transpose to return with `batch_size` first and
 | 
			
		||||
        # reverse the columns to fix the ordering of the classes
 | 
			
		||||
        return torch.flip(winning_values.T, dims=(1, ))
 | 
			
		||||
 | 
			
		||||
    return winning_values.T  # return with `batch_size` first
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stratified_sum_pooling(values: torch.Tensor,
 | 
			
		||||
                           labels: torch.LongTensor) -> (torch.Tensor):
 | 
			
		||||
    """Group-wise sum."""
 | 
			
		||||
    winning_values = stratify_with(
 | 
			
		||||
        values,
 | 
			
		||||
        labels,
 | 
			
		||||
        fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
 | 
			
		||||
        fill_value=0.0)
 | 
			
		||||
    return winning_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stratified_min_pooling(values: torch.Tensor,
 | 
			
		||||
                           labels: torch.LongTensor) -> (torch.Tensor):
 | 
			
		||||
    """Group-wise minimum."""
 | 
			
		||||
    winning_values = stratify_with(
 | 
			
		||||
        values,
 | 
			
		||||
        labels,
 | 
			
		||||
        fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
 | 
			
		||||
        fill_value=float("inf"))
 | 
			
		||||
    return winning_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stratified_max_pooling(values: torch.Tensor,
 | 
			
		||||
                           labels: torch.LongTensor) -> (torch.Tensor):
 | 
			
		||||
    """Group-wise maximum."""
 | 
			
		||||
    winning_values = stratify_with(
 | 
			
		||||
        values,
 | 
			
		||||
        labels,
 | 
			
		||||
        fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
 | 
			
		||||
        fill_value=-1.0 * float("inf"))
 | 
			
		||||
    return winning_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stratified_prod_pooling(values: torch.Tensor,
 | 
			
		||||
                            labels: torch.LongTensor) -> (torch.Tensor):
 | 
			
		||||
    """Group-wise maximum."""
 | 
			
		||||
    winning_values = stratify_with(
 | 
			
		||||
        values,
 | 
			
		||||
        labels,
 | 
			
		||||
        fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
 | 
			
		||||
        fill_value=1.0)
 | 
			
		||||
    return winning_values
 | 
			
		||||
@@ -1,3 +1,5 @@
 | 
			
		||||
"""ProtoTorch modules."""
 | 
			
		||||
 | 
			
		||||
from .utils import LambdaLayer
 | 
			
		||||
from .competitions import *
 | 
			
		||||
from .pooling import *
 | 
			
		||||
from .wrappers import LambdaLayer, LossLayer
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										41
									
								
								prototorch/modules/competitions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								prototorch/modules/competitions.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,41 @@
 | 
			
		||||
"""ProtoTorch Competition Modules."""
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from prototorch.functions.competitions import knnc, wtac
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WTAC(torch.nn.Module):
 | 
			
		||||
    """Winner-Takes-All-Competition Layer.
 | 
			
		||||
 | 
			
		||||
    Thin wrapper over the `wtac` function.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    def forward(self, distances, labels):
 | 
			
		||||
        return wtac(distances, labels)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LTAC(torch.nn.Module):
 | 
			
		||||
    """Loser-Takes-All-Competition Layer.
 | 
			
		||||
 | 
			
		||||
    Thin wrapper over the `wtac` function.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    def forward(self, probs, labels):
 | 
			
		||||
        return wtac(-1.0 * probs, labels)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KNNC(torch.nn.Module):
 | 
			
		||||
    """K-Nearest-Neighbors-Competition.
 | 
			
		||||
 | 
			
		||||
    Thin wrapper over the `knnc` function.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, k=1, **kwargs):
 | 
			
		||||
        super().__init__(**kwargs)
 | 
			
		||||
        self.k = k
 | 
			
		||||
 | 
			
		||||
    def forward(self, distances, labels):
 | 
			
		||||
        return knnc(distances, labels, k=self.k)
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self):
 | 
			
		||||
        return f"k: {self.k}"
 | 
			
		||||
							
								
								
									
										31
									
								
								prototorch/modules/pooling.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								prototorch/modules/pooling.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,31 @@
 | 
			
		||||
"""ProtoTorch Pooling Modules."""
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from prototorch.functions.pooling import (stratified_max_pooling,
 | 
			
		||||
                                          stratified_min_pooling,
 | 
			
		||||
                                          stratified_prod_pooling,
 | 
			
		||||
                                          stratified_sum_pooling)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StratifiedSumPooling(torch.nn.Module):
 | 
			
		||||
    """Thin wrapper over the `stratified_sum_pooling` function."""
 | 
			
		||||
    def forward(self, values, labels):
 | 
			
		||||
        return stratified_sum(values, labels)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StratifiedProdPooling(torch.nn.Module):
 | 
			
		||||
    """Thin wrapper over the `stratified_prod_pooling` function."""
 | 
			
		||||
    def forward(self, values, labels):
 | 
			
		||||
        return stratified_prod_pooling(values, labels)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StratifiedMinPooling(torch.nn.Module):
 | 
			
		||||
    """Thin wrapper over the `stratified_min_pooling` function."""
 | 
			
		||||
    def forward(self, values, labels):
 | 
			
		||||
        return stratified_min_pooling(values, labels)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StratifiedMaxPooling(torch.nn.Module):
 | 
			
		||||
    """Thin wrapper over the `stratified_max_pooling` function."""
 | 
			
		||||
    def forward(self, values, labels):
 | 
			
		||||
        return stratified_max_pooling(values, labels)
 | 
			
		||||
		Reference in New Issue
	
	Block a user