diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index c6c1164..a49a87c 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -135,7 +135,7 @@ class SupervisedPrototypeModel(PrototypeModel): distances = self.compute_distances(x) _, plabels = self.proto_layer() winning = stratified_min_pooling(distances, plabels) - y_pred = torch.nn.functional.softmin(winning) + y_pred = torch.nn.functional.softmin(winning, dim=1) return y_pred def predict_from_distances(self, distances):