From 3edb13baf4ab5e0bdd935e2ba83475b3e1122b48 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 1 Mar 2021 18:52:54 +0100 Subject: [PATCH] Update examples/glvq_iris.py script --- examples/glvq_iris.py | 26 +++++++++++++++----------- setup.py | 1 + 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index 8d24991..fd23fee 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -5,6 +5,7 @@ import torch from matplotlib import pyplot as plt from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler +from torchinfo import summary from prototorch.functions.distances import euclidean_distance from prototorch.modules.losses import GLVQLoss @@ -27,7 +28,7 @@ class Model(torch.nn.Module): input_dim=2, prototypes_per_class=3, nclasses=3, - prototype_initializer='stratified_random', + prototype_initializer="stratified_random", data=[x_train, y_train]) def forward(self, x): @@ -40,21 +41,24 @@ class Model(torch.nn.Module): # Build the GLVQ model model = Model() +# Print summary using torchinfo (might be buggy/incorrect) +print(summary(model)) + # Optimize using SGD optimizer from `torch.optim` optimizer = torch.optim.SGD(model.parameters(), lr=0.01) -criterion = GLVQLoss(squashing='sigmoid_beta', beta=10) +criterion = GLVQLoss(squashing="sigmoid_beta", beta=10) x_in = torch.Tensor(x_train) y_in = torch.Tensor(y_train) # Training loop -title = 'Prototype Visualization' +title = "Prototype Visualization" fig = plt.figure(title) for epoch in range(70): # Compute loss dis, plabels = model(x_in) loss = criterion([dis, plabels], y_in) - print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}') + print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}") # Take a gradient descent step optimizer.zero_grad() @@ -64,23 +68,23 @@ for epoch in range(70): # Get the prototypes form the model protos = model.proto_layer.prototypes.data.numpy() if np.isnan(np.sum(protos)): - print('Stopping training because of `nan` in prototypes.') + print("Stopping training because of `nan` in prototypes.") break # Visualize the data and the prototypes ax = fig.gca() ax.cla() ax.set_title(title) - ax.set_xlabel('Data dimension 1') - ax.set_ylabel('Data dimension 2') - cmap = 'viridis' - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k') + ax.set_xlabel("Data dimension 1") + ax.set_ylabel("Data dimension 2") + cmap = "viridis" + ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter(protos[:, 0], protos[:, 1], c=plabels, cmap=cmap, - edgecolor='k', - marker='D', + edgecolor="k", + marker="D", s=50) # Paint decision regions diff --git a/setup.py b/setup.py index a16692a..8cfa440 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ DATASETS = [ EXAMPLES = [ "sklearn", "matplotlib", + "torchinfo", ] TESTS = ["pytest"] ALL = DOCS + DATASETS + EXAMPLES + TESTS