Ruining Li commited on
Commit
ff8abe3
·
1 Parent(s): c388aa5

Modify post-processing

Browse files
Files changed (1) hide show
  1. infer_asset.py +11 -11
infer_asset.py CHANGED
@@ -328,7 +328,7 @@ def compute_part_components_for_mesh_cc(mesh, mesh_cc_faces, current_face_part_i
328
  return components
329
 
330
 
331
- def refine_part_ids_for_faces(mesh, face_part_ids):
332
  """
333
  Refine face part IDs to ensure each part ID forms a single connected component.
334
 
@@ -346,8 +346,6 @@ def refine_part_ids_for_faces(mesh, face_part_ids):
346
  Returns:
347
  refined_face_part_ids: refined part ID for each face [num_faces]
348
  """
349
- face_part_ids_final_strict = refine_part_ids_strict(mesh, face_part_ids)
350
-
351
  face_part_ids_final = face_part_ids.copy() # Don't modify the input
352
 
353
  # Step 1: Find connected components of the original mesh (immutable structure)
@@ -398,10 +396,10 @@ def refine_part_ids_for_faces(mesh, face_part_ids):
398
  face_part_ids_final[comps[comp_idx]['faces']] = chosen_part_id
399
  break
400
 
401
- return face_part_ids_final_strict, face_part_ids_final
402
 
403
 
404
- def find_part_ids_for_faces(mesh, part_ids, face_indices):
405
  """
406
  Assign part IDs to each face in the mesh.
407
 
@@ -433,9 +431,11 @@ def find_part_ids_for_faces(mesh, part_ids, face_indices):
433
  counts = np.bincount(point_part_ids)
434
  majority_part_id = np.argmax(counts)
435
  face_part_ids[face_idx] = majority_part_id
436
-
437
- face_part_ids_refined_strict, face_part_ids_refined = refine_part_ids_for_faces(mesh, face_part_ids)
438
- return face_part_ids, face_part_ids_refined_strict, face_part_ids_refined
 
 
439
 
440
 
441
  @torch.no_grad()
@@ -501,12 +501,12 @@ def save_articulated_meshes(mesh, face_indices, outputs, output_path, strict, an
501
  prismatic_axis = outputs[hyp_idx]['prismatic_axis']
502
  prismatic_range = outputs[hyp_idx]['prismatic_range']
503
 
504
- _, face_part_ids_refined_strict, face_part_ids_refined = find_part_ids_for_faces(
505
  mesh,
506
  part_ids,
507
- face_indices
 
508
  )
509
- face_part_ids = face_part_ids_refined_strict if strict else face_part_ids_refined
510
  unique_part_ids = np.unique(face_part_ids)
511
  num_parts = len(unique_part_ids)
512
  print(f"Found {num_parts} unique parts")
 
328
  return components
329
 
330
 
331
+ def refine_part_ids_nonstrict(mesh, face_part_ids):
332
  """
333
  Refine face part IDs to ensure each part ID forms a single connected component.
334
 
 
346
  Returns:
347
  refined_face_part_ids: refined part ID for each face [num_faces]
348
  """
 
 
349
  face_part_ids_final = face_part_ids.copy() # Don't modify the input
350
 
351
  # Step 1: Find connected components of the original mesh (immutable structure)
 
396
  face_part_ids_final[comps[comp_idx]['faces']] = chosen_part_id
397
  break
398
 
399
+ return face_part_ids_final
400
 
401
 
402
+ def find_part_ids_for_faces(mesh, part_ids, face_indices, strict=False):
403
  """
404
  Assign part IDs to each face in the mesh.
405
 
 
431
  counts = np.bincount(point_part_ids)
432
  majority_part_id = np.argmax(counts)
433
  face_part_ids[face_idx] = majority_part_id
434
+
435
+ if strict:
436
+ return refine_part_ids_strict(mesh, face_part_ids)
437
+ else:
438
+ return refine_part_ids_nonstrict(mesh, face_part_ids)
439
 
440
 
441
  @torch.no_grad()
 
501
  prismatic_axis = outputs[hyp_idx]['prismatic_axis']
502
  prismatic_range = outputs[hyp_idx]['prismatic_range']
503
 
504
+ face_part_ids = find_part_ids_for_faces(
505
  mesh,
506
  part_ids,
507
+ face_indices,
508
+ strict=strict
509
  )
 
510
  unique_part_ids = np.unique(face_part_ids)
511
  num_parts = len(unique_part_ids)
512
  print(f"Found {num_parts} unique parts")