import os import gradio as gr import tempfile import shutil import torch from omegaconf import OmegaConf import trimesh from pathlib import Path from huggingface_hub import hf_hub_download import zipfile from datetime import datetime import spaces from infer_asset import infer_single_asset, save_articulated_meshes from particulate.models import Articulate3D_B from particulate.data_utils import load_obj_raw_preserve from particulate.export_utils import export_urdf, export_mjcf from particulate.visualization_utils import plot_mesh from yacs.config import CfgNode torch.serialization.add_safe_globals([CfgNode]) class ParticulateApp: """ Main application class for Particulate with Gradio interface. """ def __init__(self, model_config_path: str, output_dir: str): self.model_config = OmegaConf.load(model_config_path) self.output_dir = output_dir os.makedirs(self.output_dir, exist_ok=True) self.model = Articulate3D_B(**self.model_config) self.model.eval() # Always load to CPU initially to support Zero GPU and avoid VRAM usage when idle print("Downloading/Loading model from Hugging Face...") self.model_checkpoint = hf_hub_download(repo_id="rayli/Particulate", filename="model.pt") self.model.load_state_dict(torch.load(self.model_checkpoint, map_location="cpu")) # Model stays on CPU until inference model_dir = os.path.join("PartField", "model") os.makedirs(model_dir, exist_ok=True) hf_hub_download(repo_id="mikaelaangel/partfield-ckpt", filename="model_objaverse.ckpt", local_dir=model_dir) print("Model loaded successfully.") def visualize_mesh(self, input_mesh_path): if input_mesh_path is None: return None, None # Handle Gradio file object (dict) or file path (string) if isinstance(input_mesh_path, dict): file_path = input_mesh_path.get("path") or input_mesh_path.get("name") else: file_path = input_mesh_path print(f"Visualizing mesh from: {file_path}") if file_path.endswith(".obj"): verts, faces = load_obj_raw_preserve(Path(file_path)) mesh = trimesh.Trimesh(vertices=verts, faces=faces) else: mesh = trimesh.load(file_path, process=False) if isinstance(mesh, trimesh.Scene): mesh = trimesh.util.concatenate(mesh.geometry.values()) return plot_mesh(mesh), mesh def predict( self, mesh, min_part_confidence, num_points, up_dir, animation_frames, strict, ): if mesh is None: return None, "Please upload a 3D model." try: outputs, face_indices, mesh_transformed = self._predict_impl( mesh, min_part_confidence, num_points, up_dir ) with tempfile.TemporaryDirectory() as temp_dir: ( mesh_parts_original, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range ) = save_articulated_meshes( mesh_transformed, face_indices, outputs, strict=strict, animation_frames=int(animation_frames), output_path=temp_dir ) animated_glb_file = os.path.join(temp_dir, "animated_textured.glb") prediction_file = os.path.join(temp_dir, "mesh_parts_with_axes.glb") if os.path.exists(animated_glb_file) and os.path.exists(prediction_file): # Copy to a persistent location in the output directory timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") dest_animated_glb_file = os.path.join(self.output_dir, f"animated_textured_{timestamp}.glb") dest_prediction_file = os.path.join(self.output_dir, f"mesh_parts_with_axes_{timestamp}.glb") shutil.copy(animated_glb_file, dest_animated_glb_file) shutil.copy(prediction_file, dest_prediction_file) # Temporary for debugging. # mesh_transformed.export(os.path.join(self.output_dir, f"mesh_parts_with_axes_{timestamp}.glb")) return ( dest_animated_glb_file, dest_prediction_file, f"Success!", mesh_parts_original, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range ) else: return ( None, None, f"No output file generated.", *[None] * 9 ) except ValueError as e: # Surface validation errors (e.g., too many faces) directly to the UI message = str(e) if "faces > 51.2k" in message: message = "Mesh has more than 51.2k faces; please reduce the face count." return ( None, None, message, *[None] * 9 ) except Exception as e: import traceback traceback.print_exc() return ( None, None, f"Error: {str(e)}", *[None] * 9 ) @spaces.GPU(duration=20) def _predict_impl( self, mesh, min_part_confidence, num_points, up_dir ): return infer_single_asset( mesh=mesh, up_dir=up_dir, model=self.model.to('cuda'), num_points=int(num_points), min_part_confidence=min_part_confidence, ) def export_urdf( self, mesh_parts, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range ): if mesh_parts is None: return None, "Please run inference first." try: with tempfile.TemporaryDirectory() as temp_dir: export_urdf( mesh_parts, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range, output_path=os.path.join(temp_dir, "urdf", "model.urdf"), name="model" ) # Zip the output directory timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") with zipfile.ZipFile(os.path.join(self.output_dir, f"urdf_{timestamp}.zip"), "w") as zipf: for root, dirs, files in os.walk(os.path.join(temp_dir, "urdf")): for file in files: zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(temp_dir, "urdf"))) return os.path.join(self.output_dir, f"urdf_{timestamp}.zip"), "Success!" except Exception as e: print(f"Error exporting URDF: {e}") import traceback traceback.print_exc() return None, f"Error exporting URDF: {str(e)}" def export_mjcf( self, mesh_parts, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range ): if mesh_parts is None: return None, "Please run inference first." try: with tempfile.TemporaryDirectory() as temp_dir: export_mjcf( mesh_parts, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range, output_path=os.path.join(temp_dir, "mjcf", "model.xml"), name="model" ) # Zip the output directory timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") with zipfile.ZipFile(os.path.join(self.output_dir, f"mjcf_{timestamp}.zip"), "w") as zipf: for root, dirs, files in os.walk(os.path.join(temp_dir, "mjcf")): for file in files: zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(temp_dir, "mjcf"))) return os.path.join(self.output_dir, f"mjcf_{timestamp}.zip"), "Success!" except Exception as e: print(f"Error exporting MJCF: {e}") import traceback traceback.print_exc() return None, f"Error exporting MJCF: {str(e)}" def create_gradio_app(particulate_app): # Get example files from examples folder examples_dir = "examples" example_files = [] example_images = [] if os.path.exists(examples_dir): for file in sorted(os.listdir(examples_dir)): if file.lower().endswith(('.glb', '.obj')): base_name = os.path.splitext(file)[0] png_path = os.path.join(examples_dir, base_name + ".png") if os.path.exists(png_path): example_files.append(os.path.join(examples_dir, file)) example_images.append(png_path) with gr.Blocks(title="Particulate Demo") as demo: gr.HTML( """

Particulate: Feed-Forward 3D Object Articulation

🌟 GitHub Repository | 🚀 Project Page

Upload a 3D model (.obj or .glb format supported) to articulate it. Particulate takes this model and predicts the underlying articulated structure, which can be directly exported to URDF or MJCF format.

Getting Started:

  1. Upload a 3D model. We support meshes (.obj or .glb format) with less than 51.2k faces. You can use Hunyuan3D-v3 to generate a 3D model: please make sure to select the '50k' face count.
  2. Preview: Your uploaded 3D model will be visualized below.
  3. Confirm Orientation: Select the direction (one of X, -X, Y, -Y, Z, -Z) that corresponds to the up direction of the object in the preview (for all example assets, the up direction is -Z).
  4. Run Inference: Click the "Run Inference" button to start the inference process.
  5. Visualization: The articulated 3D model with animation and model prediction (3D part segmentation, motion types and axes) will appear on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file.
  6. Adjust Inference Parameters (Optional): You can potentially obtain better results by adjusting the following parameters:
    • Min Part Confidence: Increasing this value will merge parts that have low confidence scores to other parts. Consider increasing this value if the prediction is over segmented.
    • Refine with Connected Components: If toggled on, the prediction will be post-processed to ensure that each articulated part is a union of different connected components in the original mesh (i.e., no connected components are split across parts). Toggle this on (default) if the input mesh has clean connected components.
    • Normally, you should not need to change the other parameters.
""" ) loaded_mesh = gr.State(None) mesh_parts = gr.State(None) unique_part_ids = gr.State(None) motion_hierarchy = gr.State(None) is_part_revolute = gr.State(None) is_part_prismatic = gr.State(None) revolute_plucker = gr.State(None) revolute_range = gr.State(None) prismatic_axis = gr.State(None) prismatic_range = gr.State(None) with gr.Row(): with gr.Column(scale=1): input_mesh = gr.Model3D( label="Upload 3D Model", interactive=True ) if example_files and example_images: example_dataset = gr.Dataset( label="Example Models", components=[gr.Image(visible=False)], samples=[[img] for img in example_images], type="index" ) def load_example(index): return example_files[index] example_dataset.click( fn=load_example, inputs=[example_dataset], outputs=[input_mesh] ) mesh_plot = gr.Plot(label="Mesh Preview") with gr.Accordion("Inference Parameters", open=True): with gr.Row(): up_dir = gr.Radio(choices=["X", "Y", "Z", "-X", "-Y", "-Z"], value="-Z", label="Up Direction (Select after viewing plot)") animation_frames = gr.Number(value=50, label="Animation Frames", precision=0) with gr.Row(): num_points = gr.Number(value=102400, label="Number of Points", precision=0, minimum=2048, maximum=102400) min_part_confidence = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, label="Min Part Confidence") with gr.Row(): strict = gr.Checkbox(label="Refine with Connected Components", value=True) run_btn = gr.Button("Run Inference", variant="primary") with gr.Column(scale=2): animated_model = gr.Model3D(label="Animated 3D Model") prediction_model = gr.Model3D(label="Visualization of Model Prediction") status_text = gr.Textbox(label="Status") with gr.Row(): urdf_btn = gr.Button("Export URDF") mjcf_btn = gr.Button("Export MJCF") with gr.Row(): urdf_status = gr.Textbox(label="URDF Status") mjcf_status = gr.Textbox(label="MJCF Status") with gr.Row(): urdf_file = gr.File(label="URDF Zip File") mjcf_file = gr.File(label="MJCF Zip File") # Event triggers input_mesh.change( fn=particulate_app.visualize_mesh, inputs=[input_mesh], outputs=[mesh_plot, loaded_mesh] ) run_btn.click( fn=particulate_app.predict, inputs=[ loaded_mesh, min_part_confidence, num_points, up_dir, animation_frames, strict ], outputs=[ animated_model, prediction_model, status_text, mesh_parts, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range ] ) urdf_btn.click( fn=particulate_app.export_urdf, inputs=[ mesh_parts, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range ], outputs=[urdf_file, urdf_status] ) mjcf_btn.click( fn=particulate_app.export_mjcf, inputs=[ mesh_parts, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range ], outputs=[mjcf_file, mjcf_status] ) return demo if __name__ == "__main__": output_dir = "gradio_outputs" os.makedirs(output_dir, exist_ok=True) # Load model configuration model_config_path = "configs/particulate-B.yaml" # Initialize app print("Initializing Particulate App...") app = ParticulateApp(model_config_path, output_dir) # Create and launch Gradio demo demo = create_gradio_app(app) print("Launching Gradio server...") demo.launch(server_name="0.0.0.0", server_port=7860, share=True)