[BUGFIX] Fix typo
This commit is contained in:
parent
b03c9b1d3c
commit
2272c55092
@ -10,7 +10,7 @@ from prototorch.functions.pooling import (stratified_max_pooling,
|
||||
class StratifiedSumPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_sum(values, labels)
|
||||
return stratified_sum_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedProdPooling(torch.nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user