Remove test_components.py
This commit is contained in:
		@@ -1,25 +0,0 @@
 | 
				
			|||||||
"""ProtoTorch components test suite."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_labcomps_zeros_init():
 | 
					 | 
				
			||||||
    protos = torch.zeros(3, 2)
 | 
					 | 
				
			||||||
    c = pt.components.LabeledComponents(
 | 
					 | 
				
			||||||
        distribution=[1, 1, 1],
 | 
					 | 
				
			||||||
        initializer=pt.components.Zeros(2),
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    assert (c.components == protos).any() == True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_labcomps_warmstart():
 | 
					 | 
				
			||||||
    protos = torch.randn(3, 2)
 | 
					 | 
				
			||||||
    plabels = torch.tensor([1, 2, 3])
 | 
					 | 
				
			||||||
    c = pt.components.LabeledComponents(
 | 
					 | 
				
			||||||
        distribution=[1, 1, 1],
 | 
					 | 
				
			||||||
        initializer=None,
 | 
					 | 
				
			||||||
        initialized_components=[protos, plabels],
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    assert (c.components == protos).any() == True
 | 
					 | 
				
			||||||
    assert (c.component_labels == plabels).any() == True
 | 
					 | 
				
			||||||
		Reference in New Issue
	
	Block a user