162
examples/gtlvq_mnist.py
Normal file
162
examples/gtlvq_mnist.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
ProtoTorch GTLVQ example using MNIST data.
|
||||
The GTLVQ is placed as an classification model on
|
||||
top of a CNN, considered as featurer extractor.
|
||||
Initialization of subpsace and prototypes in
|
||||
Siamnese fashion
|
||||
For more info about GTLVQ see:
|
||||
DOI:10.1109/IJCNN.2016.7727534
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.functions.helper import calculate_prototype_accuracy
|
||||
from prototorch.modules.models import GTLVQ
|
||||
|
||||
# Parameters and options
|
||||
n_epochs = 50
|
||||
batch_size_train = 64
|
||||
batch_size_test = 1000
|
||||
learning_rate = 0.1
|
||||
momentum = 0.5
|
||||
log_interval = 10
|
||||
cuda = "cuda:1"
|
||||
random_seed = 1
|
||||
device = torch.device(cuda if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Configures reproducability
|
||||
torch.manual_seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
|
||||
# Prepare and preprocess the data
|
||||
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
||||
'./files/',
|
||||
train=True,
|
||||
download=True,
|
||||
transform=torchvision.transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
||||
batch_size=batch_size_train,
|
||||
shuffle=True)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
||||
'./files/',
|
||||
train=False,
|
||||
download=True,
|
||||
transform=torchvision.transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
||||
batch_size=batch_size_test,
|
||||
shuffle=True)
|
||||
|
||||
|
||||
# Define the GLVQ model plus appropriate feature extractor
|
||||
class CNNGTLVQ(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
subspace_data,
|
||||
prototype_data,
|
||||
tangent_projection_type="local",
|
||||
prototypes_per_class=2,
|
||||
bottleneck_dim=128,
|
||||
):
|
||||
super(CNNGTLVQ, self).__init__()
|
||||
|
||||
#Feature Extractor - Simple CNN
|
||||
self.fe = nn.Sequential(nn.Conv2d(1, 32, 3, 1), nn.ReLU(),
|
||||
nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
|
||||
nn.MaxPool2d(2), nn.Dropout(0.25),
|
||||
nn.Flatten(), nn.Linear(9216, bottleneck_dim),
|
||||
nn.Dropout(0.5), nn.LeakyReLU(),
|
||||
nn.LayerNorm(bottleneck_dim))
|
||||
|
||||
# Forward pass of subspace and prototype initialization data through feature extractor
|
||||
subspace_data = self.fe(subspace_data)
|
||||
prototype_data[0] = self.fe(prototype_data[0])
|
||||
|
||||
# Initialization of GTLVQ
|
||||
self.gtlvq = GTLVQ(num_classes,
|
||||
subspace_data,
|
||||
prototype_data,
|
||||
tangent_projection_type=tangent_projection_type,
|
||||
feature_dim=bottleneck_dim,
|
||||
prototypes_per_class=prototypes_per_class)
|
||||
|
||||
def forward(self, x):
|
||||
# Feature Extraction
|
||||
x = self.fe(x)
|
||||
|
||||
# GTLVQ Forward pass
|
||||
dis = self.gtlvq(x)
|
||||
return dis
|
||||
|
||||
|
||||
# Get init data
|
||||
subspace_data = torch.cat(
|
||||
[next(iter(train_loader))[0],
|
||||
next(iter(test_loader))[0]])
|
||||
prototype_data = next(iter(train_loader))
|
||||
|
||||
# Build the CNN GTLVQ model
|
||||
model = CNNGTLVQ(10,
|
||||
subspace_data,
|
||||
prototype_data,
|
||||
tangent_projection_type="local",
|
||||
bottleneck_dim=128).to(device)
|
||||
|
||||
# Optimize using SGD optimizer from `torch.optim`
|
||||
optimizer = torch.optim.Adam([{
|
||||
'params': model.fe.parameters()
|
||||
}, {
|
||||
'params': model.gtlvq.parameters()
|
||||
}],
|
||||
lr=learning_rate)
|
||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(n_epochs):
|
||||
for batch_idx, (x_train, y_train) in enumerate(train_loader):
|
||||
model.train()
|
||||
x_train, y_train = x_train.to(device), y_train.to(device)
|
||||
optimizer.zero_grad()
|
||||
|
||||
distances = model(x_train)
|
||||
plabels = model.gtlvq.cls.prototype_labels.to(device)
|
||||
|
||||
# Compute loss.
|
||||
loss = criterion([distances, plabels], y_train)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# GTLVQ uses projected SGD, which means to orthogonalize the subspaces after every gradient update.
|
||||
model.gtlvq.orthogonalize_subspace()
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
||||
print(
|
||||
f'Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
||||
Train Acc: {acc.item():02.02f}')
|
||||
|
||||
# Test
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
for x_test, y_test in test_loader:
|
||||
x_test, y_test = x_test.to(device), y_test.to(device)
|
||||
test_distances = model(torch.tensor(x_test))
|
||||
test_plabels = model.gtlvq.cls.prototype_labels.to(device)
|
||||
i = torch.argmin(test_distances, 1)
|
||||
correct += torch.sum(y_test == test_plabels[i])
|
||||
total += y_test.size(0)
|
||||
print('Accuracy of the network on the test images: %d %%' %
|
||||
(torch.true_divide(correct, total) * 100))
|
||||
|
||||
# Save the model
|
||||
PATH = './glvq_mnist_model.pth'
|
||||
torch.save(model.state_dict(), PATH)
|
Reference in New Issue
Block a user