[BUGFIX] Fix typo
This commit is contained in:
@@ -10,7 +10,7 @@ from prototorch.functions.pooling import (stratified_max_pooling,
|
||||
class StratifiedSumPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_sum(values, labels)
|
||||
return stratified_sum_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedProdPooling(torch.nn.Module):
|
||||
|
Reference in New Issue
Block a user