Remove usage of Prototype1D
Update Iris example to new component API Update Tecator example to new component API Update LGMLVQ example to new component API Update GTLVQ to new component API
This commit is contained in:
parent
94fe4435a8
commit
40ef3aeda2
@ -3,10 +3,10 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from torchinfo import summary
|
||||
@ -24,19 +24,17 @@ class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
"""GLVQ model for training on 2D Iris data."""
|
||||
super().__init__()
|
||||
self.proto_layer = Prototypes1D(
|
||||
input_dim=2,
|
||||
prototypes_per_class=3,
|
||||
num_classes=3,
|
||||
prototype_initializer="stratified_random",
|
||||
data=[x_train, y_train],
|
||||
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
|
||||
prototype_distribution = {"num_classes": 3, "prototypes_per_class": 3}
|
||||
self.proto_layer = LabeledComponents(
|
||||
prototype_distribution,
|
||||
prototype_initializer,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.proto_layer.prototypes
|
||||
plabels = self.proto_layer.prototype_labels
|
||||
dis = euclidean_distance(x, protos)
|
||||
return dis, plabels
|
||||
prototypes, prototype_labels = self.proto_layer()
|
||||
distances = euclidean_distance(x, prototypes)
|
||||
return distances, prototype_labels
|
||||
|
||||
|
||||
# Build the GLVQ model
|
||||
@ -53,43 +51,46 @@ x_in = torch.Tensor(x_train)
|
||||
y_in = torch.Tensor(y_train)
|
||||
|
||||
# Training loop
|
||||
title = "Prototype Visualization"
|
||||
fig = plt.figure(title)
|
||||
TITLE = "Prototype Visualization"
|
||||
fig = plt.figure(TITLE)
|
||||
for epoch in range(70):
|
||||
# Compute loss
|
||||
dis, plabels = model(x_in)
|
||||
loss = criterion([dis, plabels], y_in)
|
||||
distances, prototype_labels = model(x_in)
|
||||
loss = criterion([distances, prototype_labels], y_in)
|
||||
|
||||
# Compute Accuracy
|
||||
with torch.no_grad():
|
||||
pred = wtac(dis, plabels)
|
||||
correct = pred.eq(y_in.view_as(pred)).sum().item()
|
||||
predictions = wtac(distances, prototype_labels)
|
||||
correct = predictions.eq(y_in.view_as(predictions)).sum().item()
|
||||
acc = 100.0 * correct / len(x_train)
|
||||
|
||||
print(
|
||||
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
|
||||
)
|
||||
|
||||
# Take a gradient descent step
|
||||
# Optimizer step
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Get the prototypes form the model
|
||||
protos = model.proto_layer.prototypes.data.numpy()
|
||||
if np.isnan(np.sum(protos)):
|
||||
prototypes = model.proto_layer.components.numpy()
|
||||
if np.isnan(np.sum(prototypes)):
|
||||
print("Stopping training because of `nan` in prototypes.")
|
||||
break
|
||||
|
||||
# Visualize the data and the prototypes
|
||||
ax = fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(title)
|
||||
ax.set_title(TITLE)
|
||||
ax.set_xlabel("Data dimension 1")
|
||||
ax.set_ylabel("Data dimension 2")
|
||||
cmap = "viridis"
|
||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
||||
ax.scatter(
|
||||
protos[:, 0],
|
||||
protos[:, 1],
|
||||
c=plabels,
|
||||
prototypes[:, 0],
|
||||
prototypes[:, 1],
|
||||
c=prototype_labels,
|
||||
cmap=cmap,
|
||||
edgecolor="k",
|
||||
marker="D",
|
||||
@ -97,7 +98,7 @@ for epoch in range(70):
|
||||
)
|
||||
|
||||
# Paint decision regions
|
||||
x = np.vstack((x_train, protos))
|
||||
x = np.vstack((x_train, prototypes))
|
||||
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
||||
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
||||
@ -107,7 +108,7 @@ for epoch in range(70):
|
||||
torch_input = torch.Tensor(mesh_input)
|
||||
d = model(torch_input)[0]
|
||||
w_indices = torch.argmin(d, dim=1)
|
||||
y_pred = torch.index_select(plabels, 0, w_indices)
|
||||
y_pred = torch.index_select(prototype_labels, 0, w_indices)
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
|
||||
# Plot voronoi regions
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||
from prototorch.datasets.tecator import Tecator
|
||||
from prototorch.functions.distances import sed
|
||||
from prototorch.modules import Prototypes1D
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.utils.colors import get_legend_handles
|
||||
from torch.utils.data import DataLoader
|
||||
@ -18,22 +18,22 @@ class Model(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
"""GMLVQ model as a siamese network."""
|
||||
super().__init__()
|
||||
x, y = train_data.data, train_data.targets
|
||||
self.p1 = Prototypes1D(
|
||||
input_dim=100,
|
||||
prototypes_per_class=2,
|
||||
num_classes=2,
|
||||
prototype_initializer="stratified_random",
|
||||
data=[x, y],
|
||||
prototype_initializer = StratifiedMeanInitializer(train_loader)
|
||||
prototype_distribution = {"num_classes": 2, "prototypes_per_class": 2}
|
||||
|
||||
self.proto_layer = LabeledComponents(
|
||||
prototype_distribution,
|
||||
prototype_initializer,
|
||||
)
|
||||
|
||||
self.omega = torch.nn.Linear(in_features=100,
|
||||
out_features=100,
|
||||
bias=False)
|
||||
torch.nn.init.eye_(self.omega.weight)
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.p1.prototypes
|
||||
plabels = self.p1.prototype_labels
|
||||
protos = self.proto_layer.components
|
||||
plabels = self.proto_layer.component_labels
|
||||
|
||||
# Process `x` and `protos` through `omega`
|
||||
x_map = self.omega(x)
|
||||
@ -85,8 +85,8 @@ im = ax.imshow(omega.dot(omega.T), cmap="viridis")
|
||||
plt.show()
|
||||
|
||||
# Get the prototypes form the model
|
||||
protos = model.p1.prototypes.data.numpy()
|
||||
plabels = model.p1.prototype_labels
|
||||
protos = model.proto_layer.components.numpy()
|
||||
plabels = model.proto_layer.component_labels.numpy()
|
||||
|
||||
# Visualize the prototypes
|
||||
title = "Tecator Prototypes"
|
||||
|
@ -24,7 +24,7 @@ batch_size_test = 1000
|
||||
learning_rate = 0.1
|
||||
momentum = 0.5
|
||||
log_interval = 10
|
||||
cuda = "cuda:1"
|
||||
cuda = "cuda:0"
|
||||
random_seed = 1
|
||||
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@ -147,7 +147,7 @@ for epoch in range(num_epochs):
|
||||
optimizer.zero_grad()
|
||||
|
||||
distances = model(x_train)
|
||||
plabels = model.gtlvq.cls.prototype_labels.to(device)
|
||||
plabels = model.gtlvq.cls.component_labels.to(device)
|
||||
|
||||
# Compute loss.
|
||||
loss = criterion([distances, plabels], y_train)
|
||||
|
@ -3,14 +3,12 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||
from prototorch.functions.competitions import stratified_min
|
||||
from prototorch.functions.distances import lomega_distance
|
||||
from prototorch.functions.init import eye_
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
# Prepare training data
|
||||
x_train, y_train = load_iris(True)
|
||||
@ -22,19 +20,19 @@ class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
"""Local-GMLVQ model."""
|
||||
super().__init__()
|
||||
self.p1 = Prototypes1D(
|
||||
input_dim=2,
|
||||
prototype_distribution=[1, 2, 2],
|
||||
prototype_initializer="stratified_random",
|
||||
data=[x_train, y_train],
|
||||
|
||||
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
|
||||
prototype_distribution = [1, 2, 2]
|
||||
self.proto_layer = LabeledComponents(
|
||||
prototype_distribution,
|
||||
prototype_initializer,
|
||||
)
|
||||
omegas = torch.zeros(5, 2, 2)
|
||||
|
||||
omegas = torch.eye(2, 2).repeat(5, 1, 1)
|
||||
self.omegas = torch.nn.Parameter(omegas)
|
||||
eye_(self.omegas)
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.p1.prototypes
|
||||
plabels = self.p1.prototype_labels
|
||||
protos, plabels = self.proto_layer()
|
||||
omegas = self.omegas
|
||||
dis = lomega_distance(x, protos, omegas)
|
||||
return dis, plabels
|
||||
@ -69,7 +67,7 @@ for epoch in range(100):
|
||||
optimizer.step()
|
||||
|
||||
# Get the prototypes form the model
|
||||
protos = model.p1.prototypes.data.numpy()
|
||||
protos = model.proto_layer.components.numpy()
|
||||
|
||||
# Visualize the data and the prototypes
|
||||
ax = fig.gca()
|
||||
|
@ -1,9 +1,10 @@
|
||||
import torch
|
||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||
from prototorch.functions.distances import euclidean_distance_matrix
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from torch import nn
|
||||
|
||||
|
||||
class GTLVQ(nn.Module):
|
||||
r""" Generalized Tangent Learning Vector Quantization
|
||||
|
||||
@ -81,13 +82,13 @@ class GTLVQ(nn.Module):
|
||||
self.feature_dim = feature_dim
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.cls = Prototypes1D(
|
||||
input_dim=feature_dim,
|
||||
prototypes_per_class=prototypes_per_class,
|
||||
nclasses=num_classes,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=prototype_data,
|
||||
)
|
||||
cls_initializer = StratifiedMeanInitializer(prototype_data)
|
||||
cls_distribution = {
|
||||
"num_classes": num_classes,
|
||||
"prototypes_per_class": prototypes_per_class,
|
||||
}
|
||||
|
||||
self.cls = LabeledComponents(cls_distribution, cls_initializer)
|
||||
|
||||
if subspace_data is None:
|
||||
raise ValueError("Init Data must be specified!")
|
||||
@ -119,12 +120,12 @@ class GTLVQ(nn.Module):
|
||||
subspaces = subspace[:, :num_subspaces]
|
||||
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||
|
||||
def init_local_subspace(self, data,num_subspaces,num_protos):
|
||||
data = data - torch.mean(data,dim=0)
|
||||
_,_,v = torch.svd(data,some=False)
|
||||
v = v[:,:num_subspaces]
|
||||
subspaces = v.unsqueeze(0).repeat_interleave(num_protos,0)
|
||||
self.subspaces = nn.Parameter(subspaces,requires_grad=True)
|
||||
def init_local_subspace(self, data, num_subspaces, num_protos):
|
||||
data = data - torch.mean(data, dim=0)
|
||||
_, _, v = torch.svd(data, some=False)
|
||||
v = v[:, :num_subspaces]
|
||||
subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0)
|
||||
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||
|
||||
def global_tangent_distances(self, x):
|
||||
# Tangent Projection
|
||||
@ -138,22 +139,23 @@ class GTLVQ(nn.Module):
|
||||
def local_tangent_distances(self, x):
|
||||
|
||||
# Tangent Distance
|
||||
x = x.unsqueeze(1).expand(x.size(0), self.cls.prototypes.size(0),
|
||||
x = x.unsqueeze(1).expand(x.size(0), self.cls.num_components,
|
||||
x.size(-1))
|
||||
protos = self.cls.prototypes.unsqueeze(0).expand(
|
||||
x.size(0), self.cls.prototypes.size(0), x.size(-1))
|
||||
protos = self.cls()[0].unsqueeze(0).expand(x.size(0),
|
||||
self.cls.num_components,
|
||||
x.size(-1))
|
||||
projectors = torch.eye(
|
||||
self.subspaces.shape[-2], device=x.device) - torch.bmm(
|
||||
self.subspaces, self.subspaces.permute([0, 2, 1]))
|
||||
diff = (x - protos)
|
||||
diff = diff.permute([1, 0, 2])
|
||||
diff = torch.bmm(diff, projectors)
|
||||
diff = torch.norm(diff,2,dim=-1).T
|
||||
diff = torch.norm(diff, 2, dim=-1).T
|
||||
return diff
|
||||
|
||||
def get_parameters(self):
|
||||
return {
|
||||
"params": self.cls.prototypes,
|
||||
"params": self.cls.components,
|
||||
}, {
|
||||
"params": self.subspaces
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user