[BUGFIX] Probabilistic Models work on GPU now

This commit is contained in:
Alexander Engelsberger 2021-06-03 14:05:44 +02:00
parent 459f7c24be
commit 1b09b1d57b

View File

@ -36,7 +36,8 @@ class ProbabilisticLVQ(GLVQ):
distances = self._forward(x)
conditional = self.conditional_distribution(distances,
self.hparams.variance)
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes)
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
device=self.device)
posterior = conditional * prior
plabels = self.proto_layer._labels
y_pred = stratified_sum(posterior, plabels)