Remove usage of Prototype1D
Update Iris example to new component API Update Tecator example to new component API Update LGMLVQ example to new component API Update GTLVQ to new component API
This commit is contained in:
parent
94fe4435a8
commit
40ef3aeda2
@ -3,10 +3,10 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from torchinfo import summary
|
from torchinfo import summary
|
||||||
@ -24,19 +24,17 @@ class Model(torch.nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""GLVQ model for training on 2D Iris data."""
|
"""GLVQ model for training on 2D Iris data."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proto_layer = Prototypes1D(
|
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
|
||||||
input_dim=2,
|
prototype_distribution = {"num_classes": 3, "prototypes_per_class": 3}
|
||||||
prototypes_per_class=3,
|
self.proto_layer = LabeledComponents(
|
||||||
num_classes=3,
|
prototype_distribution,
|
||||||
prototype_initializer="stratified_random",
|
prototype_initializer,
|
||||||
data=[x_train, y_train],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos = self.proto_layer.prototypes
|
prototypes, prototype_labels = self.proto_layer()
|
||||||
plabels = self.proto_layer.prototype_labels
|
distances = euclidean_distance(x, prototypes)
|
||||||
dis = euclidean_distance(x, protos)
|
return distances, prototype_labels
|
||||||
return dis, plabels
|
|
||||||
|
|
||||||
|
|
||||||
# Build the GLVQ model
|
# Build the GLVQ model
|
||||||
@ -53,43 +51,46 @@ x_in = torch.Tensor(x_train)
|
|||||||
y_in = torch.Tensor(y_train)
|
y_in = torch.Tensor(y_train)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
title = "Prototype Visualization"
|
TITLE = "Prototype Visualization"
|
||||||
fig = plt.figure(title)
|
fig = plt.figure(TITLE)
|
||||||
for epoch in range(70):
|
for epoch in range(70):
|
||||||
# Compute loss
|
# Compute loss
|
||||||
dis, plabels = model(x_in)
|
distances, prototype_labels = model(x_in)
|
||||||
loss = criterion([dis, plabels], y_in)
|
loss = criterion([distances, prototype_labels], y_in)
|
||||||
|
|
||||||
|
# Compute Accuracy
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred = wtac(dis, plabels)
|
predictions = wtac(distances, prototype_labels)
|
||||||
correct = pred.eq(y_in.view_as(pred)).sum().item()
|
correct = predictions.eq(y_in.view_as(predictions)).sum().item()
|
||||||
acc = 100.0 * correct / len(x_train)
|
acc = 100.0 * correct / len(x_train)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
|
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Take a gradient descent step
|
# Optimizer step
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# Get the prototypes form the model
|
# Get the prototypes form the model
|
||||||
protos = model.proto_layer.prototypes.data.numpy()
|
prototypes = model.proto_layer.components.numpy()
|
||||||
if np.isnan(np.sum(protos)):
|
if np.isnan(np.sum(prototypes)):
|
||||||
print("Stopping training because of `nan` in prototypes.")
|
print("Stopping training because of `nan` in prototypes.")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Visualize the data and the prototypes
|
# Visualize the data and the prototypes
|
||||||
ax = fig.gca()
|
ax = fig.gca()
|
||||||
ax.cla()
|
ax.cla()
|
||||||
ax.set_title(title)
|
ax.set_title(TITLE)
|
||||||
ax.set_xlabel("Data dimension 1")
|
ax.set_xlabel("Data dimension 1")
|
||||||
ax.set_ylabel("Data dimension 2")
|
ax.set_ylabel("Data dimension 2")
|
||||||
cmap = "viridis"
|
cmap = "viridis"
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
||||||
ax.scatter(
|
ax.scatter(
|
||||||
protos[:, 0],
|
prototypes[:, 0],
|
||||||
protos[:, 1],
|
prototypes[:, 1],
|
||||||
c=plabels,
|
c=prototype_labels,
|
||||||
cmap=cmap,
|
cmap=cmap,
|
||||||
edgecolor="k",
|
edgecolor="k",
|
||||||
marker="D",
|
marker="D",
|
||||||
@ -97,7 +98,7 @@ for epoch in range(70):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Paint decision regions
|
# Paint decision regions
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, prototypes))
|
||||||
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
||||||
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
||||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
||||||
@ -107,7 +108,7 @@ for epoch in range(70):
|
|||||||
torch_input = torch.Tensor(mesh_input)
|
torch_input = torch.Tensor(mesh_input)
|
||||||
d = model(torch_input)[0]
|
d = model(torch_input)[0]
|
||||||
w_indices = torch.argmin(d, dim=1)
|
w_indices = torch.argmin(d, dim=1)
|
||||||
y_pred = torch.index_select(plabels, 0, w_indices)
|
y_pred = torch.index_select(prototype_labels, 0, w_indices)
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
y_pred = y_pred.reshape(xx.shape)
|
||||||
|
|
||||||
# Plot voronoi regions
|
# Plot voronoi regions
|
||||||
|
@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||||
from prototorch.datasets.tecator import Tecator
|
from prototorch.datasets.tecator import Tecator
|
||||||
from prototorch.functions.distances import sed
|
from prototorch.functions.distances import sed
|
||||||
from prototorch.modules import Prototypes1D
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.utils.colors import get_legend_handles
|
from prototorch.utils.colors import get_legend_handles
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -18,22 +18,22 @@ class Model(torch.nn.Module):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
"""GMLVQ model as a siamese network."""
|
"""GMLVQ model as a siamese network."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
x, y = train_data.data, train_data.targets
|
prototype_initializer = StratifiedMeanInitializer(train_loader)
|
||||||
self.p1 = Prototypes1D(
|
prototype_distribution = {"num_classes": 2, "prototypes_per_class": 2}
|
||||||
input_dim=100,
|
|
||||||
prototypes_per_class=2,
|
self.proto_layer = LabeledComponents(
|
||||||
num_classes=2,
|
prototype_distribution,
|
||||||
prototype_initializer="stratified_random",
|
prototype_initializer,
|
||||||
data=[x, y],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.omega = torch.nn.Linear(in_features=100,
|
self.omega = torch.nn.Linear(in_features=100,
|
||||||
out_features=100,
|
out_features=100,
|
||||||
bias=False)
|
bias=False)
|
||||||
torch.nn.init.eye_(self.omega.weight)
|
torch.nn.init.eye_(self.omega.weight)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos = self.p1.prototypes
|
protos = self.proto_layer.components
|
||||||
plabels = self.p1.prototype_labels
|
plabels = self.proto_layer.component_labels
|
||||||
|
|
||||||
# Process `x` and `protos` through `omega`
|
# Process `x` and `protos` through `omega`
|
||||||
x_map = self.omega(x)
|
x_map = self.omega(x)
|
||||||
@ -85,8 +85,8 @@ im = ax.imshow(omega.dot(omega.T), cmap="viridis")
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
# Get the prototypes form the model
|
# Get the prototypes form the model
|
||||||
protos = model.p1.prototypes.data.numpy()
|
protos = model.proto_layer.components.numpy()
|
||||||
plabels = model.p1.prototype_labels
|
plabels = model.proto_layer.component_labels.numpy()
|
||||||
|
|
||||||
# Visualize the prototypes
|
# Visualize the prototypes
|
||||||
title = "Tecator Prototypes"
|
title = "Tecator Prototypes"
|
||||||
|
@ -24,7 +24,7 @@ batch_size_test = 1000
|
|||||||
learning_rate = 0.1
|
learning_rate = 0.1
|
||||||
momentum = 0.5
|
momentum = 0.5
|
||||||
log_interval = 10
|
log_interval = 10
|
||||||
cuda = "cuda:1"
|
cuda = "cuda:0"
|
||||||
random_seed = 1
|
random_seed = 1
|
||||||
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
|
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
@ -147,7 +147,7 @@ for epoch in range(num_epochs):
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
distances = model(x_train)
|
distances = model(x_train)
|
||||||
plabels = model.gtlvq.cls.prototype_labels.to(device)
|
plabels = model.gtlvq.cls.component_labels.to(device)
|
||||||
|
|
||||||
# Compute loss.
|
# Compute loss.
|
||||||
loss = criterion([distances, plabels], y_train)
|
loss = criterion([distances, plabels], y_train)
|
||||||
|
@ -3,14 +3,12 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from sklearn.datasets import load_iris
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||||
from sklearn.metrics import accuracy_score
|
|
||||||
|
|
||||||
from prototorch.functions.competitions import stratified_min
|
from prototorch.functions.competitions import stratified_min
|
||||||
from prototorch.functions.distances import lomega_distance
|
from prototorch.functions.distances import lomega_distance
|
||||||
from prototorch.functions.init import eye_
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.metrics import accuracy_score
|
||||||
|
|
||||||
# Prepare training data
|
# Prepare training data
|
||||||
x_train, y_train = load_iris(True)
|
x_train, y_train = load_iris(True)
|
||||||
@ -22,19 +20,19 @@ class Model(torch.nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Local-GMLVQ model."""
|
"""Local-GMLVQ model."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.p1 = Prototypes1D(
|
|
||||||
input_dim=2,
|
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
|
||||||
prototype_distribution=[1, 2, 2],
|
prototype_distribution = [1, 2, 2]
|
||||||
prototype_initializer="stratified_random",
|
self.proto_layer = LabeledComponents(
|
||||||
data=[x_train, y_train],
|
prototype_distribution,
|
||||||
|
prototype_initializer,
|
||||||
)
|
)
|
||||||
omegas = torch.zeros(5, 2, 2)
|
|
||||||
|
omegas = torch.eye(2, 2).repeat(5, 1, 1)
|
||||||
self.omegas = torch.nn.Parameter(omegas)
|
self.omegas = torch.nn.Parameter(omegas)
|
||||||
eye_(self.omegas)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos = self.p1.prototypes
|
protos, plabels = self.proto_layer()
|
||||||
plabels = self.p1.prototype_labels
|
|
||||||
omegas = self.omegas
|
omegas = self.omegas
|
||||||
dis = lomega_distance(x, protos, omegas)
|
dis = lomega_distance(x, protos, omegas)
|
||||||
return dis, plabels
|
return dis, plabels
|
||||||
@ -69,7 +67,7 @@ for epoch in range(100):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# Get the prototypes form the model
|
# Get the prototypes form the model
|
||||||
protos = model.p1.prototypes.data.numpy()
|
protos = model.proto_layer.components.numpy()
|
||||||
|
|
||||||
# Visualize the data and the prototypes
|
# Visualize the data and the prototypes
|
||||||
ax = fig.gca()
|
ax = fig.gca()
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||||
from prototorch.functions.distances import euclidean_distance_matrix
|
from prototorch.functions.distances import euclidean_distance_matrix
|
||||||
from prototorch.functions.normalization import orthogonalization
|
from prototorch.functions.normalization import orthogonalization
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
class GTLVQ(nn.Module):
|
class GTLVQ(nn.Module):
|
||||||
r""" Generalized Tangent Learning Vector Quantization
|
r""" Generalized Tangent Learning Vector Quantization
|
||||||
|
|
||||||
@ -81,13 +82,13 @@ class GTLVQ(nn.Module):
|
|||||||
self.feature_dim = feature_dim
|
self.feature_dim = feature_dim
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
|
|
||||||
self.cls = Prototypes1D(
|
cls_initializer = StratifiedMeanInitializer(prototype_data)
|
||||||
input_dim=feature_dim,
|
cls_distribution = {
|
||||||
prototypes_per_class=prototypes_per_class,
|
"num_classes": num_classes,
|
||||||
nclasses=num_classes,
|
"prototypes_per_class": prototypes_per_class,
|
||||||
prototype_initializer="stratified_mean",
|
}
|
||||||
data=prototype_data,
|
|
||||||
)
|
self.cls = LabeledComponents(cls_distribution, cls_initializer)
|
||||||
|
|
||||||
if subspace_data is None:
|
if subspace_data is None:
|
||||||
raise ValueError("Init Data must be specified!")
|
raise ValueError("Init Data must be specified!")
|
||||||
@ -138,10 +139,11 @@ class GTLVQ(nn.Module):
|
|||||||
def local_tangent_distances(self, x):
|
def local_tangent_distances(self, x):
|
||||||
|
|
||||||
# Tangent Distance
|
# Tangent Distance
|
||||||
x = x.unsqueeze(1).expand(x.size(0), self.cls.prototypes.size(0),
|
x = x.unsqueeze(1).expand(x.size(0), self.cls.num_components,
|
||||||
|
x.size(-1))
|
||||||
|
protos = self.cls()[0].unsqueeze(0).expand(x.size(0),
|
||||||
|
self.cls.num_components,
|
||||||
x.size(-1))
|
x.size(-1))
|
||||||
protos = self.cls.prototypes.unsqueeze(0).expand(
|
|
||||||
x.size(0), self.cls.prototypes.size(0), x.size(-1))
|
|
||||||
projectors = torch.eye(
|
projectors = torch.eye(
|
||||||
self.subspaces.shape[-2], device=x.device) - torch.bmm(
|
self.subspaces.shape[-2], device=x.device) - torch.bmm(
|
||||||
self.subspaces, self.subspaces.permute([0, 2, 1]))
|
self.subspaces, self.subspaces.permute([0, 2, 1]))
|
||||||
@ -153,7 +155,7 @@ class GTLVQ(nn.Module):
|
|||||||
|
|
||||||
def get_parameters(self):
|
def get_parameters(self):
|
||||||
return {
|
return {
|
||||||
"params": self.cls.prototypes,
|
"params": self.cls.components,
|
||||||
}, {
|
}, {
|
||||||
"params": self.subspaces
|
"params": self.subspaces
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user