prototorch/examples/gtlvq_mnist.py

185 lines
5.2 KiB
Python
Raw Normal View History

2021-01-12 17:11:46 +00:00
"""
ProtoTorch GTLVQ example using MNIST data.
2021-01-14 09:04:43 +00:00
The GTLVQ is placed as an classification model on
2021-01-12 17:11:46 +00:00
top of a CNN, considered as featurer extractor.
2021-01-14 09:04:43 +00:00
Initialization of subpsace and prototypes in
2021-01-12 17:11:46 +00:00
Siamnese fashion
For more info about GTLVQ see:
DOI:10.1109/IJCNN.2016.7727534
"""
import numpy as np
import torch
import torch.nn as nn
import torchvision
2021-06-16 11:46:09 +00:00
from torchvision import transforms
2021-01-12 17:11:46 +00:00
from prototorch.functions.helper import calculate_prototype_accuracy
2021-04-23 15:24:53 +00:00
from prototorch.modules.losses import GLVQLoss
2021-01-12 17:11:46 +00:00
from prototorch.modules.models import GTLVQ
# Parameters and options
2021-05-25 13:57:05 +00:00
num_epochs = 50
2021-01-12 17:11:46 +00:00
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.1
momentum = 0.5
log_interval = 10
cuda = "cuda:0"
2021-01-12 17:11:46 +00:00
random_seed = 1
2021-04-23 15:24:53 +00:00
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
2021-01-12 17:11:46 +00:00
# Configures reproducability
torch.manual_seed(random_seed)
np.random.seed(random_seed)
# Prepare and preprocess the data
2021-04-23 15:24:53 +00:00
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
"./files/",
train=True,
download=True,
transform=torchvision.transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
]),
),
batch_size=batch_size_train,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
"./files/",
train=False,
download=True,
transform=torchvision.transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
]),
),
batch_size=batch_size_test,
shuffle=True,
)
2021-01-12 17:11:46 +00:00
# Define the GLVQ model plus appropriate feature extractor
class CNNGTLVQ(torch.nn.Module):
def __init__(
self,
num_classes,
subspace_data,
prototype_data,
tangent_projection_type="local",
prototypes_per_class=2,
bottleneck_dim=128,
):
super(CNNGTLVQ, self).__init__()
2021-04-23 15:24:53 +00:00
# Feature Extractor - Simple CNN
self.fe = nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.25),
nn.Flatten(),
nn.Linear(9216, bottleneck_dim),
nn.Dropout(0.5),
nn.LeakyReLU(),
nn.LayerNorm(bottleneck_dim),
)
2021-01-12 17:11:46 +00:00
# Forward pass of subspace and prototype initialization data through feature extractor
subspace_data = self.fe(subspace_data)
prototype_data[0] = self.fe(prototype_data[0])
# Initialization of GTLVQ
2021-04-23 15:24:53 +00:00
self.gtlvq = GTLVQ(
num_classes,
subspace_data,
prototype_data,
tangent_projection_type=tangent_projection_type,
feature_dim=bottleneck_dim,
prototypes_per_class=prototypes_per_class,
)
2021-01-12 17:11:46 +00:00
def forward(self, x):
# Feature Extraction
x = self.fe(x)
# GTLVQ Forward pass
dis = self.gtlvq(x)
return dis
# Get init data
subspace_data = torch.cat(
[next(iter(train_loader))[0],
next(iter(test_loader))[0]])
prototype_data = next(iter(train_loader))
# Build the CNN GTLVQ model
2021-04-23 15:24:53 +00:00
model = CNNGTLVQ(
10,
subspace_data,
prototype_data,
tangent_projection_type="local",
bottleneck_dim=128,
).to(device)
2021-01-12 17:11:46 +00:00
# Optimize using SGD optimizer from `torch.optim`
2021-04-23 15:24:53 +00:00
optimizer = torch.optim.Adam(
[{
"params": model.fe.parameters()
}, {
"params": model.gtlvq.parameters()
}],
lr=learning_rate,
)
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
2021-01-12 17:11:46 +00:00
# Training loop
2021-05-25 13:57:05 +00:00
for epoch in range(num_epochs):
2021-01-12 17:11:46 +00:00
for batch_idx, (x_train, y_train) in enumerate(train_loader):
model.train()
x_train, y_train = x_train.to(device), y_train.to(device)
optimizer.zero_grad()
distances = model(x_train)
plabels = model.gtlvq.cls.component_labels.to(device)
2021-01-12 17:11:46 +00:00
# Compute loss.
loss = criterion([distances, plabels], y_train)
loss.backward()
optimizer.step()
# GTLVQ uses projected SGD, which means to orthogonalize the subspaces after every gradient update.
model.gtlvq.orthogonalize_subspace()
if batch_idx % log_interval == 0:
acc = calculate_prototype_accuracy(distances, y_train, plabels)
print(
2021-05-25 13:57:05 +00:00
f"Epoch: {epoch + 1:02d}/{num_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
2021-04-23 15:24:53 +00:00
Train Acc: {acc.item():02.02f}")
2021-01-12 17:11:46 +00:00
# Test
with torch.no_grad():
model.eval()
correct = 0
total = 0
for x_test, y_test in test_loader:
x_test, y_test = x_test.to(device), y_test.to(device)
test_distances = model(torch.tensor(x_test))
test_plabels = model.gtlvq.cls.prototype_labels.to(device)
i = torch.argmin(test_distances, 1)
correct += torch.sum(y_test == test_plabels[i])
total += y_test.size(0)
2021-04-23 15:24:53 +00:00
print("Accuracy of the network on the test images: %d %%" %
2021-01-12 17:11:46 +00:00
(torch.true_divide(correct, total) * 100))
# Save the model
2021-04-23 15:24:53 +00:00
PATH = "./glvq_mnist_model.pth"
2021-01-12 17:11:46 +00:00
torch.save(model.state_dict(), PATH)