From 07a2d6caaae6b95db5a1840f63e09107575710cf Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Fri, 15 Oct 2021 13:08:19 +0200 Subject: [PATCH] feat: Add new mesh util --- prototorch/utils/utils.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/prototorch/utils/utils.py b/prototorch/utils/utils.py index 2d7462a..46ed01b 100644 --- a/prototorch/utils/utils.py +++ b/prototorch/utils/utils.py @@ -13,6 +13,32 @@ import torch from torch.utils.data import DataLoader, Dataset +def generate_mesh( + minima: torch.TensorType, + maxima: torch.TensorType, + border: float = 1.0, + resolution: int = 100, + device: torch.device = None, +): + # Apply Border + ptp = maxima - minima + shift = border * ptp + minima -= shift + maxima += shift + + # Generate Mesh + minima = minima.to(device).unsqueeze(1) + maxima = maxima.to(device).unsqueeze(1) + + factors = torch.linspace(0, 1, resolution, device=device) + marginals = factors * maxima + ((1 - factors) * minima) + + single_dimensions = torch.meshgrid(*marginals) + mesh_input = torch.stack([dim.ravel() for dim in single_dimensions], dim=1) + + return mesh_input, single_dimensions + + def mesh2d(x=None, border: float = 1.0, resolution: int = 100): if x is not None: x_shift = border * np.ptp(x[:, 0])