diff --git a/tests/test_components.py b/tests/test_components.py deleted file mode 100644 index 03bc215..0000000 --- a/tests/test_components.py +++ /dev/null @@ -1,25 +0,0 @@ -"""ProtoTorch components test suite.""" - -import prototorch as pt -import torch - - -def test_labcomps_zeros_init(): - protos = torch.zeros(3, 2) - c = pt.components.LabeledComponents( - distribution=[1, 1, 1], - initializer=pt.components.Zeros(2), - ) - assert (c.components == protos).any() == True - - -def test_labcomps_warmstart(): - protos = torch.randn(3, 2) - plabels = torch.tensor([1, 2, 3]) - c = pt.components.LabeledComponents( - distribution=[1, 1, 1], - initializer=None, - initialized_components=[protos, plabels], - ) - assert (c.components == protos).any() == True - assert (c.component_labels == plabels).any() == True