Remove usage of Prototype1D

Update Iris example to new component API
Update Tecator example to new component API
Update LGMLVQ example to new component API
Update GTLVQ to new component API
This commit is contained in:
Alexander Engelsberger 2021-05-28 15:57:26 +02:00
parent 94fe4435a8
commit 40ef3aeda2
5 changed files with 75 additions and 74 deletions

View File

@ -3,10 +3,10 @@
import numpy as np
import torch
from matplotlib import pyplot as plt
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import Prototypes1D
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
@ -24,19 +24,17 @@ class Model(torch.nn.Module):
def __init__(self):
"""GLVQ model for training on 2D Iris data."""
super().__init__()
self.proto_layer = Prototypes1D(
input_dim=2,
prototypes_per_class=3,
num_classes=3,
prototype_initializer="stratified_random",
data=[x_train, y_train],
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
prototype_distribution = {"num_classes": 3, "prototypes_per_class": 3}
self.proto_layer = LabeledComponents(
prototype_distribution,
prototype_initializer,
)
def forward(self, x):
protos = self.proto_layer.prototypes
plabels = self.proto_layer.prototype_labels
dis = euclidean_distance(x, protos)
return dis, plabels
prototypes, prototype_labels = self.proto_layer()
distances = euclidean_distance(x, prototypes)
return distances, prototype_labels
# Build the GLVQ model
@ -53,43 +51,46 @@ x_in = torch.Tensor(x_train)
y_in = torch.Tensor(y_train)
# Training loop
title = "Prototype Visualization"
fig = plt.figure(title)
TITLE = "Prototype Visualization"
fig = plt.figure(TITLE)
for epoch in range(70):
# Compute loss
dis, plabels = model(x_in)
loss = criterion([dis, plabels], y_in)
distances, prototype_labels = model(x_in)
loss = criterion([distances, prototype_labels], y_in)
# Compute Accuracy
with torch.no_grad():
pred = wtac(dis, plabels)
correct = pred.eq(y_in.view_as(pred)).sum().item()
predictions = wtac(distances, prototype_labels)
correct = predictions.eq(y_in.view_as(predictions)).sum().item()
acc = 100.0 * correct / len(x_train)
print(
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
)
# Take a gradient descent step
# Optimizer step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Get the prototypes form the model
protos = model.proto_layer.prototypes.data.numpy()
if np.isnan(np.sum(protos)):
prototypes = model.proto_layer.components.numpy()
if np.isnan(np.sum(prototypes)):
print("Stopping training because of `nan` in prototypes.")
break
# Visualize the data and the prototypes
ax = fig.gca()
ax.cla()
ax.set_title(title)
ax.set_title(TITLE)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
cmap = "viridis"
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
prototypes[:, 0],
prototypes[:, 1],
c=prototype_labels,
cmap=cmap,
edgecolor="k",
marker="D",
@ -97,7 +98,7 @@ for epoch in range(70):
)
# Paint decision regions
x = np.vstack((x_train, protos))
x = np.vstack((x_train, prototypes))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
@ -107,7 +108,7 @@ for epoch in range(70):
torch_input = torch.Tensor(mesh_input)
d = model(torch_input)[0]
w_indices = torch.argmin(d, dim=1)
y_pred = torch.index_select(plabels, 0, w_indices)
y_pred = torch.index_select(prototype_labels, 0, w_indices)
y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions

View File

@ -2,9 +2,9 @@
import matplotlib.pyplot as plt
import torch
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.datasets.tecator import Tecator
from prototorch.functions.distances import sed
from prototorch.modules import Prototypes1D
from prototorch.modules.losses import GLVQLoss
from prototorch.utils.colors import get_legend_handles
from torch.utils.data import DataLoader
@ -18,22 +18,22 @@ class Model(torch.nn.Module):
def __init__(self, **kwargs):
"""GMLVQ model as a siamese network."""
super().__init__()
x, y = train_data.data, train_data.targets
self.p1 = Prototypes1D(
input_dim=100,
prototypes_per_class=2,
num_classes=2,
prototype_initializer="stratified_random",
data=[x, y],
prototype_initializer = StratifiedMeanInitializer(train_loader)
prototype_distribution = {"num_classes": 2, "prototypes_per_class": 2}
self.proto_layer = LabeledComponents(
prototype_distribution,
prototype_initializer,
)
self.omega = torch.nn.Linear(in_features=100,
out_features=100,
bias=False)
torch.nn.init.eye_(self.omega.weight)
def forward(self, x):
protos = self.p1.prototypes
plabels = self.p1.prototype_labels
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
# Process `x` and `protos` through `omega`
x_map = self.omega(x)
@ -85,8 +85,8 @@ im = ax.imshow(omega.dot(omega.T), cmap="viridis")
plt.show()
# Get the prototypes form the model
protos = model.p1.prototypes.data.numpy()
plabels = model.p1.prototype_labels
protos = model.proto_layer.components.numpy()
plabels = model.proto_layer.component_labels.numpy()
# Visualize the prototypes
title = "Tecator Prototypes"

View File

@ -24,7 +24,7 @@ batch_size_test = 1000
learning_rate = 0.1
momentum = 0.5
log_interval = 10
cuda = "cuda:1"
cuda = "cuda:0"
random_seed = 1
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
@ -147,7 +147,7 @@ for epoch in range(num_epochs):
optimizer.zero_grad()
distances = model(x_train)
plabels = model.gtlvq.cls.prototype_labels.to(device)
plabels = model.gtlvq.cls.component_labels.to(device)
# Compute loss.
loss = criterion([distances, plabels], y_train)

View File

@ -3,14 +3,12 @@
import numpy as np
import torch
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.competitions import stratified_min
from prototorch.functions.distances import lomega_distance
from prototorch.functions.init import eye_
from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import Prototypes1D
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
# Prepare training data
x_train, y_train = load_iris(True)
@ -22,19 +20,19 @@ class Model(torch.nn.Module):
def __init__(self):
"""Local-GMLVQ model."""
super().__init__()
self.p1 = Prototypes1D(
input_dim=2,
prototype_distribution=[1, 2, 2],
prototype_initializer="stratified_random",
data=[x_train, y_train],
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
prototype_distribution = [1, 2, 2]
self.proto_layer = LabeledComponents(
prototype_distribution,
prototype_initializer,
)
omegas = torch.zeros(5, 2, 2)
omegas = torch.eye(2, 2).repeat(5, 1, 1)
self.omegas = torch.nn.Parameter(omegas)
eye_(self.omegas)
def forward(self, x):
protos = self.p1.prototypes
plabels = self.p1.prototype_labels
protos, plabels = self.proto_layer()
omegas = self.omegas
dis = lomega_distance(x, protos, omegas)
return dis, plabels
@ -69,7 +67,7 @@ for epoch in range(100):
optimizer.step()
# Get the prototypes form the model
protos = model.p1.prototypes.data.numpy()
protos = model.proto_layer.components.numpy()
# Visualize the data and the prototypes
ax = fig.gca()

View File

@ -1,9 +1,10 @@
import torch
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.distances import euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D
from torch import nn
class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization
@ -81,13 +82,13 @@ class GTLVQ(nn.Module):
self.feature_dim = feature_dim
self.num_classes = num_classes
self.cls = Prototypes1D(
input_dim=feature_dim,
prototypes_per_class=prototypes_per_class,
nclasses=num_classes,
prototype_initializer="stratified_mean",
data=prototype_data,
)
cls_initializer = StratifiedMeanInitializer(prototype_data)
cls_distribution = {
"num_classes": num_classes,
"prototypes_per_class": prototypes_per_class,
}
self.cls = LabeledComponents(cls_distribution, cls_initializer)
if subspace_data is None:
raise ValueError("Init Data must be specified!")
@ -119,12 +120,12 @@ class GTLVQ(nn.Module):
subspaces = subspace[:, :num_subspaces]
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
def init_local_subspace(self, data,num_subspaces,num_protos):
data = data - torch.mean(data,dim=0)
_,_,v = torch.svd(data,some=False)
v = v[:,:num_subspaces]
subspaces = v.unsqueeze(0).repeat_interleave(num_protos,0)
self.subspaces = nn.Parameter(subspaces,requires_grad=True)
def init_local_subspace(self, data, num_subspaces, num_protos):
data = data - torch.mean(data, dim=0)
_, _, v = torch.svd(data, some=False)
v = v[:, :num_subspaces]
subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0)
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
def global_tangent_distances(self, x):
# Tangent Projection
@ -138,22 +139,23 @@ class GTLVQ(nn.Module):
def local_tangent_distances(self, x):
# Tangent Distance
x = x.unsqueeze(1).expand(x.size(0), self.cls.prototypes.size(0),
x = x.unsqueeze(1).expand(x.size(0), self.cls.num_components,
x.size(-1))
protos = self.cls.prototypes.unsqueeze(0).expand(
x.size(0), self.cls.prototypes.size(0), x.size(-1))
protos = self.cls()[0].unsqueeze(0).expand(x.size(0),
self.cls.num_components,
x.size(-1))
projectors = torch.eye(
self.subspaces.shape[-2], device=x.device) - torch.bmm(
self.subspaces, self.subspaces.permute([0, 2, 1]))
diff = (x - protos)
diff = diff.permute([1, 0, 2])
diff = torch.bmm(diff, projectors)
diff = torch.norm(diff,2,dim=-1).T
diff = torch.norm(diff, 2, dim=-1).T
return diff
def get_parameters(self):
return {
"params": self.cls.prototypes,
"params": self.cls.components,
}, {
"params": self.subspaces
}