fix(warning): specify dimension explicitly when calling softmin

This commit is contained in:
Jensun Ravichandran 2021-11-16 10:19:31 +01:00
parent 4232d0ed2a
commit 1d26226a2f
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F

View File

@ -135,7 +135,7 @@ class SupervisedPrototypeModel(PrototypeModel):
distances = self.compute_distances(x)
_, plabels = self.proto_layer()
winning = stratified_min_pooling(distances, plabels)
y_pred = torch.nn.functional.softmin(winning)
y_pred = torch.nn.functional.softmin(winning, dim=1)
return y_pred
def predict_from_distances(self, distances):