Fix divide-by-zero in example
This commit is contained in:
@@ -76,7 +76,11 @@ def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
|
||||
|
||||
@register_initializer
|
||||
def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
def stratified_random(x_train,
|
||||
y_train,
|
||||
prototype_distribution,
|
||||
one_hot=True,
|
||||
epsilon=1e-7):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
pdim = x_train.shape[1]
|
||||
protos = torch.empty(nprotos, pdim)
|
||||
@@ -89,7 +93,7 @@ def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
xl = x_train[matcher]
|
||||
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
||||
random_xl = xl[rand_index]
|
||||
protos[i] = random_xl
|
||||
protos[i] = random_xl + epsilon
|
||||
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
@@ -15,6 +15,6 @@ class GLVQLoss(torch.nn.Module):
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
distances, plabels = outputs
|
||||
mu = glvq_loss(distances, targets, plabels)
|
||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||
return torch.sum(batch_loss, dim=0)
|
||||
|
Reference in New Issue
Block a user