import numpy as np from pathlib import Path import torch import trimesh from particulate.visualization_utils import ( get_3D_arrow_on_points, create_arrow, create_ring, create_textured_mesh_parts, ARROW_COLOR_REVOLUTE, ARROW_COLOR_PRISMATIC ) from particulate.articulation_utils import plucker_to_axis_point from particulate.export_utils import export_animated_glb_file from partfield_utils import obtain_partfield_feats, get_partfield_model DATA_CONFIG = { 'sharp_point_ratio': 0.5, 'normalize_points': True } def sharp_sample_pointcloud(mesh, num_points: int = 8192): V = mesh.vertices N = mesh.face_normals F = mesh.faces # Build edge-to-faces mapping # Each edge is represented as (min_vertex_id, max_vertex_id) to ensure consistent ordering edge_to_faces = {} for face_idx in range(len(F)): face = F[face_idx] # Get the three edges of this face edges = [ (face[0], face[1]), (face[1], face[2]), (face[2], face[0]) ] for edge in edges: # Normalize edge ordering (min vertex first) edge_key = tuple(sorted(edge)) if edge_key not in edge_to_faces: edge_to_faces[edge_key] = [] edge_to_faces[edge_key].append(face_idx) # Identify sharp edges based on face normal angles and store their averaged normals sharp_edges = [] sharp_edge_normals = [] sharp_edge_faces = [] # Store adjacent faces for each sharp edge cos_30 = np.cos(np.radians(30)) # ≈ 0.866 cos_150 = np.cos(np.radians(150)) # ≈ -0.866 for edge_key, face_indices in edge_to_faces.items(): # Check if edge has >= 2 faces if len(face_indices) < 2: continue # Check all pairs of face normals is_sharp = False for i in range(len(face_indices)): for j in range(i + 1, len(face_indices)): n1 = N[face_indices[i]] n2 = N[face_indices[j]] dot_product = np.dot(n1, n2) # Check if angle is between 30 and 150 degrees if cos_150 < dot_product < cos_30 and np.linalg.norm(n1) > 1e-8 and np.linalg.norm(n2) > 1e-8: is_sharp = True sharp_edges.append(edge_key) averaged_normal = (n1 + n2) / 2 sharp_edge_normals.append(averaged_normal) sharp_edge_faces.append(face_indices) # Store all adjacent faces break if is_sharp: break # Convert sharp edges to vertex arrays edge_a = np.array([edge[0] for edge in sharp_edges], dtype=np.int32) edge_b = np.array([edge[1] for edge in sharp_edges], dtype=np.int32) sharp_edge_normals = np.array(sharp_edge_normals, dtype=np.float64) # Handle the case where there are no sharp edges if len(sharp_edges) == 0: # Return empty arrays with appropriate shape samples = np.zeros((0, 3), dtype=np.float64) normals = np.zeros((0, 3), dtype=np.float64) edge_indices = np.zeros((0,), dtype=np.int32) return samples, normals, edge_indices, sharp_edge_faces sharp_verts_a = V[edge_a] sharp_verts_b = V[edge_b] weights = np.linalg.norm(sharp_verts_b - sharp_verts_a, axis=-1) weights /= np.sum(weights) random_number = np.random.rand(num_points) w = np.random.rand(num_points, 1) index = np.searchsorted(weights.cumsum(), random_number) samples = w * sharp_verts_a[index] + (1 - w) * sharp_verts_b[index] normals = sharp_edge_normals[index] # Use the averaged face normal for each edge return samples, normals, index, sharp_edge_faces def sample_points(mesh, num_points, sharp_point_ratio, at_least_one_point_per_face=False): """Sample points from mesh using sharp edge and uniform sampling.""" num_points_sharp_edges = int(num_points * sharp_point_ratio) num_points_uniform = num_points - num_points_sharp_edges points_sharp, normals_sharp, edge_indices, sharp_edge_faces = sharp_sample_pointcloud(mesh, num_points_sharp_edges) # If no sharp edges were found, sample all points uniformly if len(points_sharp) == 0 and sharp_point_ratio > 0: print(f"Warning: No sharp edges found, sampling all points uniformly") num_points_uniform = num_points if at_least_one_point_per_face: num_faces = len(mesh.faces) if num_points_uniform < num_faces: raise ValueError( "Unable to sample at least one point per face: " f"{num_faces} faces > 51.2k" ) # Get a random permutation of face indices face_perm = np.random.permutation(num_faces) # Sample one point from each face points_per_face = [] for face_idx in face_perm: # Sample one random point on this face using barycentric coordinates r1, r2 = np.random.random(), np.random.random() sqrt_r1 = np.sqrt(r1) # Barycentric coordinates u = 1 - sqrt_r1 v = sqrt_r1 * (1 - r2) w = sqrt_r1 * r2 # Get vertices of the face face = mesh.faces[face_idx] vertices = mesh.vertices[face] # Compute point using barycentric coordinates point = u * vertices[0] + v * vertices[1] + w * vertices[2] points_per_face.append(point) points_per_face = np.array(points_per_face) normals_per_face = mesh.face_normals[face_perm] # Sample remaining points uniformly num_remaining_points = num_points_uniform - num_faces if num_remaining_points > 0: points_remaining, face_indices_remaining = mesh.sample(num_remaining_points, return_index=True) normals_remaining = mesh.face_normals[face_indices_remaining] points_uniform = np.concatenate([points_per_face, points_remaining], axis=0) normals_uniform = np.concatenate([normals_per_face, normals_remaining], axis=0) face_indices = np.concatenate([face_perm, face_indices_remaining], axis=0) else: points_uniform = points_per_face normals_uniform = normals_per_face face_indices = face_perm else: points_uniform, face_indices = mesh.sample(num_points_uniform, return_index=True) normals_uniform = mesh.face_normals[face_indices] points = np.concatenate([points_sharp, points_uniform], axis=0) normals = np.concatenate([normals_sharp, normals_uniform], axis=0) sharp_flag = np.concatenate([ np.ones(len(points_sharp), dtype=np.bool_), np.zeros(len(points_uniform), dtype=np.bool_) ], axis=0) # For each sharp point, randomly select one of the adjacent faces from the edge sharp_face_indices = np.zeros(len(points_sharp), dtype=np.int32) for i, edge_idx in enumerate(edge_indices): adjacent_faces = sharp_edge_faces[edge_idx] # Randomly select one of the adjacent faces sharp_face_indices[i] = np.random.choice(adjacent_faces) face_indices = np.concatenate([ sharp_face_indices, face_indices ], axis=0) return points, normals, sharp_flag, face_indices def prepare_inputs(mesh, num_points_global: int = 40000, num_points_decode: int = 2048, device: str = "cuda"): """Prepare inputs from a mesh file for model inference.""" sharp_point_ratio = DATA_CONFIG['sharp_point_ratio'] all_points, _, _, _ = sample_points(mesh, num_points_global, sharp_point_ratio) points, normals, sharp_flag, face_indices = sample_points(mesh, num_points_decode, sharp_point_ratio, at_least_one_point_per_face=True) if DATA_CONFIG['normalize_points']: bbmin = np.concatenate([all_points, points], axis=0).min(0) bbmax = np.concatenate([all_points, points], axis=0).max(0) center = (bbmin + bbmax) * 0.5 scale = 1.0 / (bbmax - bbmin).max() all_points = (all_points - center) * scale points = (points - center) * scale all_points = torch.from_numpy(all_points).to(device).float().unsqueeze(0) points = torch.from_numpy(points).to(device).float().unsqueeze(0) normals = torch.from_numpy(normals).to(device).float().unsqueeze(0) partfield_model = get_partfield_model(device=device) feats = obtain_partfield_feats(partfield_model, all_points, points) return dict(xyz=points, normals=normals, feats=feats), sharp_flag, face_indices def refine_part_ids_strict(mesh, face_part_ids): """ Refine face part IDs by treating each connected component (CC) in the mesh independently. For each CC, all faces are labeled with the part ID that has the largest surface area in that CC. Args: mesh: trimesh object face_part_ids: part ID for each face [num_faces] Returns: refined_face_part_ids: refined part ID for each face [num_faces] """ face_part_ids = face_part_ids.copy() # Don't modify the input # Use trimesh's built-in connected components functionality # mesh.face_adjacency gives pairs of face indices that share an edge mesh_components = trimesh.graph.connected_components( edges=mesh.face_adjacency, nodes=np.arange(len(mesh.faces)), min_len=1 ) # For each connected component, find the part ID with the largest surface area for component in mesh_components: if len(component) == 0: continue # Collect part IDs in this component and their surface areas part_id_areas = {} for face_idx in component: part_id = face_part_ids[face_idx] if part_id == -1: continue # Skip unassigned faces face_area = mesh.area_faces[face_idx] if part_id not in part_id_areas: part_id_areas[part_id] = 0.0 part_id_areas[part_id] += face_area # Find the part ID with the largest area if len(part_id_areas) == 0: # No valid part IDs in this component, skip continue dominant_part_id = max(part_id_areas.keys(), key=lambda pid: part_id_areas[pid]) # Assign all faces in this component to the dominant part ID for face_idx in component: face_part_ids[face_idx] = dominant_part_id return face_part_ids def compute_part_components_for_mesh_cc(mesh, mesh_cc_faces, current_face_part_ids, face_adjacency_dict): """ Compute part-specific connected components for faces in this mesh CC. Returns a list of dicts with 'faces', 'part_id', and 'area'. Two faces are in the same component if: - They have the same part ID - They are connected through faces of the same part ID """ components = [] # Get unique part IDs in this mesh CC unique_part_ids = np.unique(current_face_part_ids[mesh_cc_faces]) for part_id in unique_part_ids: if part_id == -1: continue # Get faces in this mesh CC with this part ID mask = current_face_part_ids[mesh_cc_faces] == part_id faces_with_part = mesh_cc_faces[mask] if len(faces_with_part) == 0: continue # Convert to set for faster lookup faces_with_part_set = set(faces_with_part) # Build edges between these faces (both must have same part ID and be adjacent) edges_for_part = [] for face_i in faces_with_part: for face_j in face_adjacency_dict[face_i]: if face_j in faces_with_part_set: edges_for_part.append([face_i, face_j]) if len(edges_for_part) == 0: # Each face is its own component for face_i in faces_with_part: components.append({ 'faces': np.array([face_i]), 'part_id': part_id, 'area': mesh.area_faces[face_i] }) else: # Find connected components edges_for_part = np.array(edges_for_part) comps = trimesh.graph.connected_components( edges=edges_for_part, nodes=faces_with_part, min_len=1 ) for comp in comps: comp_faces = np.array(list(comp)) components.append({ 'faces': comp_faces, 'part_id': part_id, 'area': np.sum(mesh.area_faces[comp_faces]) }) return components def refine_part_ids_nonstrict(mesh, face_part_ids): """ Refine face part IDs to ensure each part ID forms a single connected component. For each part ID, if there are multiple disconnected components, the smaller components (by surface area) are reassigned based on adjacent faces' part IDs. This is done iteratively until convergence or max iterations. Args: mesh: trimesh object xyz: sampled points on the mesh [num_points, 3] part_ids: part IDs for each sampled point [num_points] face_indices: which face each point lies on (-1 means on edge) [num_points] face_part_ids: initial part ID for each face [num_faces] Returns: refined_face_part_ids: refined part ID for each face [num_faces] """ face_part_ids_final = face_part_ids.copy() # Don't modify the input # Step 1: Find connected components of the original mesh (immutable structure) mesh_components = trimesh.graph.connected_components( edges=mesh.face_adjacency, nodes=np.arange(len(mesh.faces)), min_len=1 ) mesh_components = [np.array(list(comp)) for comp in mesh_components] # Step 2: Build face adjacency dict (immutable structure) face_adjacency_dict = {i: set() for i in range(len(mesh.faces))} for face_i, face_j in mesh.face_adjacency: face_adjacency_dict[face_i].add(face_j) face_adjacency_dict[face_j].add(face_i) # Step 3: Process each mesh CC independently for mesh_cc_faces in mesh_components: done = False while not done: comps = compute_part_components_for_mesh_cc(mesh, mesh_cc_faces, face_part_ids_final, face_adjacency_dict) comps.sort(key=lambda c: c['area']) part_id_areas = {} for comp in comps: pid = comp['part_id'] if pid not in part_id_areas: part_id_areas[pid] = 0.0 part_id_areas[pid] += comp['area'] done = True for comp_idx in range(len(comps)): current_part_id = comps[comp_idx]['part_id'] if len([c for c in comps if c['part_id'] == current_part_id]) > 1: done = False # Find adjacent components adjacent_part_ids = set() current_faces_set = set(comps[comp_idx]['faces']) for face_i in current_faces_set: for face_j in face_adjacency_dict[face_i]: if face_j in current_faces_set: continue adjacent_part_ids.add(face_part_ids_final[face_j]) chosen_part_id = max(adjacent_part_ids, key=lambda x: part_id_areas[x]) comps[comp_idx]['part_id'] = chosen_part_id face_part_ids_final[comps[comp_idx]['faces']] = chosen_part_id break return face_part_ids_final def find_part_ids_for_faces(mesh, part_ids, face_indices, strict=False): """ Assign part IDs to each face in the mesh. Args: mesh: trimesh object xyz: sampled points on the mesh [num_points, 3] part_ids: part IDs for each sampled point [num_points] face_indices: which face each point lies on (-1 means on edge) [num_points] Returns: face_part_ids: part ID for each face [num_faces] """ num_faces = len(mesh.faces) face_part_ids = np.full(num_faces, -1, dtype=np.int32) # Step 1: Assign part IDs to faces that have points on them # For each face, collect all points that lie on it and use majority vote face_to_points = {} for point_idx, face_idx in enumerate(face_indices): if face_idx == -1: # Point is on an edge, ignore it continue if face_idx not in face_to_points: face_to_points[face_idx] = [] face_to_points[face_idx].append(part_ids[point_idx]) # Assign part IDs based on majority vote from points for face_idx, point_part_ids in face_to_points.items(): # Use bincount to find the majority part ID counts = np.bincount(point_part_ids) majority_part_id = np.argmax(counts) face_part_ids[face_idx] = majority_part_id if strict: return refine_part_ids_strict(mesh, face_part_ids) else: return refine_part_ids_nonstrict(mesh, face_part_ids) @torch.no_grad() def infer_single_asset( mesh, up_dir, model, num_points, min_part_confidence=0.0 ): mesh_transformed = mesh.copy() if up_dir in ["x", "X"]: rotation_matrix = np.array([[0, 0, -1], [0, 1, 0], [1, 0, 0]], dtype=np.float32) mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T elif up_dir in ["-x", "-X"]: rotation_matrix = np.array([[0, 0, 1], [0, 1, 0], [-1, 0, 0]], dtype=np.float32) mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T elif up_dir in ["y", "Y"]: rotation_matrix = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=np.float32) mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T elif up_dir in ["-y", "-Y"]: rotation_matrix = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=np.float32) mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T elif up_dir in ["z", "Z"]: pass elif up_dir in ["-z", "-Z"]: rotation_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]], dtype=np.float32) mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T else: raise ValueError(f"Invalid up direction: {up_dir}") # Normalize mesh to [-0.5, 0.5]^3 bounding box bbox_min = mesh_transformed.vertices.min(axis=0) bbox_max = mesh_transformed.vertices.max(axis=0) center = (bbox_min + bbox_max) / 2 mesh_transformed.vertices -= center # Center the mesh # Scale to fit in [-0.5, 0.5]^3 scale = (bbox_max - bbox_min).max() # Use the largest dimension mesh_transformed.vertices /= scale inputs, sharp_flag, face_indices = prepare_inputs(mesh_transformed, num_points_global=40000, num_points_decode=num_points) with torch.no_grad(): outputs = model.infer( xyz=inputs['xyz'], feats=inputs['feats'], normals=inputs['normals'], output_all_hyps=True, min_part_confidence=min_part_confidence ) return outputs, face_indices, mesh_transformed def save_articulated_meshes(mesh, face_indices, outputs, output_path, strict, animation_frames: int = 50, hyp_idx: int = 0): part_ids = outputs[hyp_idx]['part_ids'] motion_hierarchy = outputs[hyp_idx]['motion_hierarchy'] is_part_revolute = outputs[hyp_idx]['is_part_revolute'] is_part_prismatic = outputs[hyp_idx]['is_part_prismatic'] revolute_plucker = outputs[hyp_idx]['revolute_plucker'] revolute_range = outputs[hyp_idx]['revolute_range'] prismatic_axis = outputs[hyp_idx]['prismatic_axis'] prismatic_range = outputs[hyp_idx]['prismatic_range'] face_part_ids = find_part_ids_for_faces( mesh, part_ids, face_indices, strict=strict ) unique_part_ids = np.unique(face_part_ids) num_parts = len(unique_part_ids) print(f"Found {num_parts} unique parts") mesh_parts_original = [mesh.submesh([face_part_ids == part_id], append=True) for part_id in unique_part_ids] mesh_parts_segmented = create_textured_mesh_parts([mp.copy() for mp in mesh_parts_original]) # Create axes axes = [] for i, mesh_part in enumerate(mesh_parts_segmented): part_id = unique_part_ids[i] if is_part_revolute[part_id]: axis, point = plucker_to_axis_point(revolute_plucker[part_id]) arrow_start, arrow_end = get_3D_arrow_on_points(axis, mesh_part.vertices, fixed_point=point, extension=0.2) axes.append(create_arrow(arrow_start, arrow_end, color=ARROW_COLOR_REVOLUTE, radius=0.01, radius_tip=0.018)) # Add rings at arrow_start and arrow_end arrow_dir = arrow_end - arrow_start axes.append(create_ring(arrow_start, arrow_dir, major_radius=0.03, minor_radius=0.006, color=ARROW_COLOR_REVOLUTE)) axes.append(create_ring(arrow_end, arrow_dir, major_radius=0.03, minor_radius=0.006, color=ARROW_COLOR_REVOLUTE)) elif is_part_prismatic[part_id]: axis = prismatic_axis[part_id] arrow_start, arrow_end = get_3D_arrow_on_points(axis, mesh_part.vertices, extension=0.2) axes.append(create_arrow(arrow_start, arrow_end, color=ARROW_COLOR_PRISMATIC, radius=0.01, radius_tip=0.018)) trimesh.Scene(mesh_parts_segmented + axes).export(Path(output_path) / "mesh_parts_with_axes.glb") print("Exporting animated GLB files...") try: export_animated_glb_file( mesh_parts_original, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range, animation_frames, str(Path(output_path) / "animated_textured.glb"), include_axes=False, axes_meshes=None ) except Exception as e: print(f"Error exporting animated.glb: {e}") import traceback traceback.print_exc() return ( mesh_parts_original, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range )