Update components
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
#
|
||||
# DATASET
|
||||
#
|
||||
import torch
|
||||
"""This example script shows the usage of the new components architecture.
|
||||
|
||||
Serialization/deserialization also works as expected.
|
||||
"""
|
||||
|
||||
# DATASET
|
||||
import torch
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
@@ -16,9 +18,7 @@ x_train = torch.Tensor(x_train)
|
||||
y_train = torch.Tensor(y_train)
|
||||
num_classes = len(torch.unique(y_train))
|
||||
|
||||
#
|
||||
# CREATE NEW COMPONENTS
|
||||
#
|
||||
from prototorch.components import *
|
||||
from prototorch.components.initializers import *
|
||||
|
||||
@@ -33,9 +33,7 @@ components = ReasoningComponents(
|
||||
(3, 6), StratifiedSelectionInitializer(x_train, y_train))
|
||||
print(components())
|
||||
|
||||
#
|
||||
# TEST SERIALIZATION
|
||||
#
|
||||
import io
|
||||
|
||||
save = io.BytesIO()
|
||||
@@ -53,8 +51,8 @@ serialized_prototypes = torch.load(save)
|
||||
|
||||
assert torch.all(prototypes.components == serialized_prototypes.components
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(prototypes.labels == serialized_prototypes.labels
|
||||
), "Serialization of Components failed."
|
||||
assert torch.all(prototypes.component_labels == serialized_prototypes.
|
||||
component_labels), "Serialization of Components failed."
|
||||
|
||||
save = io.BytesIO()
|
||||
torch.save(components, save)
|
||||
|
Reference in New Issue
Block a user