import os import gradio as gr import tempfile import shutil import torch from omegaconf import OmegaConf import trimesh from pathlib import Path from huggingface_hub import hf_hub_download import zipfile from datetime import datetime import spaces from infer_asset import infer_single_asset, save_articulated_meshes from particulate.models import Articulate3D_B from particulate.data_utils import load_obj_raw_preserve from particulate.export_utils import export_urdf, export_mjcf from particulate.visualization_utils import plot_mesh from yacs.config import CfgNode torch.serialization.add_safe_globals([CfgNode]) class ParticulateApp: """ Main application class for Particulate with Gradio interface. """ def __init__(self, model_config_path: str, output_dir: str): self.model_config = OmegaConf.load(model_config_path) self.output_dir = output_dir os.makedirs(self.output_dir, exist_ok=True) self.model = Articulate3D_B(**self.model_config) self.model.eval() # Always load to CPU initially to support Zero GPU and avoid VRAM usage when idle print("Downloading/Loading model from Hugging Face...") self.model_checkpoint = hf_hub_download(repo_id="rayli/Particulate", filename="model.pt") self.model.load_state_dict(torch.load(self.model_checkpoint, map_location="cpu")) # Model stays on CPU until inference model_dir = os.path.join("PartField", "model") os.makedirs(model_dir, exist_ok=True) hf_hub_download(repo_id="mikaelaangel/partfield-ckpt", filename="model_objaverse.ckpt", local_dir=model_dir) print("Model loaded successfully.") def visualize_mesh(self, input_mesh_path): if input_mesh_path is None: return None, None # Handle Gradio file object (dict) or file path (string) if isinstance(input_mesh_path, dict): file_path = input_mesh_path.get("path") or input_mesh_path.get("name") else: file_path = input_mesh_path print(f"Visualizing mesh from: {file_path}") if file_path.endswith(".obj"): verts, faces = load_obj_raw_preserve(Path(file_path)) mesh = trimesh.Trimesh(vertices=verts, faces=faces) else: mesh = trimesh.load(file_path, process=False) if isinstance(mesh, trimesh.Scene): mesh = trimesh.util.concatenate(mesh.geometry.values()) return plot_mesh(mesh), mesh def predict( self, mesh, min_part_confidence, num_points, up_dir, animation_frames, strict, ): if mesh is None: return None, "Please upload a 3D model." try: outputs, face_indices, mesh_transformed = self._predict_impl( mesh, min_part_confidence, num_points, up_dir ) with tempfile.TemporaryDirectory() as temp_dir: ( mesh_parts_original, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range ) = save_articulated_meshes( mesh_transformed, face_indices, outputs, strict=strict, animation_frames=int(animation_frames), output_path=temp_dir ) animated_glb_file = os.path.join(temp_dir, "animated_textured.glb") prediction_file = os.path.join(temp_dir, "mesh_parts_with_axes.glb") if os.path.exists(animated_glb_file) and os.path.exists(prediction_file): # Copy to a persistent location in the output directory timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") dest_animated_glb_file = os.path.join(self.output_dir, f"animated_textured_{timestamp}.glb") dest_prediction_file = os.path.join(self.output_dir, f"mesh_parts_with_axes_{timestamp}.glb") shutil.copy(animated_glb_file, dest_animated_glb_file) shutil.copy(prediction_file, dest_prediction_file) # Temporary for debugging. # mesh_transformed.export(os.path.join(self.output_dir, f"mesh_parts_with_axes_{timestamp}.glb")) return ( dest_animated_glb_file, dest_prediction_file, f"Success!", mesh_parts_original, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range ) else: return ( None, None, f"No output file generated.", *[None] * 9 ) except ValueError as e: # Surface validation errors (e.g., too many faces) directly to the UI message = str(e) if "faces > 51.2k" in message: message = "Mesh has more than 51.2k faces; please reduce the face count." return ( None, None, message, *[None] * 9 ) except Exception as e: import traceback traceback.print_exc() return ( None, None, f"Error: {str(e)}", *[None] * 9 ) @spaces.GPU(duration=20) def _predict_impl( self, mesh, min_part_confidence, num_points, up_dir ): return infer_single_asset( mesh=mesh, up_dir=up_dir, model=self.model.to('cuda'), num_points=int(num_points), min_part_confidence=min_part_confidence, ) def export_urdf( self, mesh_parts, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range ): if mesh_parts is None: return None, "Please run inference first." try: with tempfile.TemporaryDirectory() as temp_dir: export_urdf( mesh_parts, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range, output_path=os.path.join(temp_dir, "urdf", "model.urdf"), name="model" ) # Zip the output directory timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") with zipfile.ZipFile(os.path.join(self.output_dir, f"urdf_{timestamp}.zip"), "w") as zipf: for root, dirs, files in os.walk(os.path.join(temp_dir, "urdf")): for file in files: zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(temp_dir, "urdf"))) return os.path.join(self.output_dir, f"urdf_{timestamp}.zip"), "Success!" except Exception as e: print(f"Error exporting URDF: {e}") import traceback traceback.print_exc() return None, f"Error exporting URDF: {str(e)}" def export_mjcf( self, mesh_parts, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range ): if mesh_parts is None: return None, "Please run inference first." try: with tempfile.TemporaryDirectory() as temp_dir: export_mjcf( mesh_parts, unique_part_ids, motion_hierarchy, is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range, output_path=os.path.join(temp_dir, "mjcf", "model.xml"), name="model" ) # Zip the output directory timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") with zipfile.ZipFile(os.path.join(self.output_dir, f"mjcf_{timestamp}.zip"), "w") as zipf: for root, dirs, files in os.walk(os.path.join(temp_dir, "mjcf")): for file in files: zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(temp_dir, "mjcf"))) return os.path.join(self.output_dir, f"mjcf_{timestamp}.zip"), "Success!" except Exception as e: print(f"Error exporting MJCF: {e}") import traceback traceback.print_exc() return None, f"Error exporting MJCF: {str(e)}" def create_gradio_app(particulate_app): # Get example files from examples folder examples_dir = "examples" example_files = [] example_images = [] if os.path.exists(examples_dir): for file in sorted(os.listdir(examples_dir)): if file.lower().endswith(('.glb', '.obj')): base_name = os.path.splitext(file)[0] png_path = os.path.join(examples_dir, base_name + ".png") if os.path.exists(png_path): example_files.append(os.path.join(examples_dir, file)) example_images.append(png_path) with gr.Blocks(title="Particulate Demo") as demo: gr.HTML( """
🌟 GitHub Repository | 🚀 Project Page
Upload a 3D model (.obj or .glb format supported) to articulate it. Particulate takes this model and predicts the underlying articulated structure, which can be directly exported to URDF or MJCF format.