Fix divide-by-zero in example

This commit is contained in:
Jensun Ravichandran
2020-09-23 15:29:26 +02:00
parent 3e6aa6a20b
commit d5ab9c3771
3 changed files with 19 additions and 11 deletions

View File

@@ -23,15 +23,16 @@ class Model(torch.nn.Module):
def __init__(self, **kwargs):
"""GLVQ model."""
super().__init__()
self.p1 = Prototypes1D(input_dim=2,
prototypes_per_class=3,
nclasses=3,
prototype_initializer='stratified_random',
data=[x_train, y_train])
self.proto_layer = Prototypes1D(
input_dim=2,
prototypes_per_class=3,
nclasses=3,
prototype_initializer='stratified_random',
data=[x_train, y_train])
def forward(self, x):
protos = self.p1.prototypes
plabels = self.p1.prototype_labels
protos = self.proto_layer.prototypes
plabels = self.proto_layer.prototype_labels
dis = euclidean_distance(x, protos)
return dis, plabels
@@ -61,7 +62,10 @@ for epoch in range(70):
optimizer.step()
# Get the prototypes form the model
protos = model.p1.prototypes.data.numpy()
protos = model.proto_layer.prototypes.data.numpy()
if np.isnan(np.sum(protos)):
print(f'Stopping because of `nan` in prototypes.')
break
# Visualize the data and the prototypes
ax = fig.gca()