fix: All examples should work on CPU and GPU now

This commit is contained in:
Alexander Engelsberger
2021-08-05 11:20:02 +02:00
parent 0af8cf36f8
commit d7834e2cc0
5 changed files with 5 additions and 6 deletions

View File

@@ -96,7 +96,7 @@ class UnsupervisedPrototypeModel(PrototypeModel):
)
def compute_distances(self, x):
protos = self.proto_layer()
protos = self.proto_layer().type_as(x)
distances = self.distance_layer(x, protos)
return distances