[Refactor] Add Modules for prior distrbutions
This commit is contained in:
parent
4f1c879528
commit
47d7f5831f
@ -1,5 +1,32 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def gaussian(distance, variance):
|
# Functions
|
||||||
return torch.exp(-(distance * distance) / (2 * variance))
|
def gaussian(distances, variance):
|
||||||
|
return torch.exp(-(distances * distances) / (2 * variance))
|
||||||
|
|
||||||
|
|
||||||
|
def rank_scaled_gaussian(distances, lambd):
|
||||||
|
order = torch.argsort(distances, dim=1)
|
||||||
|
ranks = torch.argsort(order, dim=1)
|
||||||
|
|
||||||
|
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
||||||
|
|
||||||
|
|
||||||
|
# Modules
|
||||||
|
class GaussianPrior(torch.nn.Module):
|
||||||
|
def __init__(self, variance):
|
||||||
|
super().__init__()
|
||||||
|
self.variance = variance
|
||||||
|
|
||||||
|
def forward(self, distances):
|
||||||
|
return gaussian(distances, self.variance)
|
||||||
|
|
||||||
|
|
||||||
|
class RankScaledGaussianPrior(torch.nn.Module):
|
||||||
|
def __init__(self, lambd):
|
||||||
|
super().__init__()
|
||||||
|
self.lambd = lambd
|
||||||
|
|
||||||
|
def forward(self, distances):
|
||||||
|
return rank_scaled_gaussian(distances, self.lambd)
|
||||||
|
Loading…
Reference in New Issue
Block a user