Remove examples
This commit is contained in:
parent
acf3272fd7
commit
40c1021c20
@ -1,121 +0,0 @@
|
|||||||
"""ProtoTorch GLVQ example using 2D Iris data."""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
from sklearn.datasets import load_iris
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
|
||||||
from torchinfo import summary
|
|
||||||
|
|
||||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
|
||||||
from prototorch.functions.competitions import wtac
|
|
||||||
from prototorch.functions.distances import euclidean_distance
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
|
||||||
scaler = StandardScaler()
|
|
||||||
x_train, y_train = load_iris(return_X_y=True)
|
|
||||||
x_train = x_train[:, [0, 2]]
|
|
||||||
scaler.fit(x_train)
|
|
||||||
x_train = scaler.transform(x_train)
|
|
||||||
|
|
||||||
|
|
||||||
# Define the GLVQ model
|
|
||||||
class Model(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
"""GLVQ model for training on 2D Iris data."""
|
|
||||||
super().__init__()
|
|
||||||
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
|
|
||||||
prototype_distribution = {"num_classes": 3, "prototypes_per_class": 3}
|
|
||||||
self.proto_layer = LabeledComponents(
|
|
||||||
prototype_distribution,
|
|
||||||
prototype_initializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
prototypes, prototype_labels = self.proto_layer()
|
|
||||||
distances = euclidean_distance(x, prototypes)
|
|
||||||
return distances, prototype_labels
|
|
||||||
|
|
||||||
|
|
||||||
# Build the GLVQ model
|
|
||||||
model = Model()
|
|
||||||
|
|
||||||
# Print summary using torchinfo (might be buggy/incorrect)
|
|
||||||
print(summary(model))
|
|
||||||
|
|
||||||
# Optimize using SGD optimizer from `torch.optim`
|
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
|
||||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
|
||||||
|
|
||||||
x_in = torch.Tensor(x_train)
|
|
||||||
y_in = torch.Tensor(y_train)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
TITLE = "Prototype Visualization"
|
|
||||||
fig = plt.figure(TITLE)
|
|
||||||
for epoch in range(70):
|
|
||||||
# Compute loss
|
|
||||||
distances, prototype_labels = model(x_in)
|
|
||||||
loss = criterion([distances, prototype_labels], y_in)
|
|
||||||
|
|
||||||
# Compute Accuracy
|
|
||||||
with torch.no_grad():
|
|
||||||
predictions = wtac(distances, prototype_labels)
|
|
||||||
correct = predictions.eq(y_in.view_as(predictions)).sum().item()
|
|
||||||
acc = 100.0 * correct / len(x_train)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Optimizer step
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# Get the prototypes form the model
|
|
||||||
prototypes = model.proto_layer.components.numpy()
|
|
||||||
if np.isnan(np.sum(prototypes)):
|
|
||||||
print("Stopping training because of `nan` in prototypes.")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Visualize the data and the prototypes
|
|
||||||
ax = fig.gca()
|
|
||||||
ax.cla()
|
|
||||||
ax.set_title(TITLE)
|
|
||||||
ax.set_xlabel("Data dimension 1")
|
|
||||||
ax.set_ylabel("Data dimension 2")
|
|
||||||
cmap = "viridis"
|
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
|
||||||
ax.scatter(
|
|
||||||
prototypes[:, 0],
|
|
||||||
prototypes[:, 1],
|
|
||||||
c=prototype_labels,
|
|
||||||
cmap=cmap,
|
|
||||||
edgecolor="k",
|
|
||||||
marker="D",
|
|
||||||
s=50,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Paint decision regions
|
|
||||||
x = np.vstack((x_train, prototypes))
|
|
||||||
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
|
||||||
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
|
||||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
|
||||||
np.arange(y_min, y_max, 1 / 50))
|
|
||||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
|
||||||
|
|
||||||
torch_input = torch.Tensor(mesh_input)
|
|
||||||
d = model(torch_input)[0]
|
|
||||||
w_indices = torch.argmin(d, dim=1)
|
|
||||||
y_pred = torch.index_select(prototype_labels, 0, w_indices)
|
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
|
||||||
|
|
||||||
# Plot voronoi regions
|
|
||||||
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
|
|
||||||
|
|
||||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
|
||||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
|
||||||
|
|
||||||
plt.pause(0.1)
|
|
@ -1,104 +0,0 @@
|
|||||||
"""ProtoTorch "siamese" GMLVQ example using Tecator."""
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
|
||||||
from prototorch.datasets.tecator import Tecator
|
|
||||||
from prototorch.functions.distances import sed
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
|
||||||
from prototorch.utils.colors import get_legend_handles
|
|
||||||
|
|
||||||
# Prepare the dataset and dataloader
|
|
||||||
train_data = Tecator(root="./artifacts", train=True)
|
|
||||||
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
|
|
||||||
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
"""GMLVQ model as a siamese network."""
|
|
||||||
super().__init__()
|
|
||||||
prototype_initializer = StratifiedMeanInitializer(train_loader)
|
|
||||||
prototype_distribution = {"num_classes": 2, "prototypes_per_class": 2}
|
|
||||||
|
|
||||||
self.proto_layer = LabeledComponents(
|
|
||||||
prototype_distribution,
|
|
||||||
prototype_initializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.omega = torch.nn.Linear(in_features=100,
|
|
||||||
out_features=100,
|
|
||||||
bias=False)
|
|
||||||
torch.nn.init.eye_(self.omega.weight)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
protos = self.proto_layer.components
|
|
||||||
plabels = self.proto_layer.component_labels
|
|
||||||
|
|
||||||
# Process `x` and `protos` through `omega`
|
|
||||||
x_map = self.omega(x)
|
|
||||||
protos_map = self.omega(protos)
|
|
||||||
|
|
||||||
# Compute distances and output
|
|
||||||
dis = sed(x_map, protos_map)
|
|
||||||
return dis, plabels
|
|
||||||
|
|
||||||
|
|
||||||
# Build the GLVQ model
|
|
||||||
model = Model()
|
|
||||||
|
|
||||||
# Print a summary of the model
|
|
||||||
print(model)
|
|
||||||
|
|
||||||
# Optimize using Adam optimizer from `torch.optim`
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001_0)
|
|
||||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1)
|
|
||||||
criterion = GLVQLoss(squashing="identity", beta=10)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
for epoch in range(150):
|
|
||||||
epoch_loss = 0.0 # zero-out epoch loss
|
|
||||||
optimizer.zero_grad() # zero-out gradients
|
|
||||||
for xb, yb in train_loader:
|
|
||||||
# Compute loss
|
|
||||||
distances, plabels = model(xb)
|
|
||||||
loss = criterion([distances, plabels], yb)
|
|
||||||
epoch_loss += loss.item()
|
|
||||||
# Backprop
|
|
||||||
loss.backward()
|
|
||||||
# Take a gradient descent step
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
lr = optimizer.param_groups[0]["lr"]
|
|
||||||
print(f"Epoch: {epoch + 1:03d} Loss: {epoch_loss:06.02f} lr: {lr:07.06f}")
|
|
||||||
|
|
||||||
# Get the omega matrix form the model
|
|
||||||
omega = model.omega.weight.data.numpy().T
|
|
||||||
|
|
||||||
# Visualize the lambda matrix
|
|
||||||
title = "Lambda Matrix Visualization"
|
|
||||||
fig = plt.figure(title)
|
|
||||||
ax = fig.gca()
|
|
||||||
ax.set_title(title)
|
|
||||||
im = ax.imshow(omega.dot(omega.T), cmap="viridis")
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
# Get the prototypes form the model
|
|
||||||
protos = model.proto_layer.components.numpy()
|
|
||||||
plabels = model.proto_layer.component_labels.numpy()
|
|
||||||
|
|
||||||
# Visualize the prototypes
|
|
||||||
title = "Tecator Prototypes"
|
|
||||||
fig = plt.figure(title)
|
|
||||||
ax = fig.gca()
|
|
||||||
ax.set_title(title)
|
|
||||||
ax.set_xlabel("Spectral frequencies")
|
|
||||||
ax.set_ylabel("Absorption")
|
|
||||||
clabels = ["Class 0 - Low fat", "Class 1 - High fat"]
|
|
||||||
handles, colors = get_legend_handles(clabels, marker="line", zero_indexed=True)
|
|
||||||
for x, y in zip(protos, plabels):
|
|
||||||
ax.plot(x, c=colors[int(y)])
|
|
||||||
ax.legend(handles, clabels)
|
|
||||||
plt.show()
|
|
@ -1,184 +0,0 @@
|
|||||||
"""
|
|
||||||
ProtoTorch GTLVQ example using MNIST data.
|
|
||||||
The GTLVQ is placed as an classification model on
|
|
||||||
top of a CNN, considered as featurer extractor.
|
|
||||||
Initialization of subpsace and prototypes in
|
|
||||||
Siamnese fashion
|
|
||||||
For more info about GTLVQ see:
|
|
||||||
DOI:10.1109/IJCNN.2016.7727534
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torchvision
|
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
from prototorch.functions.helper import calculate_prototype_accuracy
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
|
||||||
from prototorch.modules.models import GTLVQ
|
|
||||||
|
|
||||||
# Parameters and options
|
|
||||||
num_epochs = 50
|
|
||||||
batch_size_train = 64
|
|
||||||
batch_size_test = 1000
|
|
||||||
learning_rate = 0.1
|
|
||||||
momentum = 0.5
|
|
||||||
log_interval = 10
|
|
||||||
cuda = "cuda:0"
|
|
||||||
random_seed = 1
|
|
||||||
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
# Configures reproducability
|
|
||||||
torch.manual_seed(random_seed)
|
|
||||||
np.random.seed(random_seed)
|
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
|
||||||
torchvision.datasets.MNIST(
|
|
||||||
"./files/",
|
|
||||||
train=True,
|
|
||||||
download=True,
|
|
||||||
transform=torchvision.transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
|
||||||
]),
|
|
||||||
),
|
|
||||||
batch_size=batch_size_train,
|
|
||||||
shuffle=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
test_loader = torch.utils.data.DataLoader(
|
|
||||||
torchvision.datasets.MNIST(
|
|
||||||
"./files/",
|
|
||||||
train=False,
|
|
||||||
download=True,
|
|
||||||
transform=torchvision.transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
|
||||||
]),
|
|
||||||
),
|
|
||||||
batch_size=batch_size_test,
|
|
||||||
shuffle=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Define the GLVQ model plus appropriate feature extractor
|
|
||||||
class CNNGTLVQ(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_classes,
|
|
||||||
subspace_data,
|
|
||||||
prototype_data,
|
|
||||||
tangent_projection_type="local",
|
|
||||||
prototypes_per_class=2,
|
|
||||||
bottleneck_dim=128,
|
|
||||||
):
|
|
||||||
super(CNNGTLVQ, self).__init__()
|
|
||||||
|
|
||||||
# Feature Extractor - Simple CNN
|
|
||||||
self.fe = nn.Sequential(
|
|
||||||
nn.Conv2d(1, 32, 3, 1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(32, 64, 3, 1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.MaxPool2d(2),
|
|
||||||
nn.Dropout(0.25),
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(9216, bottleneck_dim),
|
|
||||||
nn.Dropout(0.5),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.LayerNorm(bottleneck_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Forward pass of subspace and prototype initialization data through feature extractor
|
|
||||||
subspace_data = self.fe(subspace_data)
|
|
||||||
prototype_data[0] = self.fe(prototype_data[0])
|
|
||||||
|
|
||||||
# Initialization of GTLVQ
|
|
||||||
self.gtlvq = GTLVQ(
|
|
||||||
num_classes,
|
|
||||||
subspace_data,
|
|
||||||
prototype_data,
|
|
||||||
tangent_projection_type=tangent_projection_type,
|
|
||||||
feature_dim=bottleneck_dim,
|
|
||||||
prototypes_per_class=prototypes_per_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# Feature Extraction
|
|
||||||
x = self.fe(x)
|
|
||||||
|
|
||||||
# GTLVQ Forward pass
|
|
||||||
dis = self.gtlvq(x)
|
|
||||||
return dis
|
|
||||||
|
|
||||||
|
|
||||||
# Get init data
|
|
||||||
subspace_data = torch.cat(
|
|
||||||
[next(iter(train_loader))[0],
|
|
||||||
next(iter(test_loader))[0]])
|
|
||||||
prototype_data = next(iter(train_loader))
|
|
||||||
|
|
||||||
# Build the CNN GTLVQ model
|
|
||||||
model = CNNGTLVQ(
|
|
||||||
10,
|
|
||||||
subspace_data,
|
|
||||||
prototype_data,
|
|
||||||
tangent_projection_type="local",
|
|
||||||
bottleneck_dim=128,
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
# Optimize using SGD optimizer from `torch.optim`
|
|
||||||
optimizer = torch.optim.Adam(
|
|
||||||
[{
|
|
||||||
"params": model.fe.parameters()
|
|
||||||
}, {
|
|
||||||
"params": model.gtlvq.parameters()
|
|
||||||
}],
|
|
||||||
lr=learning_rate,
|
|
||||||
)
|
|
||||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
for batch_idx, (x_train, y_train) in enumerate(train_loader):
|
|
||||||
model.train()
|
|
||||||
x_train, y_train = x_train.to(device), y_train.to(device)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
distances = model(x_train)
|
|
||||||
plabels = model.gtlvq.cls.component_labels.to(device)
|
|
||||||
|
|
||||||
# Compute loss.
|
|
||||||
loss = criterion([distances, plabels], y_train)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# GTLVQ uses projected SGD, which means to orthogonalize the subspaces after every gradient update.
|
|
||||||
model.gtlvq.orthogonalize_subspace()
|
|
||||||
|
|
||||||
if batch_idx % log_interval == 0:
|
|
||||||
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
|
||||||
print(
|
|
||||||
f"Epoch: {epoch + 1:02d}/{num_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
|
||||||
Train Acc: {acc.item():02.02f}")
|
|
||||||
|
|
||||||
# Test
|
|
||||||
with torch.no_grad():
|
|
||||||
model.eval()
|
|
||||||
correct = 0
|
|
||||||
total = 0
|
|
||||||
for x_test, y_test in test_loader:
|
|
||||||
x_test, y_test = x_test.to(device), y_test.to(device)
|
|
||||||
test_distances = model(torch.tensor(x_test))
|
|
||||||
test_plabels = model.gtlvq.cls.prototype_labels.to(device)
|
|
||||||
i = torch.argmin(test_distances, 1)
|
|
||||||
correct += torch.sum(y_test == test_plabels[i])
|
|
||||||
total += y_test.size(0)
|
|
||||||
print("Accuracy of the network on the test images: %d %%" %
|
|
||||||
(torch.true_divide(correct, total) * 100))
|
|
||||||
|
|
||||||
# Save the model
|
|
||||||
PATH = "./glvq_mnist_model.pth"
|
|
||||||
torch.save(model.state_dict(), PATH)
|
|
@ -1,111 +0,0 @@
|
|||||||
"""ProtoTorch LGMLVQ example using 2D Iris data."""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
from sklearn.datasets import load_iris
|
|
||||||
from sklearn.metrics import accuracy_score
|
|
||||||
|
|
||||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
|
||||||
from prototorch.functions.distances import lomega_distance
|
|
||||||
from prototorch.functions.pooling import stratified_min_pooling
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
|
||||||
|
|
||||||
# Prepare training data
|
|
||||||
x_train, y_train = load_iris(True)
|
|
||||||
x_train = x_train[:, [0, 2]]
|
|
||||||
|
|
||||||
|
|
||||||
# Define the model
|
|
||||||
class Model(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
"""Local-GMLVQ model."""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
prototype_initializer = StratifiedMeanInitializer([x_train, y_train])
|
|
||||||
prototype_distribution = [1, 2, 2]
|
|
||||||
self.proto_layer = LabeledComponents(
|
|
||||||
prototype_distribution,
|
|
||||||
prototype_initializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
omegas = torch.eye(2, 2).repeat(5, 1, 1)
|
|
||||||
self.omegas = torch.nn.Parameter(omegas)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
protos, plabels = self.proto_layer()
|
|
||||||
omegas = self.omegas
|
|
||||||
dis = lomega_distance(x, protos, omegas)
|
|
||||||
return dis, plabels
|
|
||||||
|
|
||||||
|
|
||||||
# Build the model
|
|
||||||
model = Model()
|
|
||||||
|
|
||||||
# Optimize using Adam optimizer from `torch.optim`
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
|
||||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
|
||||||
|
|
||||||
x_in = torch.Tensor(x_train)
|
|
||||||
y_in = torch.Tensor(y_train)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
title = "Prototype Visualization"
|
|
||||||
fig = plt.figure(title)
|
|
||||||
for epoch in range(100):
|
|
||||||
# Compute loss
|
|
||||||
dis, plabels = model(x_in)
|
|
||||||
loss = criterion([dis, plabels], y_in)
|
|
||||||
y_pred = np.argmin(stratified_min_pooling(dis, plabels).detach().numpy(),
|
|
||||||
axis=1)
|
|
||||||
acc = accuracy_score(y_train, y_pred)
|
|
||||||
log_string = f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} "
|
|
||||||
log_string += f"Acc: {acc * 100:05.02f}%"
|
|
||||||
print(log_string)
|
|
||||||
|
|
||||||
# Take a gradient descent step
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# Get the prototypes form the model
|
|
||||||
protos = model.proto_layer.components.numpy()
|
|
||||||
|
|
||||||
# Visualize the data and the prototypes
|
|
||||||
ax = fig.gca()
|
|
||||||
ax.cla()
|
|
||||||
ax.set_title(title)
|
|
||||||
ax.set_xlabel("Data dimension 1")
|
|
||||||
ax.set_ylabel("Data dimension 2")
|
|
||||||
cmap = "viridis"
|
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
|
||||||
ax.scatter(
|
|
||||||
protos[:, 0],
|
|
||||||
protos[:, 1],
|
|
||||||
c=plabels,
|
|
||||||
cmap=cmap,
|
|
||||||
edgecolor="k",
|
|
||||||
marker="D",
|
|
||||||
s=50,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Paint decision regions
|
|
||||||
x = np.vstack((x_train, protos))
|
|
||||||
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
|
||||||
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
|
||||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
|
||||||
np.arange(y_min, y_max, 1 / 50))
|
|
||||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
|
||||||
|
|
||||||
d, plabels = model(torch.Tensor(mesh_input))
|
|
||||||
y_pred = np.argmin(stratified_min_pooling(d, plabels).detach().numpy(),
|
|
||||||
axis=1)
|
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
|
||||||
|
|
||||||
# Plot voronoi regions
|
|
||||||
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
|
|
||||||
|
|
||||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
|
||||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
|
||||||
|
|
||||||
plt.pause(0.1)
|
|
Loading…
Reference in New Issue
Block a user