[BUGFIX] Probabilistic Models work on GPU now
This commit is contained in:
		@@ -36,7 +36,8 @@ class ProbabilisticLVQ(GLVQ):
 | 
				
			|||||||
        distances = self._forward(x)
 | 
					        distances = self._forward(x)
 | 
				
			||||||
        conditional = self.conditional_distribution(distances,
 | 
					        conditional = self.conditional_distribution(distances,
 | 
				
			||||||
                                                    self.hparams.variance)
 | 
					                                                    self.hparams.variance)
 | 
				
			||||||
        prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes)
 | 
					        prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
 | 
				
			||||||
 | 
					                                                        device=self.device)
 | 
				
			||||||
        posterior = conditional * prior
 | 
					        posterior = conditional * prior
 | 
				
			||||||
        plabels = self.proto_layer._labels
 | 
					        plabels = self.proto_layer._labels
 | 
				
			||||||
        y_pred = stratified_sum(posterior, plabels)
 | 
					        y_pred = stratified_sum(posterior, plabels)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user