2021-05-21 14:22:02 +00:00
|
|
|
"""ProtoTorch components test suite."""
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2021-06-16 11:46:09 +00:00
|
|
|
import prototorch as pt
|
|
|
|
|
2021-05-21 14:22:02 +00:00
|
|
|
|
|
|
|
def test_labcomps_zeros_init():
|
|
|
|
protos = torch.zeros(3, 2)
|
|
|
|
c = pt.components.LabeledComponents(
|
|
|
|
distribution=[1, 1, 1],
|
|
|
|
initializer=pt.components.Zeros(2),
|
|
|
|
)
|
|
|
|
assert (c.components == protos).any() == True
|
|
|
|
|
|
|
|
|
|
|
|
def test_labcomps_warmstart():
|
|
|
|
protos = torch.randn(3, 2)
|
|
|
|
plabels = torch.tensor([1, 2, 3])
|
|
|
|
c = pt.components.LabeledComponents(
|
|
|
|
distribution=[1, 1, 1],
|
|
|
|
initializer=None,
|
|
|
|
initialized_components=[protos, plabels],
|
|
|
|
)
|
|
|
|
assert (c.components == protos).any() == True
|
|
|
|
assert (c.component_labels == plabels).any() == True
|