Fix divide-by-zero in example
This commit is contained in:
@@ -23,15 +23,16 @@ class Model(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
"""GLVQ model."""
|
||||
super().__init__()
|
||||
self.p1 = Prototypes1D(input_dim=2,
|
||||
prototypes_per_class=3,
|
||||
nclasses=3,
|
||||
prototype_initializer='stratified_random',
|
||||
data=[x_train, y_train])
|
||||
self.proto_layer = Prototypes1D(
|
||||
input_dim=2,
|
||||
prototypes_per_class=3,
|
||||
nclasses=3,
|
||||
prototype_initializer='stratified_random',
|
||||
data=[x_train, y_train])
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.p1.prototypes
|
||||
plabels = self.p1.prototype_labels
|
||||
protos = self.proto_layer.prototypes
|
||||
plabels = self.proto_layer.prototype_labels
|
||||
dis = euclidean_distance(x, protos)
|
||||
return dis, plabels
|
||||
|
||||
@@ -61,7 +62,10 @@ for epoch in range(70):
|
||||
optimizer.step()
|
||||
|
||||
# Get the prototypes form the model
|
||||
protos = model.p1.prototypes.data.numpy()
|
||||
protos = model.proto_layer.prototypes.data.numpy()
|
||||
if np.isnan(np.sum(protos)):
|
||||
print(f'Stopping because of `nan` in prototypes.')
|
||||
break
|
||||
|
||||
# Visualize the data and the prototypes
|
||||
ax = fig.gca()
|
||||
|
Reference in New Issue
Block a user