Update components

This commit is contained in:
Jensun Ravichandran
2021-04-29 18:06:26 +02:00
parent a70166280a
commit 9b663477fd
3 changed files with 41 additions and 45 deletions

View File

@@ -1,8 +1,10 @@
#
# DATASET
#
import torch
"""This example script shows the usage of the new components architecture.
Serialization/deserialization also works as expected.
"""
# DATASET
import torch
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
@@ -16,9 +18,7 @@ x_train = torch.Tensor(x_train)
y_train = torch.Tensor(y_train)
num_classes = len(torch.unique(y_train))
#
# CREATE NEW COMPONENTS
#
from prototorch.components import *
from prototorch.components.initializers import *
@@ -33,9 +33,7 @@ components = ReasoningComponents(
(3, 6), StratifiedSelectionInitializer(x_train, y_train))
print(components())
#
# TEST SERIALIZATION
#
import io
save = io.BytesIO()
@@ -53,8 +51,8 @@ serialized_prototypes = torch.load(save)
assert torch.all(prototypes.components == serialized_prototypes.components
), "Serialization of Components failed."
assert torch.all(prototypes.labels == serialized_prototypes.labels
), "Serialization of Components failed."
assert torch.all(prototypes.component_labels == serialized_prototypes.
component_labels), "Serialization of Components failed."
save = io.BytesIO()
torch.save(components, save)