fix: labels where on cpu in forward pass

This commit is contained in:
Alexander Engelsberger
2021-08-05 09:14:32 +02:00
parent f8ad1d83eb
commit 0af8cf36f8
4 changed files with 7 additions and 10 deletions

View File

@@ -136,14 +136,14 @@ class SupervisedPrototypeModel(PrototypeModel):
def forward(self, x):
distances = self.compute_distances(x)
plabels = self.proto_layer.labels
_, plabels = self.proto_layer()
winning = stratified_min_pooling(distances, plabels)
y_pred = torch.nn.functional.softmin(winning)
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
plabels = self.proto_layer.labels
_, plabels = self.proto_layer()
y_pred = self.competition_layer(distances, plabels)
return y_pred