Physics Simulation

Mesh to Signed Distance Function

HJ Choi 2022. 11. 21. 12:05

This article is for personal use to review the way to create signed distance functions from 3D mesh using JAX.

Triangle Mesh (Tri-mesh)

Mesh is the common data representation of 3D objects. However, there are different kinds of mesh. Triangle mesh and quadrilateral mesh are the most common types that only save surface information. Quadilateral mesh is more accurate and mainly used for flow simulation or finite element analysis. Instead, triangle meshes are easier to compute and are commonly used in graphics. Therefore, since we just need shape information for signed distance function (SDF) creation, we will only use tri-mesh, which is easy to use.

Tri-mesh data has vertices and triangle data. Vertices data stores an index of points in triangles and their coordinates. Triangle data stores an index of triangles and their vertices' indices. Triangle data also has vertices' indices in order so that, when we use the cross product to calculate normals, all normals will always direct the inner or outer side of the mesh. Knowing this is important since we are going to save gradient information in SDF grids as well.

Mesh SDF Creation

Point to Triangle Distance

One simple and accurate way to create an SDF grid from mesh is to calculate the distance between grid points and all mesh triangles and select the minimum distance. The method for calculating the distance between two points is well described on Inigo Quilez's blog.

# a: 1st vertice of triangle
# b: 2nd vertice of triangle
# c: 3rd vertice of triangle
# p: distance parse point

ba = b - a
cb = c - b
ac = a - c
nor = jp.cross(ba, ac)

def dot(u, v):
  return jp.sum(u*v, axis=-1, keepdims=True)

def triDistance(p):
  pa = p - a
  pb = p - b
  pc = p - c

  l = jp.where(dot(jp.cross(ba,nor), pa) < 0, ba*jp.clip(dot(ba,pa)/dot(ba,ba),0.0,1.0)-pa,
               jp.where(dot(jp.cross(cb,nor), pb) < 0, cb*jp.clip(dot(cb,pb)/dot(cb,cb),0.0,1.0)-pb,
                        jp.where(dot(jp.cross(ac,nor), pc) < 0, ac*jp.clip(dot(ac,pc)/dot(ac,ac),0.0,1.0)-pc,
                                 nor*pa/jp.linalg.norm(nor, axis=-1, keepdims=True))))

  d = jp.min(jp.linalg.norm(l, axis=-1, keepdims=True))
  return d

You can see that the code for calculating the distance is pretty simple. However, there are some problems.

Signed Distance

The distance that you acquire from "point to triangle distance" is not a "signed" distance, meaning that it will have a non-negative value. To know if the point is inside or outside of the mesh, you have to use ray-strip intersection functions, which are usually used for ray tracing. If the parsing point is inside the mesh, the unidirectional ray will intersect with an odd number of triangles, whereas the ray of the point outside the mesh will intersect with an even number of triangles. The intersection function for a ray and a triangle is also well described on Inigo Quilez's blog.

# a: 1st vertice of triangle
# b: 2nd vertice of triangle
# c: 3rd vertice of triangle
# p: ray starting point
# rd: ray direction

ba = b - a
cb = c - b
ac = a - c
nor = jp.cross(ba, ac)

def dot(u, v):
  return jp.sum(u*v, axis=-1, keepdims=True)

def intersect(pa,rd):
  q = jp.cross(pa, rd)
  d = -1/dot(rd,nor)
  u = d*dot(q,ac)
  v = d*dot(q,ba)
  t = d*dot(nor,pa)

  # 1 if the point is inside the mesh and 0 if the point is outside the mesh
  return jp.sum(jp.where(jp.logical_or(jp.logical_or(u<0, jp.logical_or(v<0, u+v>1)),t<0), 0, 1)) % 2

Based on the intersection result, you have to change the sign of the distance. However, there are some exceptions, like the ray contacting the triangle but not intersecting, and the ray exactly crossing the edge between two triangles. To solve this problem, you have to use ray directions that are not parallel to the x, y, and z axes, which have the risk of contacting the triangle since humans tend to touch the ground mesh's flat surface or straight lines parallel to the axes. Or, you can use multiple rays and select based on the majority.

Gradient / Normal

This is rather simple. To get the normal information, you just have to get the normal of the triangle that gives the minimum distance from the point. Only a few lines have to be changed or added to "point to triangle distance."

d = jp.linalg.norm(l, axis=-1, keepdims=True)
index = jp.argmin(d)

d = d.at[index].get()
grad = -nor.at[index].get() # the normal facing outside of the mesh
grad = grad / jp.linalg.norm(grad, axis=-1, keepdims=True)

Optimization

If you use these codes to create SDF, you will probably only see an out-of-memory (OOM) warning. This many calculations and variables are beyond JAX and Colab's capabilities. Calculating "point to triangle distance" is particularly expensive. That's why commercial SDF creation algorithms use the sweeping algorithm or ray intersection algorithm to calculate distance only for triangles that multiple rays from the point intersect. These are good enough, but I will use different methods for optimization to get the exact solution using "point to triangle distance".

The first way is to use bounding spheres of triangles. Before calculating the triangle distance, we calculate the bounding sphere for triangles, which has the middle point at the average of the three vertices of the triangle and a radius of distance between the middle point and a vertice that gives the maximum distance. The below image is a visualization of bounding spheres for a polygon bunny mesh.

Then, when we try to get the distance of a parsed point, we first calculate the triangle distance with the closest sphere, then skip spheres that are further than the triangle distance and only calculate triangles that are within the triangle distance. This reduces the calculation a lot. However, since JAX's JIT does not support masking, you have to find a way to effectively mimic the masking behavior.

The second way is to use local variables rather than global variables by using small functions. However, this causes Spagetthi to code. Therefore, even though we are using Python, we can explicitly free variable memories. This is actually very helpful for reducing GPU memory usage.

The last way to optimize is to serialize the calculation. I know that this sounds weird because we are using JAX for parallelization to speed things up. This is necessary because the GPU's or TPU's memory will go OOM if you don't. Therefore, we have to sacrifice time for memory. You can do this by spliting paring points' array.

The optimized code is below.

# reducing the GPU memory usage in Colab
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

ba = b - a
cb = c - b
ac = a - c
nor = jp.cross(ba, ac)

def dot(u, v):
  return jp.sum(u*v, axis=-1, keepdims=True)

# ray-intersection function
def intersect(pa,rd):
  q = jp.cross(pa, rd)
  d = -1/dot(rd,nor)
  u = d*dot(q,ac)
  v = d*dot(q,ba)
  t = d*dot(nor,pa)
  t = jp.sum(jp.where(jp.logical_or(jp.logical_or(u<0, jp.logical_or(v<0, u+v>1)),t<0), 0, 1)) % 2
  return t

# triangle distance function
def triDistance(p):
  pa = p - a
  pb = p - b
  pc = p - c

  sphere_l = ccs-p
  sphere_p = jp.linalg.norm(sphere_l, axis=-1) - ccsR
  sphere_id = jp.argmin(sphere_p) 
  sphere_d = jp.linalg.norm(sphere_l.at[sphere_id].get(), axis=-1)

  l = jp.where(jp.expand_dims(sphere_p<sphere_d, axis=-1),
               jp.where(dot(jp.cross(ba,nor), pa) < 0, ba*jp.clip(dot(ba,pa)/dot(ba,ba),0.0,1.0)-pa,
                        jp.where(dot(jp.cross(cb,nor), pb) < 0, cb*jp.clip(dot(cb,pb)/dot(cb,cb),0.0,1.0)-pb,
                                 jp.where(dot(jp.cross(ac,nor), pc) < 0, ac*jp.clip(dot(ac,pc)/dot(ac,ac),0.0,1.0)-pc,
                                          nor*pa/jp.linalg.norm(nor, axis=-1, keepdims=True)))),
               sphere_l)

  d = jp.linalg.norm(l, axis=-1, keepdims=True)
  index = jp.argmin(d)

  del(l)
  del(sphere_l)
  del(sphere_p)
  del(sphere_id)
  del(sphere_d)
  del(pb)
  del(pc)

  d = d.at[index].get()
  grad = -nor.at[index].get()
  grad = grad / jp.linalg.norm(grad, axis=-1, keepdims=True)

  t = intersect(pa,jp.array([1,1,0]))

  d = jp.where(t==0, d, -d)

  del(pa)
  del(t)

  return jp.concatenate((grad, d), axis=-1)

Mesh SDF Results

The first row is the SDF result of the polygon version of Stanford Bunny. I found out that normal information about discrete shapes with sharp angles gets a little weird. However, high-fidelity mesh produces excellent results, as seen in the second and third rows. The 2nd row is from a 31x26x32 SDF grid, which took 14 seconds to create. The third row is from a 127x104x128 SDF grid, which took 5 minutes to create. It takes a long time, but this is pretty fast for a personal SDF creation. Below are results from other meshes.

In the article "Signed Distance Function Visualization," I said that the first-order Taylor approximation of SDF does not work well. However, I found that it works well for meshes. Only simple geometric shapes have problems with the first-order approximation. This is good for us because not only do we already have gradients saved with distance in the SDF grid, but it also gives better results, allowing us to use smaller grids for similar accuracy.