feat: Add new mesh util
This commit is contained in:
parent
3d3d27fbab
commit
07a2d6caaa
@ -13,6 +13,32 @@ import torch
|
|||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
|
||||||
|
def generate_mesh(
|
||||||
|
minima: torch.TensorType,
|
||||||
|
maxima: torch.TensorType,
|
||||||
|
border: float = 1.0,
|
||||||
|
resolution: int = 100,
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
# Apply Border
|
||||||
|
ptp = maxima - minima
|
||||||
|
shift = border * ptp
|
||||||
|
minima -= shift
|
||||||
|
maxima += shift
|
||||||
|
|
||||||
|
# Generate Mesh
|
||||||
|
minima = minima.to(device).unsqueeze(1)
|
||||||
|
maxima = maxima.to(device).unsqueeze(1)
|
||||||
|
|
||||||
|
factors = torch.linspace(0, 1, resolution, device=device)
|
||||||
|
marginals = factors * maxima + ((1 - factors) * minima)
|
||||||
|
|
||||||
|
single_dimensions = torch.meshgrid(*marginals)
|
||||||
|
mesh_input = torch.stack([dim.ravel() for dim in single_dimensions], dim=1)
|
||||||
|
|
||||||
|
return mesh_input, single_dimensions
|
||||||
|
|
||||||
|
|
||||||
def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
||||||
if x is not None:
|
if x is not None:
|
||||||
x_shift = border * np.ptp(x[:, 0])
|
x_shift = border * np.ptp(x[:, 0])
|
||||||
|
Loading…
Reference in New Issue
Block a user