Fix divide-by-zero in example

This commit is contained in:
Jensun Ravichandran
2020-09-23 15:29:26 +02:00
parent 3e6aa6a20b
commit d5ab9c3771
3 changed files with 19 additions and 11 deletions

View File

@@ -76,7 +76,11 @@ def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
@register_initializer
def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
def stratified_random(x_train,
y_train,
prototype_distribution,
one_hot=True,
epsilon=1e-7):
nprotos = torch.sum(prototype_distribution)
pdim = x_train.shape[1]
protos = torch.empty(nprotos, pdim)
@@ -89,7 +93,7 @@ def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
xl = x_train[matcher]
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
random_xl = xl[rand_index]
protos[i] = random_xl
protos[i] = random_xl + epsilon
plabels = labels_from(prototype_distribution, one_hot=one_hot)
return protos, plabels

View File

@@ -15,6 +15,6 @@ class GLVQLoss(torch.nn.Module):
def forward(self, outputs, targets):
distances, plabels = outputs
mu = glvq_loss(distances, targets, plabels)
mu = glvq_loss(distances, targets, prototype_labels=plabels)
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
return torch.sum(batch_loss, dim=0)