Add utils test suite
This commit is contained in:
parent
abae72d624
commit
0b2aaa42b8
47
tests/test_utils.py
Normal file
47
tests/test_utils.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
"""ProtoTorch utils test suite"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
|
|
||||||
|
def test_mesh2d_without_input():
|
||||||
|
mesh, xx, yy = pt.utils.mesh2d(border=2.0, resolution=10)
|
||||||
|
assert mesh.shape[0] == 100
|
||||||
|
assert mesh.shape[1] == 2
|
||||||
|
assert xx.shape[0] == 10
|
||||||
|
assert xx.shape[1] == 10
|
||||||
|
assert yy.shape[0] == 10
|
||||||
|
assert yy.shape[1] == 10
|
||||||
|
assert np.min(xx) == -2.0
|
||||||
|
assert np.max(xx) == 2.0
|
||||||
|
assert np.min(yy) == -2.0
|
||||||
|
assert np.max(yy) == 2.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_mesh2d_with_torch_input():
|
||||||
|
x = 10 * torch.rand(5, 2)
|
||||||
|
mesh, xx, yy = pt.utils.mesh2d(x, border=0.0, resolution=100)
|
||||||
|
assert mesh.shape[0] == 100 * 100
|
||||||
|
assert mesh.shape[1] == 2
|
||||||
|
assert xx.shape[0] == 100
|
||||||
|
assert xx.shape[1] == 100
|
||||||
|
assert yy.shape[0] == 100
|
||||||
|
assert yy.shape[1] == 100
|
||||||
|
assert np.min(xx) == x[:, 0].min()
|
||||||
|
assert np.max(xx) == x[:, 0].max()
|
||||||
|
assert np.min(yy) == x[:, 1].min()
|
||||||
|
assert np.max(yy) == x[:, 1].max()
|
||||||
|
|
||||||
|
|
||||||
|
def test_hex_to_rgb():
|
||||||
|
red_rgb = list(pt.utils.hex_to_rgb(["#ff0000"]))[0]
|
||||||
|
assert red_rgb[0] == 255
|
||||||
|
assert red_rgb[1] == 0
|
||||||
|
assert red_rgb[2] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_rgb_to_hex():
|
||||||
|
blue_hex = list(pt.utils.rgb_to_hex([(0, 0, 255)]))[0]
|
||||||
|
assert blue_hex.lower() == "0000ff"
|
Loading…
Reference in New Issue
Block a user