feat: Add new mesh util
This commit is contained in:
parent
3d3d27fbab
commit
07a2d6caaa
@ -13,6 +13,32 @@ import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
def generate_mesh(
|
||||
minima: torch.TensorType,
|
||||
maxima: torch.TensorType,
|
||||
border: float = 1.0,
|
||||
resolution: int = 100,
|
||||
device: torch.device = None,
|
||||
):
|
||||
# Apply Border
|
||||
ptp = maxima - minima
|
||||
shift = border * ptp
|
||||
minima -= shift
|
||||
maxima += shift
|
||||
|
||||
# Generate Mesh
|
||||
minima = minima.to(device).unsqueeze(1)
|
||||
maxima = maxima.to(device).unsqueeze(1)
|
||||
|
||||
factors = torch.linspace(0, 1, resolution, device=device)
|
||||
marginals = factors * maxima + ((1 - factors) * minima)
|
||||
|
||||
single_dimensions = torch.meshgrid(*marginals)
|
||||
mesh_input = torch.stack([dim.ravel() for dim in single_dimensions], dim=1)
|
||||
|
||||
return mesh_input, single_dimensions
|
||||
|
||||
|
||||
def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
||||
if x is not None:
|
||||
x_shift = border * np.ptp(x[:, 0])
|
||||
|
Loading…
Reference in New Issue
Block a user