From 0b2aaa42b8c8ee00ab5ba91f5dc7347887487e7e Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 11 Jun 2021 23:08:32 +0200 Subject: [PATCH] Add utils test suite --- tests/test_utils.py | 47 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..e8a5e06 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,47 @@ +"""ProtoTorch utils test suite""" + +import numpy as np +import torch + +import prototorch as pt + + +def test_mesh2d_without_input(): + mesh, xx, yy = pt.utils.mesh2d(border=2.0, resolution=10) + assert mesh.shape[0] == 100 + assert mesh.shape[1] == 2 + assert xx.shape[0] == 10 + assert xx.shape[1] == 10 + assert yy.shape[0] == 10 + assert yy.shape[1] == 10 + assert np.min(xx) == -2.0 + assert np.max(xx) == 2.0 + assert np.min(yy) == -2.0 + assert np.max(yy) == 2.0 + + +def test_mesh2d_with_torch_input(): + x = 10 * torch.rand(5, 2) + mesh, xx, yy = pt.utils.mesh2d(x, border=0.0, resolution=100) + assert mesh.shape[0] == 100 * 100 + assert mesh.shape[1] == 2 + assert xx.shape[0] == 100 + assert xx.shape[1] == 100 + assert yy.shape[0] == 100 + assert yy.shape[1] == 100 + assert np.min(xx) == x[:, 0].min() + assert np.max(xx) == x[:, 0].max() + assert np.min(yy) == x[:, 1].min() + assert np.max(yy) == x[:, 1].max() + + +def test_hex_to_rgb(): + red_rgb = list(pt.utils.hex_to_rgb(["#ff0000"]))[0] + assert red_rgb[0] == 255 + assert red_rgb[1] == 0 + assert red_rgb[2] == 0 + + +def test_rgb_to_hex(): + blue_hex = list(pt.utils.rgb_to_hex([(0, 0, 255)]))[0] + assert blue_hex.lower() == "0000ff"