fix: All examples should work on CPU and GPU now
This commit is contained in:
@@ -96,7 +96,7 @@ class UnsupervisedPrototypeModel(PrototypeModel):
|
||||
)
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos = self.proto_layer()
|
||||
protos = self.proto_layer().type_as(x)
|
||||
distances = self.distance_layer(x, protos)
|
||||
return distances
|
||||
|
||||
|
Reference in New Issue
Block a user