Fix divide-by-zero in example

This commit is contained in:
Jensun Ravichandran 2020-09-23 15:29:26 +02:00
parent 3e6aa6a20b
commit d5ab9c3771
3 changed files with 19 additions and 11 deletions

View File

@ -23,15 +23,16 @@ class Model(torch.nn.Module):
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""GLVQ model.""" """GLVQ model."""
super().__init__() super().__init__()
self.p1 = Prototypes1D(input_dim=2, self.proto_layer = Prototypes1D(
prototypes_per_class=3, input_dim=2,
nclasses=3, prototypes_per_class=3,
prototype_initializer='stratified_random', nclasses=3,
data=[x_train, y_train]) prototype_initializer='stratified_random',
data=[x_train, y_train])
def forward(self, x): def forward(self, x):
protos = self.p1.prototypes protos = self.proto_layer.prototypes
plabels = self.p1.prototype_labels plabels = self.proto_layer.prototype_labels
dis = euclidean_distance(x, protos) dis = euclidean_distance(x, protos)
return dis, plabels return dis, plabels
@ -61,7 +62,10 @@ for epoch in range(70):
optimizer.step() optimizer.step()
# Get the prototypes form the model # Get the prototypes form the model
protos = model.p1.prototypes.data.numpy() protos = model.proto_layer.prototypes.data.numpy()
if np.isnan(np.sum(protos)):
print(f'Stopping because of `nan` in prototypes.')
break
# Visualize the data and the prototypes # Visualize the data and the prototypes
ax = fig.gca() ax = fig.gca()

View File

@ -76,7 +76,11 @@ def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
@register_initializer @register_initializer
def stratified_random(x_train, y_train, prototype_distribution, one_hot=True): def stratified_random(x_train,
y_train,
prototype_distribution,
one_hot=True,
epsilon=1e-7):
nprotos = torch.sum(prototype_distribution) nprotos = torch.sum(prototype_distribution)
pdim = x_train.shape[1] pdim = x_train.shape[1]
protos = torch.empty(nprotos, pdim) protos = torch.empty(nprotos, pdim)
@ -89,7 +93,7 @@ def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
xl = x_train[matcher] xl = x_train[matcher]
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1) rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
random_xl = xl[rand_index] random_xl = xl[rand_index]
protos[i] = random_xl protos[i] = random_xl + epsilon
plabels = labels_from(prototype_distribution, one_hot=one_hot) plabels = labels_from(prototype_distribution, one_hot=one_hot)
return protos, plabels return protos, plabels

View File

@ -15,6 +15,6 @@ class GLVQLoss(torch.nn.Module):
def forward(self, outputs, targets): def forward(self, outputs, targets):
distances, plabels = outputs distances, plabels = outputs
mu = glvq_loss(distances, targets, plabels) mu = glvq_loss(distances, targets, prototype_labels=plabels)
batch_loss = self.squashing(mu + self.margin, beta=self.beta) batch_loss = self.squashing(mu + self.margin, beta=self.beta)
return torch.sum(batch_loss, dim=0) return torch.sum(batch_loss, dim=0)