[BUGFIX] Probabilistic Models work on GPU now
This commit is contained in:
parent
459f7c24be
commit
1b09b1d57b
@ -36,7 +36,8 @@ class ProbabilisticLVQ(GLVQ):
|
|||||||
distances = self._forward(x)
|
distances = self._forward(x)
|
||||||
conditional = self.conditional_distribution(distances,
|
conditional = self.conditional_distribution(distances,
|
||||||
self.hparams.variance)
|
self.hparams.variance)
|
||||||
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes)
|
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
|
||||||
|
device=self.device)
|
||||||
posterior = conditional * prior
|
posterior = conditional * prior
|
||||||
plabels = self.proto_layer._labels
|
plabels = self.proto_layer._labels
|
||||||
y_pred = stratified_sum(posterior, plabels)
|
y_pred = stratified_sum(posterior, plabels)
|
||||||
|
Loading…
Reference in New Issue
Block a user