Fix imports in examples/gmlvq_tecator.py

This commit is contained in:
Jensun Ravichandran 2021-03-01 18:45:41 +01:00
parent 2322876eb6
commit 42cedbb2b8

View File

@ -6,13 +6,12 @@ from torch.utils.data import DataLoader
from prototorch.datasets.tecator import Tecator from prototorch.datasets.tecator import Tecator
from prototorch.functions.distances import sed from prototorch.functions.distances import sed
from prototorch.functions.normalizations import normalize_omegat_
from prototorch.modules import Prototypes1D from prototorch.modules import Prototypes1D
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.utils.colors import handles_and_colors from prototorch.utils.colors import get_legend_handles
# Prepare the dataset and dataloader # Prepare the dataset and dataloader
train_data = Tecator(root='./artifacts', train=True) train_data = Tecator(root="./artifacts", train=True)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True) train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
@ -24,7 +23,7 @@ class Model(torch.nn.Module):
self.p1 = Prototypes1D(input_dim=100, self.p1 = Prototypes1D(input_dim=100,
prototypes_per_class=2, prototypes_per_class=2,
nclasses=2, nclasses=2,
prototype_initializer='stratified_random', prototype_initializer="stratified_random",
data=[x, y]) data=[x, y])
self.omega = torch.nn.Linear(in_features=100, self.omega = torch.nn.Linear(in_features=100,
out_features=100, out_features=100,
@ -53,7 +52,7 @@ print(model)
# Optimize using Adam optimizer from `torch.optim` # Optimize using Adam optimizer from `torch.optim`
optimizer = torch.optim.Adam(model.parameters(), lr=0.001_0) optimizer = torch.optim.Adam(model.parameters(), lr=0.001_0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1)
criterion = GLVQLoss(squashing='identity', beta=10) criterion = GLVQLoss(squashing="identity", beta=10)
# Training loop # Training loop
for epoch in range(150): for epoch in range(150):
@ -64,29 +63,24 @@ for epoch in range(150):
distances, plabels = model(xb) distances, plabels = model(xb)
loss = criterion([distances, plabels], yb) loss = criterion([distances, plabels], yb)
epoch_loss += loss.item() epoch_loss += loss.item()
# Backprop # Backprop
loss.backward() loss.backward()
# Normalize omega
normalize_omegat_(model.omega.weight)
# Take a gradient descent step # Take a gradient descent step
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
lr = optimizer.param_groups[0]['lr'] lr = optimizer.param_groups[0]["lr"]
print(f'Epoch: {epoch + 1:03d} Loss: {epoch_loss:06.02f} lr: {lr:07.06f}') print(f"Epoch: {epoch + 1:03d} Loss: {epoch_loss:06.02f} lr: {lr:07.06f}")
# Get the omega matrix form the model # Get the omega matrix form the model
omega = model.omega.weight.data.numpy().T omega = model.omega.weight.data.numpy().T
# Visualize the lambda matrix # Visualize the lambda matrix
title = 'Lambda Matrix Visualization' title = "Lambda Matrix Visualization"
fig = plt.figure(title) fig = plt.figure(title)
ax = fig.gca() ax = fig.gca()
ax.set_title(title) ax.set_title(title)
im = ax.imshow(omega.dot(omega.T), cmap='viridis') im = ax.imshow(omega.dot(omega.T), cmap="viridis")
plt.show() plt.show()
# Get the prototypes form the model # Get the prototypes form the model
@ -94,14 +88,14 @@ protos = model.p1.prototypes.data.numpy()
plabels = model.p1.prototype_labels plabels = model.p1.prototype_labels
# Visualize the prototypes # Visualize the prototypes
title = 'Tecator Prototypes' title = "Tecator Prototypes"
fig = plt.figure(title) fig = plt.figure(title)
ax = fig.gca() ax = fig.gca()
ax.set_title(title) ax.set_title(title)
ax.set_xlabel('Spectral frequencies') ax.set_xlabel("Spectral frequencies")
ax.set_ylabel('Absorption') ax.set_ylabel("Absorption")
clabels = ['Class 0 - Low fat', 'Class 1 - High fat'] clabels = ["Class 0 - Low fat", "Class 1 - High fat"]
handles, colors = handles_and_colors(clabels, marker='line') handles, colors = get_legend_handles(clabels, marker="line", zero_indexed=True)
for x, y in zip(protos, plabels): for x, y in zip(protos, plabels):
ax.plot(x, c=colors[int(y)]) ax.plot(x, c=colors[int(y)])
ax.legend(handles, clabels) ax.legend(handles, clabels)