Fix divide-by-zero in example
This commit is contained in:
		@@ -23,15 +23,16 @@ class Model(torch.nn.Module):
 | 
				
			|||||||
    def __init__(self, **kwargs):
 | 
					    def __init__(self, **kwargs):
 | 
				
			||||||
        """GLVQ model."""
 | 
					        """GLVQ model."""
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.p1 = Prototypes1D(input_dim=2,
 | 
					        self.proto_layer = Prototypes1D(
 | 
				
			||||||
 | 
					            input_dim=2,
 | 
				
			||||||
            prototypes_per_class=3,
 | 
					            prototypes_per_class=3,
 | 
				
			||||||
            nclasses=3,
 | 
					            nclasses=3,
 | 
				
			||||||
            prototype_initializer='stratified_random',
 | 
					            prototype_initializer='stratified_random',
 | 
				
			||||||
            data=[x_train, y_train])
 | 
					            data=[x_train, y_train])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        protos = self.p1.prototypes
 | 
					        protos = self.proto_layer.prototypes
 | 
				
			||||||
        plabels = self.p1.prototype_labels
 | 
					        plabels = self.proto_layer.prototype_labels
 | 
				
			||||||
        dis = euclidean_distance(x, protos)
 | 
					        dis = euclidean_distance(x, protos)
 | 
				
			||||||
        return dis, plabels
 | 
					        return dis, plabels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -61,7 +62,10 @@ for epoch in range(70):
 | 
				
			|||||||
    optimizer.step()
 | 
					    optimizer.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Get the prototypes form the model
 | 
					    # Get the prototypes form the model
 | 
				
			||||||
    protos = model.p1.prototypes.data.numpy()
 | 
					    protos = model.proto_layer.prototypes.data.numpy()
 | 
				
			||||||
 | 
					    if np.isnan(np.sum(protos)):
 | 
				
			||||||
 | 
					        print(f'Stopping because of `nan` in prototypes.')
 | 
				
			||||||
 | 
					        break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Visualize the data and the prototypes
 | 
					    # Visualize the data and the prototypes
 | 
				
			||||||
    ax = fig.gca()
 | 
					    ax = fig.gca()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -76,7 +76,11 @@ def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@register_initializer
 | 
					@register_initializer
 | 
				
			||||||
def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
 | 
					def stratified_random(x_train,
 | 
				
			||||||
 | 
					                      y_train,
 | 
				
			||||||
 | 
					                      prototype_distribution,
 | 
				
			||||||
 | 
					                      one_hot=True,
 | 
				
			||||||
 | 
					                      epsilon=1e-7):
 | 
				
			||||||
    nprotos = torch.sum(prototype_distribution)
 | 
					    nprotos = torch.sum(prototype_distribution)
 | 
				
			||||||
    pdim = x_train.shape[1]
 | 
					    pdim = x_train.shape[1]
 | 
				
			||||||
    protos = torch.empty(nprotos, pdim)
 | 
					    protos = torch.empty(nprotos, pdim)
 | 
				
			||||||
@@ -89,7 +93,7 @@ def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
 | 
				
			|||||||
        xl = x_train[matcher]
 | 
					        xl = x_train[matcher]
 | 
				
			||||||
        rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
 | 
					        rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
 | 
				
			||||||
        random_xl = xl[rand_index]
 | 
					        random_xl = xl[rand_index]
 | 
				
			||||||
        protos[i] = random_xl
 | 
					        protos[i] = random_xl + epsilon
 | 
				
			||||||
    plabels = labels_from(prototype_distribution, one_hot=one_hot)
 | 
					    plabels = labels_from(prototype_distribution, one_hot=one_hot)
 | 
				
			||||||
    return protos, plabels
 | 
					    return protos, plabels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,6 +15,6 @@ class GLVQLoss(torch.nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def forward(self, outputs, targets):
 | 
					    def forward(self, outputs, targets):
 | 
				
			||||||
        distances, plabels = outputs
 | 
					        distances, plabels = outputs
 | 
				
			||||||
        mu = glvq_loss(distances, targets, plabels)
 | 
					        mu = glvq_loss(distances, targets, prototype_labels=plabels)
 | 
				
			||||||
        batch_loss = self.squashing(mu + self.margin, beta=self.beta)
 | 
					        batch_loss = self.squashing(mu + self.margin, beta=self.beta)
 | 
				
			||||||
        return torch.sum(batch_loss, dim=0)
 | 
					        return torch.sum(batch_loss, dim=0)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user