refactor(api)!: merge the new api changes into dev

BREAKING CHANGE: remove the following
`prototorch/functions/*`
`prototorch/components/*`
`prototorch/modules/*`
BREAKING CHANGE: move `initializers` into the `prototorch.initializers`
namespace from the `prototorch.components` namespace
BREAKING CHANGE: `functions` and `modules` and moved into `core` and `nn`
This commit is contained in:
Jensun Ravichandran
2021-06-18 18:20:30 +02:00
49 changed files with 2465 additions and 2201 deletions

96
examples/cbc_iris.py Normal file
View File

@@ -0,0 +1,96 @@
"""ProtoTorch CBC example using 2D Iris data."""
import torch
from matplotlib import pyplot as plt
import prototorch as pt
class CBC(torch.nn.Module):
def __init__(self, data, **kwargs):
super().__init__(**kwargs)
self.components_layer = pt.components.ReasoningComponents(
distribution=[2, 1, 2],
components_initializer=pt.initializers.SSCI(data, noise=0.1),
reasonings_initializer=pt.initializers.PPRI(components_first=True),
)
def forward(self, x):
components, reasonings = self.components_layer()
sims = pt.similarities.euclidean_similarity(x, components)
probs = pt.competitions.cbcc(sims, reasonings)
return probs
class VisCBC2D():
def __init__(self, model, data):
self.model = model
self.x_train, self.y_train = pt.utils.parse_data_arg(data)
self.title = "Components Visualization"
self.fig = plt.figure(self.title)
self.border = 0.1
self.resolution = 100
self.cmap = "viridis"
def on_epoch_end(self):
x_train, y_train = self.x_train, self.y_train
_components = self.model.components_layer._components.detach()
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.axis("off")
ax.scatter(
x_train[:, 0],
x_train[:, 1],
c=y_train,
cmap=self.cmap,
edgecolor="k",
marker="o",
s=30,
)
ax.scatter(
_components[:, 0],
_components[:, 1],
c="w",
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = torch.vstack((x_train, _components))
mesh_input, xx, yy = pt.utils.mesh2d(x, self.border, self.resolution)
with torch.no_grad():
y_pred = self.model(
torch.Tensor(mesh_input).type_as(_components)).argmax(1)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
plt.pause(0.2)
if __name__ == "__main__":
train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
model = CBC(train_ds)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = pt.losses.MarginLoss(margin=0.1)
vis = VisCBC2D(model, train_ds)
for epoch in range(200):
correct = 0.0
for x, y in train_loader:
y_oh = torch.eye(3)[y]
y_pred = model(x)
loss = criterion(y_pred, y_oh).mean(0)
optimizer.zero_grad()
loss.backward()
optimizer.step()
correct += (y_pred.argmax(1) == y).float().sum(0)
acc = 100 * correct / len(train_ds)
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
vis.on_epoch_end()

View File

@@ -1,39 +1,35 @@
"""This example script shows the usage of the new components architecture.
Serialization/deserialization also works as expected.
"""
# DATASET
import torch
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
scaler.fit(x_train)
x_train = scaler.transform(x_train)
import prototorch as pt
x_train = torch.Tensor(x_train)
y_train = torch.Tensor(y_train)
num_classes = len(torch.unique(y_train))
ds = pt.datasets.Iris()
# CREATE NEW COMPONENTS
from prototorch.components import *
from prototorch.components.initializers import *
unsupervised = Components(6, SelectionInitializer(x_train))
unsupervised = pt.components.Components(
6,
initializer=pt.initializers.ZCI(2),
)
print(unsupervised())
prototypes = LabeledComponents(
(3, 2), StratifiedSelectionInitializer(x_train, y_train))
prototypes = pt.components.LabeledComponents(
(3, 2),
components_initializer=pt.initializers.SSCI(ds),
)
print(prototypes())
components = ReasoningComponents(
(3, 6), StratifiedSelectionInitializer(x_train, y_train))
print(components())
components = pt.components.ReasoningComponents(
(3, 2),
components_initializer=pt.initializers.SSCI(ds),
reasonings_initializer=pt.initializers.PPRI(),
)
print(prototypes())
# TEST SERIALIZATION
# Test Serialization
import io
save = io.BytesIO()
@@ -41,25 +37,20 @@ torch.save(unsupervised, save)
save.seek(0)
serialized_unsupervised = torch.load(save)
assert torch.all(unsupervised.components == serialized_unsupervised.components
), "Serialization of Components failed."
assert torch.all(unsupervised.components == serialized_unsupervised.components)
save = io.BytesIO()
torch.save(prototypes, save)
save.seek(0)
serialized_prototypes = torch.load(save)
assert torch.all(prototypes.components == serialized_prototypes.components
), "Serialization of Components failed."
assert torch.all(prototypes.component_labels == serialized_prototypes.
component_labels), "Serialization of Components failed."
assert torch.all(prototypes.components == serialized_prototypes.components)
assert torch.all(prototypes.labels == serialized_prototypes.labels)
save = io.BytesIO()
torch.save(components, save)
save.seek(0)
serialized_components = torch.load(save)
assert torch.all(components.components == serialized_components.components
), "Serialization of Components failed."
assert torch.all(components.reasonings == serialized_components.reasonings
), "Serialization of Components failed."
assert torch.all(components.components == serialized_components.components)
assert torch.all(components.reasonings == serialized_components.reasonings)