diff --git a/prototorch/functions/__init__.py b/prototorch/functions/__init__.py index 1139b07..9b3b993 100644 --- a/prototorch/functions/__init__.py +++ b/prototorch/functions/__init__.py @@ -2,11 +2,4 @@ from .activations import identity, sigmoid_beta, swish_beta from .competitions import knnc, wtac - -__all__ = [ - "identity", - "sigmoid_beta", - "swish_beta", - "knnc", - "wtac", -] +from .pooling import * diff --git a/prototorch/functions/competitions.py b/prototorch/functions/competitions.py index 0c4e2fa..326d510 100644 --- a/prototorch/functions/competitions.py +++ b/prototorch/functions/competitions.py @@ -1,107 +1,15 @@ """ProtoTorch competition functions.""" -from typing import Callable - import torch -def stratified_sum_v1(values: torch.Tensor, - labels: torch.LongTensor) -> (torch.Tensor): - """Group-wise sum.""" - uniques = labels.unique(sorted=True).tolist() - labels = labels.tolist() - - key_val = {key: val for key, val in zip(uniques, range(len(uniques)))} - labels = torch.LongTensor(list(map(key_val.get, labels))) - - labels = labels.view(labels.size(0), 1).expand(-1, values.size(1)) - - unique_labels = labels.unique(dim=0) - print(f"{labels=}") - print(f"{unique_labels=}") - - result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_( - 0, labels, values) - return result.T - - -def stratify_with(values: torch.Tensor, - labels: torch.LongTensor, - fn: Callable, - fill_value: float = 0.0) -> (torch.Tensor): - """Apply an arbitrary stratification strategy on the columns on `values`. - - The outputs correspond to sorted labels. - """ - clabels = torch.unique(labels, dim=0, sorted=True) - num_classes = clabels.size()[0] - if values.size()[1] == num_classes: - # skip if stratification is trivial - return values - batch_size = values.size()[0] - winning_values = torch.zeros(num_classes, batch_size, device=labels.device) - filler = torch.full_like(values.T, fill_value=fill_value) - for i, cl in enumerate(clabels): - matcher = torch.eq(labels.unsqueeze(dim=1), cl) - if labels.ndim == 2: - # if the labels are one-hot vectors - matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) - cdists = torch.where(matcher, values.T, filler).T - winning_values[i] = fn(cdists) - if labels.ndim == 2: - # Transpose to return with `batch_size` first and - # reverse the columns to fix the ordering of the classes - return torch.flip(winning_values.T, dims=(1, )) - - return winning_values.T # return with `batch_size` first - - -def stratified_sum(values: torch.Tensor, - labels: torch.LongTensor) -> (torch.Tensor): - """Group-wise sum.""" - winning_values = stratify_with( - values, - labels, - fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(), - fill_value=0.0) - return winning_values - - -def stratified_min(values: torch.Tensor, - labels: torch.LongTensor) -> (torch.Tensor): - """Group-wise minimum.""" - winning_values = stratify_with( - values, - labels, - fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(), - fill_value=float("inf")) - return winning_values - - -def stratified_max(values: torch.Tensor, - labels: torch.LongTensor) -> (torch.Tensor): - """Group-wise maximum.""" - winning_values = stratify_with( - values, - labels, - fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(), - fill_value=-1.0 * float("inf")) - return winning_values - - -def stratified_prod(values: torch.Tensor, - labels: torch.LongTensor) -> (torch.Tensor): - """Group-wise maximum.""" - winning_values = stratify_with( - values, - labels, - fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(), - fill_value=1.0) - return winning_values - - def wtac(distances: torch.Tensor, labels: torch.LongTensor) -> (torch.LongTensor): + """Winner-Takes-All-Competition. + + Returns the labels corresponding to the winners. + + """ winning_indices = torch.min(distances, dim=1).indices winning_labels = labels[winning_indices].squeeze() return winning_labels @@ -110,6 +18,11 @@ def wtac(distances: torch.Tensor, def knnc(distances: torch.Tensor, labels: torch.LongTensor, k: int = 1) -> (torch.LongTensor): + """K-Nearest-Neighbors-Competition. + + Returns the labels corresponding to the winners. + + """ winning_indices = torch.topk(-distances, k=k, dim=1).indices winning_labels = torch.mode(labels[winning_indices], dim=1).values return winning_labels diff --git a/prototorch/functions/pooling.py b/prototorch/functions/pooling.py new file mode 100644 index 0000000..6dd427e --- /dev/null +++ b/prototorch/functions/pooling.py @@ -0,0 +1,80 @@ +"""ProtoTorch pooling functions.""" + +from typing import Callable + +import torch + + +def stratify_with(values: torch.Tensor, + labels: torch.LongTensor, + fn: Callable, + fill_value: float = 0.0) -> (torch.Tensor): + """Apply an arbitrary stratification strategy on the columns on `values`. + + The outputs correspond to sorted labels. + """ + clabels = torch.unique(labels, dim=0, sorted=True) + num_classes = clabels.size()[0] + if values.size()[1] == num_classes: + # skip if stratification is trivial + return values + batch_size = values.size()[0] + winning_values = torch.zeros(num_classes, batch_size, device=labels.device) + filler = torch.full_like(values.T, fill_value=fill_value) + for i, cl in enumerate(clabels): + matcher = torch.eq(labels.unsqueeze(dim=1), cl) + if labels.ndim == 2: + # if the labels are one-hot vectors + matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) + cdists = torch.where(matcher, values.T, filler).T + winning_values[i] = fn(cdists) + if labels.ndim == 2: + # Transpose to return with `batch_size` first and + # reverse the columns to fix the ordering of the classes + return torch.flip(winning_values.T, dims=(1, )) + + return winning_values.T # return with `batch_size` first + + +def stratified_sum_pooling(values: torch.Tensor, + labels: torch.LongTensor) -> (torch.Tensor): + """Group-wise sum.""" + winning_values = stratify_with( + values, + labels, + fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(), + fill_value=0.0) + return winning_values + + +def stratified_min_pooling(values: torch.Tensor, + labels: torch.LongTensor) -> (torch.Tensor): + """Group-wise minimum.""" + winning_values = stratify_with( + values, + labels, + fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(), + fill_value=float("inf")) + return winning_values + + +def stratified_max_pooling(values: torch.Tensor, + labels: torch.LongTensor) -> (torch.Tensor): + """Group-wise maximum.""" + winning_values = stratify_with( + values, + labels, + fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(), + fill_value=-1.0 * float("inf")) + return winning_values + + +def stratified_prod_pooling(values: torch.Tensor, + labels: torch.LongTensor) -> (torch.Tensor): + """Group-wise maximum.""" + winning_values = stratify_with( + values, + labels, + fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(), + fill_value=1.0) + return winning_values diff --git a/prototorch/modules/__init__.py b/prototorch/modules/__init__.py index 7f05925..fc7ab87 100644 --- a/prototorch/modules/__init__.py +++ b/prototorch/modules/__init__.py @@ -1,3 +1,5 @@ """ProtoTorch modules.""" -from .utils import LambdaLayer +from .competitions import * +from .pooling import * +from .wrappers import LambdaLayer, LossLayer diff --git a/prototorch/modules/competitions.py b/prototorch/modules/competitions.py new file mode 100644 index 0000000..a15631a --- /dev/null +++ b/prototorch/modules/competitions.py @@ -0,0 +1,41 @@ +"""ProtoTorch Competition Modules.""" + +import torch +from prototorch.functions.competitions import knnc, wtac + + +class WTAC(torch.nn.Module): + """Winner-Takes-All-Competition Layer. + + Thin wrapper over the `wtac` function. + + """ + def forward(self, distances, labels): + return wtac(distances, labels) + + +class LTAC(torch.nn.Module): + """Loser-Takes-All-Competition Layer. + + Thin wrapper over the `wtac` function. + + """ + def forward(self, probs, labels): + return wtac(-1.0 * probs, labels) + + +class KNNC(torch.nn.Module): + """K-Nearest-Neighbors-Competition. + + Thin wrapper over the `knnc` function. + + """ + def __init__(self, k=1, **kwargs): + super().__init__(**kwargs) + self.k = k + + def forward(self, distances, labels): + return knnc(distances, labels, k=self.k) + + def extra_repr(self): + return f"k: {self.k}" diff --git a/prototorch/modules/pooling.py b/prototorch/modules/pooling.py new file mode 100644 index 0000000..78ce24e --- /dev/null +++ b/prototorch/modules/pooling.py @@ -0,0 +1,31 @@ +"""ProtoTorch Pooling Modules.""" + +import torch +from prototorch.functions.pooling import (stratified_max_pooling, + stratified_min_pooling, + stratified_prod_pooling, + stratified_sum_pooling) + + +class StratifiedSumPooling(torch.nn.Module): + """Thin wrapper over the `stratified_sum_pooling` function.""" + def forward(self, values, labels): + return stratified_sum(values, labels) + + +class StratifiedProdPooling(torch.nn.Module): + """Thin wrapper over the `stratified_prod_pooling` function.""" + def forward(self, values, labels): + return stratified_prod_pooling(values, labels) + + +class StratifiedMinPooling(torch.nn.Module): + """Thin wrapper over the `stratified_min_pooling` function.""" + def forward(self, values, labels): + return stratified_min_pooling(values, labels) + + +class StratifiedMaxPooling(torch.nn.Module): + """Thin wrapper over the `stratified_max_pooling` function.""" + def forward(self, values, labels): + return stratified_max_pooling(values, labels)