[BUGFIX] Fix typo

This commit is contained in:
Jensun Ravichandran 2021-06-04 22:24:42 +02:00
parent b03c9b1d3c
commit 2272c55092

View File

@ -10,7 +10,7 @@ from prototorch.functions.pooling import (stratified_max_pooling,
class StratifiedSumPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_sum_pooling` function."""
def forward(self, values, labels):
return stratified_sum(values, labels)
return stratified_sum_pooling(values, labels)
class StratifiedProdPooling(torch.nn.Module):