prototorch/examples/gmlvq_tecator.py

109 lines
3.3 KiB
Python
Raw Normal View History

"""ProtoTorch "siamese" GMLVQ example using Tecator."""
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from prototorch.datasets.tecator import Tecator
from prototorch.functions.distances import sed
from prototorch.functions.normalizations import normalize_omegat_
from prototorch.modules import Prototypes1D
from prototorch.modules.losses import GLVQLoss
from prototorch.utils.colors import handles_and_colors
# Prepare the dataset and dataloader
train_data = Tecator(root='./artifacts', train=True)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
class Model(torch.nn.Module):
def __init__(self, **kwargs):
"""GMLVQ model as a siamese network."""
super().__init__()
x, y = train_data.data, train_data.targets
self.p1 = Prototypes1D(input_dim=100,
prototypes_per_class=2,
nclasses=2,
prototype_initializer='stratified_random',
data=[x, y])
self.omega = torch.nn.Linear(in_features=100,
out_features=100,
bias=False)
torch.nn.init.eye_(self.omega.weight)
def forward(self, x):
protos = self.p1.prototypes
plabels = self.p1.prototype_labels
# Process `x` and `protos` through `omega`
x_map = self.omega(x)
protos_map = self.omega(protos)
# Compute distances and output
dis = sed(x_map, protos_map)
return dis, plabels
# Build the GLVQ model
model = Model()
# Print a summary of the model
print(model)
# Optimize using Adam optimizer from `torch.optim`
optimizer = torch.optim.Adam(model.parameters(), lr=0.001_0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1)
criterion = GLVQLoss(squashing='identity', beta=10)
# Training loop
for epoch in range(150):
epoch_loss = 0.0 # zero-out epoch loss
optimizer.zero_grad() # zero-out gradients
for xb, yb in train_loader:
# Compute loss
distances, plabels = model(xb)
loss = criterion([distances, plabels], yb)
epoch_loss += loss.item()
# Backprop
loss.backward()
# Normalize omega
normalize_omegat_(model.omega.weight)
# Take a gradient descent step
optimizer.step()
scheduler.step()
lr = optimizer.param_groups[0]['lr']
print(f'Epoch: {epoch + 1:03d} Loss: {epoch_loss:06.02f} lr: {lr:07.06f}')
# Get the omega matrix form the model
omega = model.omega.weight.data.numpy().T
# Visualize the lambda matrix
title = 'Lambda Matrix Visualization'
fig = plt.figure(title)
ax = fig.gca()
ax.set_title(title)
im = ax.imshow(omega.dot(omega.T), cmap='viridis')
plt.show()
# Get the prototypes form the model
protos = model.p1.prototypes.data.numpy()
plabels = model.p1.prototype_labels
# Visualize the prototypes
title = 'Tecator Prototypes'
fig = plt.figure(title)
ax = fig.gca()
ax.set_title(title)
ax.set_xlabel('Spectral frequencies')
ax.set_ylabel('Absorption')
clabels = ['Class 0 - Low fat', 'Class 1 - High fat']
handles, colors = handles_and_colors(clabels, marker='line')
for x, y in zip(protos, plabels):
ax.plot(x, c=colors[int(y)])
ax.legend(handles, clabels)
plt.show()