Add siamese example using GMLVQ and Tecator

This commit is contained in:
Jensun Ravichandran 2020-07-13 09:31:48 +02:00
parent 8a4a596035
commit 6e72b9267a

108
examples/gmlvq_tecator.py Normal file
View File

@ -0,0 +1,108 @@
"""ProtoTorch "siamese" GMLVQ example using Tecator."""
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from prototorch.datasets.tecator import Tecator
from prototorch.functions.distances import sed
from prototorch.functions.normalizations import normalize_omegat_
from prototorch.modules import Prototypes1D
from prototorch.modules.losses import GLVQLoss
from prototorch.utils.colors import handles_and_colors
# Prepare the dataset and dataloader
train_data = Tecator(root='./artifacts', train=True)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
class Model(torch.nn.Module):
def __init__(self, **kwargs):
"""GMLVQ model as a siamese network."""
super().__init__()
x, y = train_data.data, train_data.targets
self.p1 = Prototypes1D(input_dim=100,
prototypes_per_class=2,
nclasses=2,
prototype_initializer='stratified_random',
data=[x, y])
self.omega = torch.nn.Linear(in_features=100,
out_features=100,
bias=False)
torch.nn.init.eye_(self.omega.weight)
def forward(self, x):
protos = self.p1.prototypes
plabels = self.p1.prototype_labels
# Process `x` and `protos` through `omega`
x_map = self.omega(x)
protos_map = self.omega(protos)
# Compute distances and output
dis = sed(x_map, protos_map)
return dis, plabels
# Build the GLVQ model
model = Model()
# Print a summary of the model
print(model)
# Optimize using Adam optimizer from `torch.optim`
optimizer = torch.optim.Adam(model.parameters(), lr=0.001_0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1)
criterion = GLVQLoss(squashing='identity', beta=10)
# Training loop
for epoch in range(150):
epoch_loss = 0.0 # zero-out epoch loss
optimizer.zero_grad() # zero-out gradients
for xb, yb in train_loader:
# Compute loss
distances, plabels = model(xb)
loss = criterion([distances, plabels], yb)
epoch_loss += loss.item()
# Backprop
loss.backward()
# Normalize omega
normalize_omegat_(model.omega.weight)
# Take a gradient descent step
optimizer.step()
scheduler.step()
lr = optimizer.param_groups[0]['lr']
print(f'Epoch: {epoch + 1:03d} Loss: {epoch_loss:06.02f} lr: {lr:07.06f}')
# Get the omega matrix form the model
omega = model.omega.weight.data.numpy().T
# Visualize the lambda matrix
title = 'Lambda Matrix Visualization'
fig = plt.figure(title)
ax = fig.gca()
ax.set_title(title)
im = ax.imshow(omega.dot(omega.T), cmap='viridis')
plt.show()
# Get the prototypes form the model
protos = model.p1.prototypes.data.numpy()
plabels = model.p1.prototype_labels
# Visualize the prototypes
title = 'Tecator Prototypes'
fig = plt.figure(title)
ax = fig.gca()
ax.set_title(title)
ax.set_xlabel('Spectral frequencies')
ax.set_ylabel('Absorption')
clabels = ['Class 0 - Low fat', 'Class 1 - High fat']
handles, colors = handles_and_colors(clabels, marker='line')
for x, y in zip(protos, plabels):
ax.plot(x, c=colors[int(y)])
ax.legend(handles, clabels)
plt.show()