Remove usage of Prototype1D

Update Iris example to new component API
Update Tecator example to new component API
Update LGMLVQ example to new component API
Update GTLVQ to new component API
This commit is contained in:
Alexander Engelsberger 2021-05-28 15:57:26 +02:00
parent 94fe4435a8
commit 40ef3aeda2
5 changed files with 75 additions and 74 deletions

View File

@ -3,10 +3,10 @@
import numpy as np import numpy as np
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import Prototypes1D
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from torchinfo import summary from torchinfo import summary
@ -24,19 +24,17 @@ class Model(torch.nn.Module):
def __init__(self): def __init__(self):
"""GLVQ model for training on 2D Iris data.""" """GLVQ model for training on 2D Iris data."""
super().__init__() super().__init__()
self.proto_layer = Prototypes1D( prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
input_dim=2, prototype_distribution = {"num_classes": 3, "prototypes_per_class": 3}
prototypes_per_class=3, self.proto_layer = LabeledComponents(
num_classes=3, prototype_distribution,
prototype_initializer="stratified_random", prototype_initializer,
data=[x_train, y_train],
) )
def forward(self, x): def forward(self, x):
protos = self.proto_layer.prototypes prototypes, prototype_labels = self.proto_layer()
plabels = self.proto_layer.prototype_labels distances = euclidean_distance(x, prototypes)
dis = euclidean_distance(x, protos) return distances, prototype_labels
return dis, plabels
# Build the GLVQ model # Build the GLVQ model
@ -53,43 +51,46 @@ x_in = torch.Tensor(x_train)
y_in = torch.Tensor(y_train) y_in = torch.Tensor(y_train)
# Training loop # Training loop
title = "Prototype Visualization" TITLE = "Prototype Visualization"
fig = plt.figure(title) fig = plt.figure(TITLE)
for epoch in range(70): for epoch in range(70):
# Compute loss # Compute loss
dis, plabels = model(x_in) distances, prototype_labels = model(x_in)
loss = criterion([dis, plabels], y_in) loss = criterion([distances, prototype_labels], y_in)
# Compute Accuracy
with torch.no_grad(): with torch.no_grad():
pred = wtac(dis, plabels) predictions = wtac(distances, prototype_labels)
correct = pred.eq(y_in.view_as(pred)).sum().item() correct = predictions.eq(y_in.view_as(predictions)).sum().item()
acc = 100.0 * correct / len(x_train) acc = 100.0 * correct / len(x_train)
print( print(
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%" f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
) )
# Take a gradient descent step # Optimizer step
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# Get the prototypes form the model # Get the prototypes form the model
protos = model.proto_layer.prototypes.data.numpy() prototypes = model.proto_layer.components.numpy()
if np.isnan(np.sum(protos)): if np.isnan(np.sum(prototypes)):
print("Stopping training because of `nan` in prototypes.") print("Stopping training because of `nan` in prototypes.")
break break
# Visualize the data and the prototypes # Visualize the data and the prototypes
ax = fig.gca() ax = fig.gca()
ax.cla() ax.cla()
ax.set_title(title) ax.set_title(TITLE)
ax.set_xlabel("Data dimension 1") ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2") ax.set_ylabel("Data dimension 2")
cmap = "viridis" cmap = "viridis"
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter( ax.scatter(
protos[:, 0], prototypes[:, 0],
protos[:, 1], prototypes[:, 1],
c=plabels, c=prototype_labels,
cmap=cmap, cmap=cmap,
edgecolor="k", edgecolor="k",
marker="D", marker="D",
@ -97,7 +98,7 @@ for epoch in range(70):
) )
# Paint decision regions # Paint decision regions
x = np.vstack((x_train, protos)) x = np.vstack((x_train, prototypes))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
@ -107,7 +108,7 @@ for epoch in range(70):
torch_input = torch.Tensor(mesh_input) torch_input = torch.Tensor(mesh_input)
d = model(torch_input)[0] d = model(torch_input)[0]
w_indices = torch.argmin(d, dim=1) w_indices = torch.argmin(d, dim=1)
y_pred = torch.index_select(plabels, 0, w_indices) y_pred = torch.index_select(prototype_labels, 0, w_indices)
y_pred = y_pred.reshape(xx.shape) y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions # Plot voronoi regions

View File

@ -2,9 +2,9 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.datasets.tecator import Tecator from prototorch.datasets.tecator import Tecator
from prototorch.functions.distances import sed from prototorch.functions.distances import sed
from prototorch.modules import Prototypes1D
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.utils.colors import get_legend_handles from prototorch.utils.colors import get_legend_handles
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -18,22 +18,22 @@ class Model(torch.nn.Module):
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""GMLVQ model as a siamese network.""" """GMLVQ model as a siamese network."""
super().__init__() super().__init__()
x, y = train_data.data, train_data.targets prototype_initializer = StratifiedMeanInitializer(train_loader)
self.p1 = Prototypes1D( prototype_distribution = {"num_classes": 2, "prototypes_per_class": 2}
input_dim=100,
prototypes_per_class=2, self.proto_layer = LabeledComponents(
num_classes=2, prototype_distribution,
prototype_initializer="stratified_random", prototype_initializer,
data=[x, y],
) )
self.omega = torch.nn.Linear(in_features=100, self.omega = torch.nn.Linear(in_features=100,
out_features=100, out_features=100,
bias=False) bias=False)
torch.nn.init.eye_(self.omega.weight) torch.nn.init.eye_(self.omega.weight)
def forward(self, x): def forward(self, x):
protos = self.p1.prototypes protos = self.proto_layer.components
plabels = self.p1.prototype_labels plabels = self.proto_layer.component_labels
# Process `x` and `protos` through `omega` # Process `x` and `protos` through `omega`
x_map = self.omega(x) x_map = self.omega(x)
@ -85,8 +85,8 @@ im = ax.imshow(omega.dot(omega.T), cmap="viridis")
plt.show() plt.show()
# Get the prototypes form the model # Get the prototypes form the model
protos = model.p1.prototypes.data.numpy() protos = model.proto_layer.components.numpy()
plabels = model.p1.prototype_labels plabels = model.proto_layer.component_labels.numpy()
# Visualize the prototypes # Visualize the prototypes
title = "Tecator Prototypes" title = "Tecator Prototypes"

View File

@ -24,7 +24,7 @@ batch_size_test = 1000
learning_rate = 0.1 learning_rate = 0.1
momentum = 0.5 momentum = 0.5
log_interval = 10 log_interval = 10
cuda = "cuda:1" cuda = "cuda:0"
random_seed = 1 random_seed = 1
device = torch.device(cuda if torch.cuda.is_available() else "cpu") device = torch.device(cuda if torch.cuda.is_available() else "cpu")
@ -147,7 +147,7 @@ for epoch in range(num_epochs):
optimizer.zero_grad() optimizer.zero_grad()
distances = model(x_train) distances = model(x_train)
plabels = model.gtlvq.cls.prototype_labels.to(device) plabels = model.gtlvq.cls.component_labels.to(device)
# Compute loss. # Compute loss.
loss = criterion([distances, plabels], y_train) loss = criterion([distances, plabels], y_train)

View File

@ -3,14 +3,12 @@
import numpy as np import numpy as np
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from sklearn.datasets import load_iris from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from sklearn.metrics import accuracy_score
from prototorch.functions.competitions import stratified_min from prototorch.functions.competitions import stratified_min
from prototorch.functions.distances import lomega_distance from prototorch.functions.distances import lomega_distance
from prototorch.functions.init import eye_
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import Prototypes1D from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
# Prepare training data # Prepare training data
x_train, y_train = load_iris(True) x_train, y_train = load_iris(True)
@ -22,19 +20,19 @@ class Model(torch.nn.Module):
def __init__(self): def __init__(self):
"""Local-GMLVQ model.""" """Local-GMLVQ model."""
super().__init__() super().__init__()
self.p1 = Prototypes1D(
input_dim=2, prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
prototype_distribution=[1, 2, 2], prototype_distribution = [1, 2, 2]
prototype_initializer="stratified_random", self.proto_layer = LabeledComponents(
data=[x_train, y_train], prototype_distribution,
prototype_initializer,
) )
omegas = torch.zeros(5, 2, 2)
omegas = torch.eye(2, 2).repeat(5, 1, 1)
self.omegas = torch.nn.Parameter(omegas) self.omegas = torch.nn.Parameter(omegas)
eye_(self.omegas)
def forward(self, x): def forward(self, x):
protos = self.p1.prototypes protos, plabels = self.proto_layer()
plabels = self.p1.prototype_labels
omegas = self.omegas omegas = self.omegas
dis = lomega_distance(x, protos, omegas) dis = lomega_distance(x, protos, omegas)
return dis, plabels return dis, plabels
@ -69,7 +67,7 @@ for epoch in range(100):
optimizer.step() optimizer.step()
# Get the prototypes form the model # Get the prototypes form the model
protos = model.p1.prototypes.data.numpy() protos = model.proto_layer.components.numpy()
# Visualize the data and the prototypes # Visualize the data and the prototypes
ax = fig.gca() ax = fig.gca()

View File

@ -1,9 +1,10 @@
import torch import torch
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.distances import euclidean_distance_matrix from prototorch.functions.distances import euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D
from torch import nn from torch import nn
class GTLVQ(nn.Module): class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization r""" Generalized Tangent Learning Vector Quantization
@ -81,13 +82,13 @@ class GTLVQ(nn.Module):
self.feature_dim = feature_dim self.feature_dim = feature_dim
self.num_classes = num_classes self.num_classes = num_classes
self.cls = Prototypes1D( cls_initializer = StratifiedMeanInitializer(prototype_data)
input_dim=feature_dim, cls_distribution = {
prototypes_per_class=prototypes_per_class, "num_classes": num_classes,
nclasses=num_classes, "prototypes_per_class": prototypes_per_class,
prototype_initializer="stratified_mean", }
data=prototype_data,
) self.cls = LabeledComponents(cls_distribution, cls_initializer)
if subspace_data is None: if subspace_data is None:
raise ValueError("Init Data must be specified!") raise ValueError("Init Data must be specified!")
@ -138,10 +139,11 @@ class GTLVQ(nn.Module):
def local_tangent_distances(self, x): def local_tangent_distances(self, x):
# Tangent Distance # Tangent Distance
x = x.unsqueeze(1).expand(x.size(0), self.cls.prototypes.size(0), x = x.unsqueeze(1).expand(x.size(0), self.cls.num_components,
x.size(-1))
protos = self.cls()[0].unsqueeze(0).expand(x.size(0),
self.cls.num_components,
x.size(-1)) x.size(-1))
protos = self.cls.prototypes.unsqueeze(0).expand(
x.size(0), self.cls.prototypes.size(0), x.size(-1))
projectors = torch.eye( projectors = torch.eye(
self.subspaces.shape[-2], device=x.device) - torch.bmm( self.subspaces.shape[-2], device=x.device) - torch.bmm(
self.subspaces, self.subspaces.permute([0, 2, 1])) self.subspaces, self.subspaces.permute([0, 2, 1]))
@ -153,7 +155,7 @@ class GTLVQ(nn.Module):
def get_parameters(self): def get_parameters(self):
return { return {
"params": self.cls.prototypes, "params": self.cls.components,
}, { }, {
"params": self.subspaces "params": self.subspaces
} }