Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import gradio as gr | |
| import tempfile | |
| import shutil | |
| import torch | |
| from omegaconf import OmegaConf | |
| import trimesh | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| import zipfile | |
| from datetime import datetime | |
| import spaces | |
| from infer_asset import infer_single_asset, save_articulated_meshes | |
| from particulate.models import Articulate3D_B | |
| from particulate.data_utils import load_obj_raw_preserve | |
| from particulate.export_utils import export_urdf, export_mjcf | |
| from particulate.visualization_utils import plot_mesh | |
| from yacs.config import CfgNode | |
| torch.serialization.add_safe_globals([CfgNode]) | |
| class ParticulateApp: | |
| """ | |
| Main application class for Particulate with Gradio interface. | |
| """ | |
| def __init__(self, model_config_path: str, output_dir: str): | |
| self.model_config = OmegaConf.load(model_config_path) | |
| self.output_dir = output_dir | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| self.model = Articulate3D_B(**self.model_config) | |
| self.model.eval() | |
| # Always load to CPU initially to support Zero GPU and avoid VRAM usage when idle | |
| print("Downloading/Loading model from Hugging Face...") | |
| self.model_checkpoint = hf_hub_download(repo_id="rayli/Particulate", filename="model.pt") | |
| self.model.load_state_dict(torch.load(self.model_checkpoint, map_location="cpu")) | |
| # Model stays on CPU until inference | |
| model_dir = os.path.join("PartField", "model") | |
| os.makedirs(model_dir, exist_ok=True) | |
| hf_hub_download(repo_id="mikaelaangel/partfield-ckpt", filename="model_objaverse.ckpt", local_dir=model_dir) | |
| print("Model loaded successfully.") | |
| def visualize_mesh(self, input_mesh_path): | |
| if input_mesh_path is None: | |
| return None, None | |
| # Handle Gradio file object (dict) or file path (string) | |
| if isinstance(input_mesh_path, dict): | |
| file_path = input_mesh_path.get("path") or input_mesh_path.get("name") | |
| else: | |
| file_path = input_mesh_path | |
| print(f"Visualizing mesh from: {file_path}") | |
| if file_path.endswith(".obj"): | |
| verts, faces = load_obj_raw_preserve(Path(file_path)) | |
| mesh = trimesh.Trimesh(vertices=verts, faces=faces) | |
| else: | |
| mesh = trimesh.load(file_path, process=False) | |
| if isinstance(mesh, trimesh.Scene): | |
| mesh = trimesh.util.concatenate(mesh.geometry.values()) | |
| return plot_mesh(mesh), mesh | |
| def predict( | |
| self, | |
| mesh, | |
| min_part_confidence, | |
| num_points, | |
| up_dir, | |
| animation_frames, | |
| strict, | |
| ): | |
| if mesh is None: | |
| return None, "Please upload a 3D model." | |
| try: | |
| outputs, face_indices, mesh_transformed = self._predict_impl( | |
| mesh, | |
| min_part_confidence, | |
| num_points, | |
| up_dir | |
| ) | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| ( | |
| mesh_parts_original, | |
| unique_part_ids, | |
| motion_hierarchy, | |
| is_part_revolute, | |
| is_part_prismatic, | |
| revolute_plucker, | |
| revolute_range, | |
| prismatic_axis, | |
| prismatic_range | |
| ) = save_articulated_meshes( | |
| mesh_transformed, face_indices, outputs, | |
| strict=strict, | |
| animation_frames=int(animation_frames), | |
| output_path=temp_dir | |
| ) | |
| animated_glb_file = os.path.join(temp_dir, "animated_textured.glb") | |
| prediction_file = os.path.join(temp_dir, "mesh_parts_with_axes.glb") | |
| if os.path.exists(animated_glb_file) and os.path.exists(prediction_file): | |
| # Copy to a persistent location in the output directory | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| dest_animated_glb_file = os.path.join(self.output_dir, f"animated_textured_{timestamp}.glb") | |
| dest_prediction_file = os.path.join(self.output_dir, f"mesh_parts_with_axes_{timestamp}.glb") | |
| shutil.copy(animated_glb_file, dest_animated_glb_file) | |
| shutil.copy(prediction_file, dest_prediction_file) | |
| # Temporary for debugging. | |
| # mesh_transformed.export(os.path.join(self.output_dir, f"mesh_parts_with_axes_{timestamp}.glb")) | |
| return ( | |
| dest_animated_glb_file, dest_prediction_file, f"Success!", | |
| mesh_parts_original, | |
| unique_part_ids, | |
| motion_hierarchy, | |
| is_part_revolute, | |
| is_part_prismatic, | |
| revolute_plucker, | |
| revolute_range, | |
| prismatic_axis, | |
| prismatic_range | |
| ) | |
| else: | |
| return ( | |
| None, None, f"No output file generated.", | |
| *[None] * 9 | |
| ) | |
| except ValueError as e: | |
| # Surface validation errors (e.g., too many faces) directly to the UI | |
| message = str(e) | |
| if "faces > 51.2k" in message: | |
| message = "Mesh has more than 51.2k faces; please reduce the face count." | |
| return ( | |
| None, None, message, | |
| *[None] * 9 | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return ( | |
| None, None, f"Error: {str(e)}", | |
| *[None] * 9 | |
| ) | |
| def _predict_impl( | |
| self, | |
| mesh, | |
| min_part_confidence, | |
| num_points, | |
| up_dir | |
| ): | |
| return infer_single_asset( | |
| mesh=mesh, | |
| up_dir=up_dir, | |
| model=self.model.to('cuda'), | |
| num_points=int(num_points), | |
| min_part_confidence=min_part_confidence, | |
| ) | |
| def export_urdf( | |
| self, | |
| mesh_parts, | |
| unique_part_ids, | |
| motion_hierarchy, | |
| is_part_revolute, | |
| is_part_prismatic, | |
| revolute_plucker, | |
| revolute_range, | |
| prismatic_axis, | |
| prismatic_range | |
| ): | |
| if mesh_parts is None: | |
| return None, "Please run inference first." | |
| try: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| export_urdf( | |
| mesh_parts, | |
| unique_part_ids, | |
| motion_hierarchy, | |
| is_part_revolute, | |
| is_part_prismatic, | |
| revolute_plucker, | |
| revolute_range, | |
| prismatic_axis, | |
| prismatic_range, | |
| output_path=os.path.join(temp_dir, "urdf", "model.urdf"), | |
| name="model" | |
| ) | |
| # Zip the output directory | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| with zipfile.ZipFile(os.path.join(self.output_dir, f"urdf_{timestamp}.zip"), "w") as zipf: | |
| for root, dirs, files in os.walk(os.path.join(temp_dir, "urdf")): | |
| for file in files: | |
| zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(temp_dir, "urdf"))) | |
| return os.path.join(self.output_dir, f"urdf_{timestamp}.zip"), "Success!" | |
| except Exception as e: | |
| print(f"Error exporting URDF: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"Error exporting URDF: {str(e)}" | |
| def export_mjcf( | |
| self, | |
| mesh_parts, | |
| unique_part_ids, | |
| motion_hierarchy, | |
| is_part_revolute, | |
| is_part_prismatic, | |
| revolute_plucker, | |
| revolute_range, | |
| prismatic_axis, | |
| prismatic_range | |
| ): | |
| if mesh_parts is None: | |
| return None, "Please run inference first." | |
| try: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| export_mjcf( | |
| mesh_parts, | |
| unique_part_ids, | |
| motion_hierarchy, | |
| is_part_revolute, | |
| is_part_prismatic, | |
| revolute_plucker, | |
| revolute_range, | |
| prismatic_axis, | |
| prismatic_range, | |
| output_path=os.path.join(temp_dir, "mjcf", "model.xml"), | |
| name="model" | |
| ) | |
| # Zip the output directory | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| with zipfile.ZipFile(os.path.join(self.output_dir, f"mjcf_{timestamp}.zip"), "w") as zipf: | |
| for root, dirs, files in os.walk(os.path.join(temp_dir, "mjcf")): | |
| for file in files: | |
| zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(temp_dir, "mjcf"))) | |
| return os.path.join(self.output_dir, f"mjcf_{timestamp}.zip"), "Success!" | |
| except Exception as e: | |
| print(f"Error exporting MJCF: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"Error exporting MJCF: {str(e)}" | |
| def create_gradio_app(particulate_app): | |
| # Get example files from examples folder | |
| examples_dir = "examples" | |
| example_files = [] | |
| example_images = [] | |
| if os.path.exists(examples_dir): | |
| for file in sorted(os.listdir(examples_dir)): | |
| if file.lower().endswith(('.glb', '.obj')): | |
| base_name = os.path.splitext(file)[0] | |
| png_path = os.path.join(examples_dir, base_name + ".png") | |
| if os.path.exists(png_path): | |
| example_files.append(os.path.join(examples_dir, file)) | |
| example_images.append(png_path) | |
| with gr.Blocks(title="Particulate Demo") as demo: | |
| gr.HTML( | |
| """ | |
| <h1>Particulate: Feed-Forward 3D Object Articulation</h1> | |
| <p> | |
| <a href="https://github.com/ruiningli/particulate" target="_blank" rel="noopener noreferrer">🌟 GitHub Repository</a> | | |
| <a href="https://ruiningli.com/particulate" target="_blank" rel="noopener noreferrer">🚀 Project Page</a> | |
| </p> | |
| <div style="font-size: 16px; line-height: 1.5;"> | |
| <p>Upload a 3D model (.obj or .glb format supported) to articulate it. Particulate takes this model and predicts the underlying articulated structure, which can be directly exported to <u>URDF</u> or <u>MJCF</u> format.</p> | |
| <h3>Getting Started:</h3> | |
| <ol> | |
| <li><strong>Upload a 3D model. We support meshes (.obj or .glb format) with less than 51.2k faces. You can use <a href="https://3d.hunyuan.tencent.com/" target="_blank">Hunyuan3D-v3</a> to generate a 3D model: please <u>make sure to select the '50k' face count</u>.</li> | |
| <li><strong>Preview:</strong> Your uploaded 3D model will be visualized below.</li> | |
| <li><strong>Confirm Orientation:</strong> Select the direction (one of X, -X, Y, -Y, Z, -Z) that corresponds to the up direction of the object in the preview (for all example assets, the up direction is -Z).</li> | |
| <li><strong>Run Inference:</strong> Click the "Run Inference" button to start the inference process.</li> | |
| <li><strong>Visualization:</strong> The articulated 3D model with animation and model prediction (3D part segmentation, motion types and axes) will appear on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file.</li> | |
| <li> | |
| <strong>Adjust Inference Parameters (Optional):</strong> | |
| You can potentially obtain better results by adjusting the following parameters: | |
| <ul> | |
| <li><em>Min Part Confidence:</em> Increasing this value will merge parts that have low confidence scores to other parts. Consider increasing this value if the prediction is over segmented.</li> | |
| <li><em>Refine with Connected Components:</em> If toggled on, the prediction will be post-processed to ensure that each articulated part is a union of different connected components in the original mesh (i.e., no connected components are split across parts). Toggle this on (default) if the input mesh has clean connected components.</li> | |
| <li><em>Normally, you should not need to change the other parameters.</em></li> | |
| </ul> | |
| </li> | |
| </ol> | |
| </div> | |
| """ | |
| ) | |
| loaded_mesh = gr.State(None) | |
| mesh_parts = gr.State(None) | |
| unique_part_ids = gr.State(None) | |
| motion_hierarchy = gr.State(None) | |
| is_part_revolute = gr.State(None) | |
| is_part_prismatic = gr.State(None) | |
| revolute_plucker = gr.State(None) | |
| revolute_range = gr.State(None) | |
| prismatic_axis = gr.State(None) | |
| prismatic_range = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_mesh = gr.Model3D( | |
| label="Upload 3D Model", | |
| interactive=True | |
| ) | |
| if example_files and example_images: | |
| example_dataset = gr.Dataset( | |
| label="Example Models", | |
| components=[gr.Image(visible=False)], | |
| samples=[[img] for img in example_images], | |
| type="index" | |
| ) | |
| def load_example(index): | |
| return example_files[index] | |
| example_dataset.click( | |
| fn=load_example, | |
| inputs=[example_dataset], | |
| outputs=[input_mesh] | |
| ) | |
| mesh_plot = gr.Plot(label="Mesh Preview") | |
| with gr.Accordion("Inference Parameters", open=True): | |
| with gr.Row(): | |
| up_dir = gr.Radio(choices=["X", "Y", "Z", "-X", "-Y", "-Z"], value="-Z", label="Up Direction (Select after viewing plot)") | |
| animation_frames = gr.Number(value=50, label="Animation Frames", precision=0) | |
| with gr.Row(): | |
| num_points = gr.Number(value=102400, label="Number of Points", precision=0, minimum=2048, maximum=102400) | |
| min_part_confidence = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, label="Min Part Confidence") | |
| with gr.Row(): | |
| strict = gr.Checkbox(label="Refine with Connected Components", value=True) | |
| run_btn = gr.Button("Run Inference", variant="primary") | |
| with gr.Column(scale=2): | |
| animated_model = gr.Model3D(label="Animated 3D Model") | |
| prediction_model = gr.Model3D(label="Visualization of Model Prediction") | |
| status_text = gr.Textbox(label="Status") | |
| with gr.Row(): | |
| urdf_btn = gr.Button("Export URDF") | |
| mjcf_btn = gr.Button("Export MJCF") | |
| with gr.Row(): | |
| urdf_status = gr.Textbox(label="URDF Status") | |
| mjcf_status = gr.Textbox(label="MJCF Status") | |
| with gr.Row(): | |
| urdf_file = gr.File(label="URDF Zip File") | |
| mjcf_file = gr.File(label="MJCF Zip File") | |
| # Event triggers | |
| input_mesh.change( | |
| fn=particulate_app.visualize_mesh, | |
| inputs=[input_mesh], | |
| outputs=[mesh_plot, loaded_mesh] | |
| ) | |
| run_btn.click( | |
| fn=particulate_app.predict, | |
| inputs=[ | |
| loaded_mesh, | |
| min_part_confidence, | |
| num_points, | |
| up_dir, | |
| animation_frames, | |
| strict | |
| ], | |
| outputs=[ | |
| animated_model, prediction_model, status_text, | |
| mesh_parts, | |
| unique_part_ids, | |
| motion_hierarchy, | |
| is_part_revolute, | |
| is_part_prismatic, | |
| revolute_plucker, | |
| revolute_range, | |
| prismatic_axis, | |
| prismatic_range | |
| ] | |
| ) | |
| urdf_btn.click( | |
| fn=particulate_app.export_urdf, | |
| inputs=[ | |
| mesh_parts, | |
| unique_part_ids, | |
| motion_hierarchy, | |
| is_part_revolute, | |
| is_part_prismatic, | |
| revolute_plucker, | |
| revolute_range, | |
| prismatic_axis, | |
| prismatic_range | |
| ], | |
| outputs=[urdf_file, urdf_status] | |
| ) | |
| mjcf_btn.click( | |
| fn=particulate_app.export_mjcf, | |
| inputs=[ | |
| mesh_parts, | |
| unique_part_ids, | |
| motion_hierarchy, | |
| is_part_revolute, | |
| is_part_prismatic, | |
| revolute_plucker, | |
| revolute_range, | |
| prismatic_axis, | |
| prismatic_range | |
| ], | |
| outputs=[mjcf_file, mjcf_status] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| output_dir = "gradio_outputs" | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Load model configuration | |
| model_config_path = "configs/particulate-B.yaml" | |
| # Initialize app | |
| print("Initializing Particulate App...") | |
| app = ParticulateApp(model_config_path, output_dir) | |
| # Create and launch Gradio demo | |
| demo = create_gradio_app(app) | |
| print("Launching Gradio server...") | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |