[BUGFIX] Fix typo
This commit is contained in:
@@ -10,7 +10,7 @@ from prototorch.functions.pooling import (stratified_max_pooling,
|
|||||||
class StratifiedSumPooling(torch.nn.Module):
|
class StratifiedSumPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||||
def forward(self, values, labels):
|
def forward(self, values, labels):
|
||||||
return stratified_sum(values, labels)
|
return stratified_sum_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedProdPooling(torch.nn.Module):
|
class StratifiedProdPooling(torch.nn.Module):
|
||||||
|
Reference in New Issue
Block a user