refactor(api)!: merge the new api changes into dev
BREAKING CHANGE: remove the following `prototorch/functions/*` `prototorch/components/*` `prototorch/modules/*` BREAKING CHANGE: move `initializers` into the `prototorch.initializers` namespace from the `prototorch.components` namespace BREAKING CHANGE: `functions` and `modules` and moved into `core` and `nn`
This commit is contained in:
96
examples/cbc_iris.py
Normal file
96
examples/cbc_iris.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""ProtoTorch CBC example using 2D Iris data."""
|
||||
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
class CBC(torch.nn.Module):
|
||||
def __init__(self, data, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.components_layer = pt.components.ReasoningComponents(
|
||||
distribution=[2, 1, 2],
|
||||
components_initializer=pt.initializers.SSCI(data, noise=0.1),
|
||||
reasonings_initializer=pt.initializers.PPRI(components_first=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
components, reasonings = self.components_layer()
|
||||
sims = pt.similarities.euclidean_similarity(x, components)
|
||||
probs = pt.competitions.cbcc(sims, reasonings)
|
||||
return probs
|
||||
|
||||
|
||||
class VisCBC2D():
|
||||
def __init__(self, model, data):
|
||||
self.model = model
|
||||
self.x_train, self.y_train = pt.utils.parse_data_arg(data)
|
||||
self.title = "Components Visualization"
|
||||
self.fig = plt.figure(self.title)
|
||||
self.border = 0.1
|
||||
self.resolution = 100
|
||||
self.cmap = "viridis"
|
||||
|
||||
def on_epoch_end(self):
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
_components = self.model.components_layer._components.detach()
|
||||
ax = self.fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(self.title)
|
||||
ax.axis("off")
|
||||
ax.scatter(
|
||||
x_train[:, 0],
|
||||
x_train[:, 1],
|
||||
c=y_train,
|
||||
cmap=self.cmap,
|
||||
edgecolor="k",
|
||||
marker="o",
|
||||
s=30,
|
||||
)
|
||||
ax.scatter(
|
||||
_components[:, 0],
|
||||
_components[:, 1],
|
||||
c="w",
|
||||
cmap=self.cmap,
|
||||
edgecolor="k",
|
||||
marker="D",
|
||||
s=50,
|
||||
)
|
||||
x = torch.vstack((x_train, _components))
|
||||
mesh_input, xx, yy = pt.utils.mesh2d(x, self.border, self.resolution)
|
||||
with torch.no_grad():
|
||||
y_pred = self.model(
|
||||
torch.Tensor(mesh_input).type_as(_components)).argmax(1)
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
plt.pause(0.2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
|
||||
|
||||
model = CBC(train_ds)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||
criterion = pt.losses.MarginLoss(margin=0.1)
|
||||
vis = VisCBC2D(model, train_ds)
|
||||
|
||||
for epoch in range(200):
|
||||
correct = 0.0
|
||||
for x, y in train_loader:
|
||||
y_oh = torch.eye(3)[y]
|
||||
y_pred = model(x)
|
||||
loss = criterion(y_pred, y_oh).mean(0)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
correct += (y_pred.argmax(1) == y).float().sum(0)
|
||||
|
||||
acc = 100 * correct / len(train_ds)
|
||||
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
||||
vis.on_epoch_end()
|
@@ -1,39 +1,35 @@
|
||||
"""This example script shows the usage of the new components architecture.
|
||||
|
||||
Serialization/deserialization also works as expected.
|
||||
|
||||
"""
|
||||
|
||||
# DATASET
|
||||
import torch
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
scaler = StandardScaler()
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
scaler.fit(x_train)
|
||||
x_train = scaler.transform(x_train)
|
||||
import prototorch as pt
|
||||
|
||||
x_train = torch.Tensor(x_train)
|
||||
y_train = torch.Tensor(y_train)
|
||||
num_classes = len(torch.unique(y_train))
|
||||
ds = pt.datasets.Iris()
|
||||
|
||||
# CREATE NEW COMPONENTS
|
||||
from prototorch.components import *
|
||||
from prototorch.components.initializers import *
|
||||
|
||||
unsupervised = Components(6, SelectionInitializer(x_train))
|
||||
unsupervised = pt.components.Components(
|
||||
6,
|
||||
initializer=pt.initializers.ZCI(2),
|
||||
)
|
||||
print(unsupervised())
|
||||
|
||||
prototypes = LabeledComponents(
|
||||
(3, 2), StratifiedSelectionInitializer(x_train, y_train))
|
||||
prototypes = pt.components.LabeledComponents(
|
||||
(3, 2),
|
||||
components_initializer=pt.initializers.SSCI(ds),
|
||||
)
|
||||
print(prototypes())
|
||||
|
||||
components = ReasoningComponents(
|
||||
(3, 6), StratifiedSelectionInitializer(x_train, y_train))
|
||||
print(components())
|
||||
components = pt.components.ReasoningComponents(
|
||||
(3, 2),
|
||||
components_initializer=pt.initializers.SSCI(ds),
|
||||
reasonings_initializer=pt.initializers.PPRI(),
|
||||
)
|
||||
print(prototypes())
|
||||
|
||||
# TEST SERIALIZATION
|
||||
# Test Serialization
|
||||
import io
|
||||
|
||||
save = io.BytesIO()
|
||||
@@ -41,25 +37,20 @@ torch.save(unsupervised, save)
|
||||
save.seek(0)
|
||||
serialized_unsupervised = torch.load(save)
|
||||
|
||||
assert torch.all(unsupervised.components == serialized_unsupervised.components
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(unsupervised.components == serialized_unsupervised.components)
|
||||
|
||||
save = io.BytesIO()
|
||||
torch.save(prototypes, save)
|
||||
save.seek(0)
|
||||
serialized_prototypes = torch.load(save)
|
||||
|
||||
assert torch.all(prototypes.components == serialized_prototypes.components
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(prototypes.component_labels == serialized_prototypes.
|
||||
component_labels), "Serialization of Components failed."
|
||||
assert torch.all(prototypes.components == serialized_prototypes.components)
|
||||
assert torch.all(prototypes.labels == serialized_prototypes.labels)
|
||||
|
||||
save = io.BytesIO()
|
||||
torch.save(components, save)
|
||||
save.seek(0)
|
||||
serialized_components = torch.load(save)
|
||||
|
||||
assert torch.all(components.components == serialized_components.components
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(components.reasonings == serialized_components.reasonings
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(components.components == serialized_components.components)
|
||||
assert torch.all(components.reasonings == serialized_components.reasonings)
|
||||
|
Reference in New Issue
Block a user