[BUGFIX] Probabilistic Models work on GPU now
This commit is contained in:
parent
459f7c24be
commit
1b09b1d57b
@ -36,7 +36,8 @@ class ProbabilisticLVQ(GLVQ):
|
||||
distances = self._forward(x)
|
||||
conditional = self.conditional_distribution(distances,
|
||||
self.hparams.variance)
|
||||
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes)
|
||||
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
|
||||
device=self.device)
|
||||
posterior = conditional * prior
|
||||
plabels = self.proto_layer._labels
|
||||
y_pred = stratified_sum(posterior, plabels)
|
||||
|
Loading…
Reference in New Issue
Block a user