Remove test_components.py
This commit is contained in:
parent
4a99bcbf0d
commit
5e72fd8187
@ -1,25 +0,0 @@
|
||||
"""ProtoTorch components test suite."""
|
||||
|
||||
import prototorch as pt
|
||||
import torch
|
||||
|
||||
|
||||
def test_labcomps_zeros_init():
|
||||
protos = torch.zeros(3, 2)
|
||||
c = pt.components.LabeledComponents(
|
||||
distribution=[1, 1, 1],
|
||||
initializer=pt.components.Zeros(2),
|
||||
)
|
||||
assert (c.components == protos).any() == True
|
||||
|
||||
|
||||
def test_labcomps_warmstart():
|
||||
protos = torch.randn(3, 2)
|
||||
plabels = torch.tensor([1, 2, 3])
|
||||
c = pt.components.LabeledComponents(
|
||||
distribution=[1, 1, 1],
|
||||
initializer=None,
|
||||
initialized_components=[protos, plabels],
|
||||
)
|
||||
assert (c.components == protos).any() == True
|
||||
assert (c.component_labels == plabels).any() == True
|
Loading…
Reference in New Issue
Block a user