Fix divide-by-zero in example
This commit is contained in:
parent
3e6aa6a20b
commit
d5ab9c3771
@ -23,15 +23,16 @@ class Model(torch.nn.Module):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
"""GLVQ model."""
|
"""GLVQ model."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.p1 = Prototypes1D(input_dim=2,
|
self.proto_layer = Prototypes1D(
|
||||||
prototypes_per_class=3,
|
input_dim=2,
|
||||||
nclasses=3,
|
prototypes_per_class=3,
|
||||||
prototype_initializer='stratified_random',
|
nclasses=3,
|
||||||
data=[x_train, y_train])
|
prototype_initializer='stratified_random',
|
||||||
|
data=[x_train, y_train])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos = self.p1.prototypes
|
protos = self.proto_layer.prototypes
|
||||||
plabels = self.p1.prototype_labels
|
plabels = self.proto_layer.prototype_labels
|
||||||
dis = euclidean_distance(x, protos)
|
dis = euclidean_distance(x, protos)
|
||||||
return dis, plabels
|
return dis, plabels
|
||||||
|
|
||||||
@ -61,7 +62,10 @@ for epoch in range(70):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# Get the prototypes form the model
|
# Get the prototypes form the model
|
||||||
protos = model.p1.prototypes.data.numpy()
|
protos = model.proto_layer.prototypes.data.numpy()
|
||||||
|
if np.isnan(np.sum(protos)):
|
||||||
|
print(f'Stopping because of `nan` in prototypes.')
|
||||||
|
break
|
||||||
|
|
||||||
# Visualize the data and the prototypes
|
# Visualize the data and the prototypes
|
||||||
ax = fig.gca()
|
ax = fig.gca()
|
||||||
|
@ -76,7 +76,11 @@ def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
|||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
|
def stratified_random(x_train,
|
||||||
|
y_train,
|
||||||
|
prototype_distribution,
|
||||||
|
one_hot=True,
|
||||||
|
epsilon=1e-7):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
pdim = x_train.shape[1]
|
pdim = x_train.shape[1]
|
||||||
protos = torch.empty(nprotos, pdim)
|
protos = torch.empty(nprotos, pdim)
|
||||||
@ -89,7 +93,7 @@ def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
|
|||||||
xl = x_train[matcher]
|
xl = x_train[matcher]
|
||||||
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
||||||
random_xl = xl[rand_index]
|
random_xl = xl[rand_index]
|
||||||
protos[i] = random_xl
|
protos[i] = random_xl + epsilon
|
||||||
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
@ -15,6 +15,6 @@ class GLVQLoss(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, outputs, targets):
|
def forward(self, outputs, targets):
|
||||||
distances, plabels = outputs
|
distances, plabels = outputs
|
||||||
mu = glvq_loss(distances, targets, plabels)
|
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||||
return torch.sum(batch_loss, dim=0)
|
return torch.sum(batch_loss, dim=0)
|
||||||
|
Loading…
Reference in New Issue
Block a user