2021-01-12 17:11:46 +00:00
|
|
|
"""
|
|
|
|
ProtoTorch GTLVQ example using MNIST data.
|
2021-01-14 09:04:43 +00:00
|
|
|
The GTLVQ is placed as an classification model on
|
2021-01-12 17:11:46 +00:00
|
|
|
top of a CNN, considered as featurer extractor.
|
2021-01-14 09:04:43 +00:00
|
|
|
Initialization of subpsace and prototypes in
|
2021-01-12 17:11:46 +00:00
|
|
|
Siamnese fashion
|
|
|
|
For more info about GTLVQ see:
|
|
|
|
DOI:10.1109/IJCNN.2016.7727534
|
|
|
|
"""
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torchvision
|
|
|
|
from prototorch.functions.helper import calculate_prototype_accuracy
|
2021-04-23 15:24:53 +00:00
|
|
|
from prototorch.modules.losses import GLVQLoss
|
2021-01-12 17:11:46 +00:00
|
|
|
from prototorch.modules.models import GTLVQ
|
2021-05-25 13:57:05 +00:00
|
|
|
from torchvision import transforms
|
2021-01-12 17:11:46 +00:00
|
|
|
|
|
|
|
# Parameters and options
|
2021-05-25 13:57:05 +00:00
|
|
|
num_epochs = 50
|
2021-01-12 17:11:46 +00:00
|
|
|
batch_size_train = 64
|
|
|
|
batch_size_test = 1000
|
|
|
|
learning_rate = 0.1
|
|
|
|
momentum = 0.5
|
|
|
|
log_interval = 10
|
2021-05-28 13:57:26 +00:00
|
|
|
cuda = "cuda:0"
|
2021-01-12 17:11:46 +00:00
|
|
|
random_seed = 1
|
2021-04-23 15:24:53 +00:00
|
|
|
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
|
2021-01-12 17:11:46 +00:00
|
|
|
|
|
|
|
# Configures reproducability
|
|
|
|
torch.manual_seed(random_seed)
|
|
|
|
np.random.seed(random_seed)
|
|
|
|
|
|
|
|
# Prepare and preprocess the data
|
2021-04-23 15:24:53 +00:00
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
|
|
torchvision.datasets.MNIST(
|
|
|
|
"./files/",
|
|
|
|
train=True,
|
|
|
|
download=True,
|
|
|
|
transform=torchvision.transforms.Compose([
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.1307, ), (0.3081, ))
|
|
|
|
]),
|
|
|
|
),
|
|
|
|
batch_size=batch_size_train,
|
|
|
|
shuffle=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
test_loader = torch.utils.data.DataLoader(
|
|
|
|
torchvision.datasets.MNIST(
|
|
|
|
"./files/",
|
|
|
|
train=False,
|
|
|
|
download=True,
|
|
|
|
transform=torchvision.transforms.Compose([
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.1307, ), (0.3081, ))
|
|
|
|
]),
|
|
|
|
),
|
|
|
|
batch_size=batch_size_test,
|
|
|
|
shuffle=True,
|
|
|
|
)
|
2021-01-12 17:11:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
# Define the GLVQ model plus appropriate feature extractor
|
|
|
|
class CNNGTLVQ(torch.nn.Module):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
num_classes,
|
|
|
|
subspace_data,
|
|
|
|
prototype_data,
|
|
|
|
tangent_projection_type="local",
|
|
|
|
prototypes_per_class=2,
|
|
|
|
bottleneck_dim=128,
|
|
|
|
):
|
|
|
|
super(CNNGTLVQ, self).__init__()
|
|
|
|
|
2021-04-23 15:24:53 +00:00
|
|
|
# Feature Extractor - Simple CNN
|
|
|
|
self.fe = nn.Sequential(
|
|
|
|
nn.Conv2d(1, 32, 3, 1),
|
|
|
|
nn.ReLU(),
|
|
|
|
nn.Conv2d(32, 64, 3, 1),
|
|
|
|
nn.ReLU(),
|
|
|
|
nn.MaxPool2d(2),
|
|
|
|
nn.Dropout(0.25),
|
|
|
|
nn.Flatten(),
|
|
|
|
nn.Linear(9216, bottleneck_dim),
|
|
|
|
nn.Dropout(0.5),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
nn.LayerNorm(bottleneck_dim),
|
|
|
|
)
|
2021-01-12 17:11:46 +00:00
|
|
|
|
|
|
|
# Forward pass of subspace and prototype initialization data through feature extractor
|
|
|
|
subspace_data = self.fe(subspace_data)
|
|
|
|
prototype_data[0] = self.fe(prototype_data[0])
|
|
|
|
|
|
|
|
# Initialization of GTLVQ
|
2021-04-23 15:24:53 +00:00
|
|
|
self.gtlvq = GTLVQ(
|
|
|
|
num_classes,
|
|
|
|
subspace_data,
|
|
|
|
prototype_data,
|
|
|
|
tangent_projection_type=tangent_projection_type,
|
|
|
|
feature_dim=bottleneck_dim,
|
|
|
|
prototypes_per_class=prototypes_per_class,
|
|
|
|
)
|
2021-01-12 17:11:46 +00:00
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
# Feature Extraction
|
|
|
|
x = self.fe(x)
|
|
|
|
|
|
|
|
# GTLVQ Forward pass
|
|
|
|
dis = self.gtlvq(x)
|
|
|
|
return dis
|
|
|
|
|
|
|
|
|
|
|
|
# Get init data
|
|
|
|
subspace_data = torch.cat(
|
|
|
|
[next(iter(train_loader))[0],
|
|
|
|
next(iter(test_loader))[0]])
|
|
|
|
prototype_data = next(iter(train_loader))
|
|
|
|
|
|
|
|
# Build the CNN GTLVQ model
|
2021-04-23 15:24:53 +00:00
|
|
|
model = CNNGTLVQ(
|
|
|
|
10,
|
|
|
|
subspace_data,
|
|
|
|
prototype_data,
|
|
|
|
tangent_projection_type="local",
|
|
|
|
bottleneck_dim=128,
|
|
|
|
).to(device)
|
2021-01-12 17:11:46 +00:00
|
|
|
|
|
|
|
# Optimize using SGD optimizer from `torch.optim`
|
2021-04-23 15:24:53 +00:00
|
|
|
optimizer = torch.optim.Adam(
|
|
|
|
[{
|
|
|
|
"params": model.fe.parameters()
|
|
|
|
}, {
|
|
|
|
"params": model.gtlvq.parameters()
|
|
|
|
}],
|
|
|
|
lr=learning_rate,
|
|
|
|
)
|
|
|
|
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
2021-01-12 17:11:46 +00:00
|
|
|
|
|
|
|
# Training loop
|
2021-05-25 13:57:05 +00:00
|
|
|
for epoch in range(num_epochs):
|
2021-01-12 17:11:46 +00:00
|
|
|
for batch_idx, (x_train, y_train) in enumerate(train_loader):
|
|
|
|
model.train()
|
|
|
|
x_train, y_train = x_train.to(device), y_train.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
distances = model(x_train)
|
2021-05-28 13:57:26 +00:00
|
|
|
plabels = model.gtlvq.cls.component_labels.to(device)
|
2021-01-12 17:11:46 +00:00
|
|
|
|
|
|
|
# Compute loss.
|
|
|
|
loss = criterion([distances, plabels], y_train)
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
# GTLVQ uses projected SGD, which means to orthogonalize the subspaces after every gradient update.
|
|
|
|
model.gtlvq.orthogonalize_subspace()
|
|
|
|
|
|
|
|
if batch_idx % log_interval == 0:
|
|
|
|
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
|
|
|
print(
|
2021-05-25 13:57:05 +00:00
|
|
|
f"Epoch: {epoch + 1:02d}/{num_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
2021-04-23 15:24:53 +00:00
|
|
|
Train Acc: {acc.item():02.02f}")
|
2021-01-12 17:11:46 +00:00
|
|
|
|
|
|
|
# Test
|
|
|
|
with torch.no_grad():
|
|
|
|
model.eval()
|
|
|
|
correct = 0
|
|
|
|
total = 0
|
|
|
|
for x_test, y_test in test_loader:
|
|
|
|
x_test, y_test = x_test.to(device), y_test.to(device)
|
|
|
|
test_distances = model(torch.tensor(x_test))
|
|
|
|
test_plabels = model.gtlvq.cls.prototype_labels.to(device)
|
|
|
|
i = torch.argmin(test_distances, 1)
|
|
|
|
correct += torch.sum(y_test == test_plabels[i])
|
|
|
|
total += y_test.size(0)
|
2021-04-23 15:24:53 +00:00
|
|
|
print("Accuracy of the network on the test images: %d %%" %
|
2021-01-12 17:11:46 +00:00
|
|
|
(torch.true_divide(correct, total) * 100))
|
|
|
|
|
|
|
|
# Save the model
|
2021-04-23 15:24:53 +00:00
|
|
|
PATH = "./glvq_mnist_model.pth"
|
2021-01-12 17:11:46 +00:00
|
|
|
torch.save(model.state_dict(), PATH)
|