Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import boto3 | |
| import json | |
| from os import path as osp | |
| # from botocore.config import Config | |
| # from botocore.exceptions import ClientError | |
| import h5py | |
| import io | |
| import numpy as np | |
| import skimage | |
| import trimesh | |
| import os | |
| from scipy.spatial import KDTree | |
| import gc | |
| from plyfile import PlyData | |
| ## For remeshing | |
| import mesh2sdf | |
| import tetgen | |
| import vtk | |
| import math | |
| import tempfile | |
| ### For mesh processing | |
| import pymeshlab | |
| from partfield.utils import * | |
| ######################### | |
| ## To handle quad inputs | |
| ######################### | |
| def quad_to_triangle_mesh(F): | |
| """ | |
| Converts a quad-dominant mesh into a pure triangle mesh by splitting quads into two triangles. | |
| Parameters: | |
| quad_mesh (trimesh.Trimesh): Input mesh with quad faces. | |
| Returns: | |
| trimesh.Trimesh: A new mesh with only triangle faces. | |
| """ | |
| faces = F | |
| ### If already a triangle mesh -- skip | |
| if len(faces[0]) == 3: | |
| return F | |
| new_faces = [] | |
| for face in faces: | |
| if len(face) == 4: # Quad face | |
| # Split into two triangles | |
| new_faces.append([face[0], face[1], face[2]]) # Triangle 1 | |
| new_faces.append([face[0], face[2], face[3]]) # Triangle 2 | |
| else: | |
| print(f"Warning: Skipping non-triangle/non-quad face {face}") | |
| new_faces = np.array(new_faces) | |
| return new_faces | |
| ######################### | |
| class Demo_Dataset(torch.utils.data.Dataset): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.data_path = cfg.dataset.data_path | |
| self.is_pc = cfg.is_pc | |
| all_files = os.listdir(self.data_path) | |
| selected = [] | |
| for f in all_files: | |
| if ".ply" in f and self.is_pc: | |
| selected.append(f) | |
| elif (".obj" in f or ".glb" in f or ".off" in f) and not self.is_pc: | |
| selected.append(f) | |
| self.data_list = selected | |
| self.pc_num_pts = 100000 | |
| self.preprocess_mesh = cfg.preprocess_mesh | |
| self.result_name = cfg.result_name | |
| print("val dataset len:", len(self.data_list)) | |
| def __len__(self): | |
| return len(self.data_list) | |
| def load_ply_to_numpy(self, filename): | |
| """ | |
| Load a PLY file and extract the point cloud as a (N, 3) NumPy array. | |
| Parameters: | |
| filename (str): Path to the PLY file. | |
| Returns: | |
| numpy.ndarray: Point cloud array of shape (N, 3). | |
| """ | |
| ply_data = PlyData.read(filename) | |
| # Extract vertex data | |
| vertex_data = ply_data["vertex"] | |
| # Convert to NumPy array (x, y, z) | |
| points = np.vstack([vertex_data["x"], vertex_data["y"], vertex_data["z"]]).T | |
| return points | |
| def get_model(self, ply_file): | |
| uid = ply_file.split(".")[-2].replace("/", "_") | |
| #### | |
| if self.is_pc: | |
| ply_file_read = os.path.join(self.data_path, ply_file) | |
| pc = self.load_ply_to_numpy(ply_file_read) | |
| bbmin = pc.min(0) | |
| bbmax = pc.max(0) | |
| center = (bbmin + bbmax) * 0.5 | |
| scale = 2.0 * 0.9 / (bbmax - bbmin).max() | |
| pc = (pc - center) * scale | |
| else: | |
| obj_path = os.path.join(self.data_path, ply_file) | |
| mesh = load_mesh_util(obj_path) | |
| vertices = mesh.vertices | |
| faces = mesh.faces | |
| bbmin = vertices.min(0) | |
| bbmax = vertices.max(0) | |
| center = (bbmin + bbmax) * 0.5 | |
| scale = 2.0 * 0.9 / (bbmax - bbmin).max() | |
| vertices = (vertices - center) * scale | |
| mesh.vertices = vertices | |
| ### Make sure it is a triangle mesh -- just convert the quad | |
| mesh.faces = quad_to_triangle_mesh(faces) | |
| print("before preprocessing...") | |
| print(mesh.vertices.shape) | |
| print(mesh.faces.shape) | |
| print() | |
| ### Pre-process mesh | |
| if self.preprocess_mesh: | |
| # Create a PyMeshLab mesh directly from vertices and faces | |
| ml_mesh = pymeshlab.Mesh(vertex_matrix=mesh.vertices, face_matrix=mesh.faces) | |
| # Create a MeshSet and add your mesh | |
| ms = pymeshlab.MeshSet() | |
| ms.add_mesh(ml_mesh, "from_trimesh") | |
| # Apply filters | |
| ms.apply_filter('meshing_remove_duplicate_faces') | |
| ms.apply_filter('meshing_remove_duplicate_vertices') | |
| percentageMerge = pymeshlab.PercentageValue(0.5) | |
| ms.apply_filter('meshing_merge_close_vertices', threshold=percentageMerge) | |
| ms.apply_filter('meshing_remove_unreferenced_vertices') | |
| # Save or extract mesh | |
| processed = ms.current_mesh() | |
| mesh.vertices = processed.vertex_matrix() | |
| mesh.faces = processed.face_matrix() | |
| print("after preprocessing...") | |
| print(mesh.vertices.shape) | |
| print(mesh.faces.shape) | |
| ### Save input | |
| save_dir = f"exp_results/{self.result_name}" | |
| os.makedirs(save_dir, exist_ok=True) | |
| view_id = 0 | |
| mesh.export(f'{save_dir}/input_{uid}_{view_id}.ply') | |
| pc, _ = trimesh.sample.sample_surface(mesh, self.pc_num_pts) | |
| result = { | |
| 'uid': uid | |
| } | |
| result['pc'] = torch.tensor(pc, dtype=torch.float32) | |
| if not self.is_pc: | |
| result['vertices'] = mesh.vertices | |
| result['faces'] = mesh.faces | |
| return result | |
| def __getitem__(self, index): | |
| gc.collect() | |
| return self.get_model(self.data_list[index]) | |
| ############## | |
| ############################### | |
| class Demo_Remesh_Dataset(torch.utils.data.Dataset): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.data_path = cfg.dataset.data_path | |
| all_files = os.listdir(self.data_path) | |
| selected = [] | |
| for f in all_files: | |
| if (".obj" in f or ".glb" in f): | |
| selected.append(f) | |
| self.data_list = selected | |
| self.pc_num_pts = 100000 | |
| self.preprocess_mesh = cfg.preprocess_mesh | |
| self.result_name = cfg.result_name | |
| print("val dataset len:", len(self.data_list)) | |
| def __len__(self): | |
| return len(self.data_list) | |
| def get_model(self, ply_file): | |
| uid = ply_file.split(".")[-2] | |
| #### | |
| obj_path = os.path.join(self.data_path, ply_file) | |
| mesh = load_mesh_util(obj_path) | |
| vertices = mesh.vertices | |
| faces = mesh.faces | |
| bbmin = vertices.min(0) | |
| bbmax = vertices.max(0) | |
| center = (bbmin + bbmax) * 0.5 | |
| scale = 2.0 * 0.9 / (bbmax - bbmin).max() | |
| vertices = (vertices - center) * scale | |
| mesh.vertices = vertices | |
| ### Pre-process mesh | |
| if self.preprocess_mesh: | |
| # Create a PyMeshLab mesh directly from vertices and faces | |
| ml_mesh = pymeshlab.Mesh(vertex_matrix=mesh.vertices, face_matrix=mesh.faces) | |
| # Create a MeshSet and add your mesh | |
| ms = pymeshlab.MeshSet() | |
| ms.add_mesh(ml_mesh, "from_trimesh") | |
| # Apply filters | |
| ms.apply_filter('meshing_remove_duplicate_faces') | |
| ms.apply_filter('meshing_remove_duplicate_vertices') | |
| percentageMerge = pymeshlab.PercentageValue(0.5) | |
| ms.apply_filter('meshing_merge_close_vertices', threshold=percentageMerge) | |
| ms.apply_filter('meshing_remove_unreferenced_vertices') | |
| # Save or extract mesh | |
| processed = ms.current_mesh() | |
| mesh.vertices = processed.vertex_matrix() | |
| mesh.faces = processed.face_matrix() | |
| print("after preprocessing...") | |
| print(mesh.vertices.shape) | |
| print(mesh.faces.shape) | |
| ### Save input | |
| save_dir = f"exp_results/{self.result_name}" | |
| os.makedirs(save_dir, exist_ok=True) | |
| view_id = 0 | |
| mesh.export(f'{save_dir}/input_{uid}_{view_id}.ply') | |
| try: | |
| ###### Remesh ###### | |
| size= 256 | |
| level = 2 / size | |
| sdf = mesh2sdf.core.compute(mesh.vertices, mesh.faces, size) | |
| # NOTE: the negative value is not reliable if the mesh is not watertight | |
| udf = np.abs(sdf) | |
| vertices, faces, _, _ = skimage.measure.marching_cubes(udf, level) | |
| #### Only use SDF mesh ### | |
| # new_mesh = trimesh.Trimesh(vertices, faces) | |
| ########################## | |
| #### Make tet ##### | |
| components = trimesh.Trimesh(vertices, faces).split(only_watertight=False) | |
| new_mesh = [] #trimesh.Trimesh() | |
| if len(components) > 100000: | |
| raise NotImplementedError | |
| for i, c in enumerate(components): | |
| c.fix_normals() | |
| new_mesh.append(c) #trimesh.util.concatenate(new_mesh, c) | |
| new_mesh = trimesh.util.concatenate(new_mesh) | |
| # generate tet mesh | |
| tet = tetgen.TetGen(new_mesh.vertices, new_mesh.faces) | |
| tet.tetrahedralize(plc=True, nobisect=1., quality=True, fixedvolume=True, maxvolume=math.sqrt(2) / 12 * (2 / size) ** 3) | |
| tmp_vtk = tempfile.NamedTemporaryFile(suffix='.vtk', delete=True) | |
| tet.grid.save(tmp_vtk.name) | |
| # extract surface mesh from tet mesh | |
| reader = vtk.vtkUnstructuredGridReader() | |
| reader.SetFileName(tmp_vtk.name) | |
| reader.Update() | |
| surface_filter = vtk.vtkDataSetSurfaceFilter() | |
| surface_filter.SetInputConnection(reader.GetOutputPort()) | |
| surface_filter.Update() | |
| polydata = surface_filter.GetOutput() | |
| writer = vtk.vtkOBJWriter() | |
| tmp_obj = tempfile.NamedTemporaryFile(suffix='.obj', delete=True) | |
| writer.SetFileName(tmp_obj.name) | |
| writer.SetInputData(polydata) | |
| writer.Update() | |
| new_mesh = load_mesh_util(tmp_obj.name) | |
| ########################## | |
| new_mesh.vertices = new_mesh.vertices * (2.0 / size) - 1.0 # normalize it to [-1, 1] | |
| mesh = new_mesh | |
| #################### | |
| except: | |
| print("Error in tet.") | |
| mesh = mesh | |
| pc, _ = trimesh.sample.sample_surface(mesh, self.pc_num_pts) | |
| result = { | |
| 'uid': uid | |
| } | |
| result['pc'] = torch.tensor(pc, dtype=torch.float32) | |
| result['vertices'] = mesh.vertices | |
| result['faces'] = mesh.faces | |
| return result | |
| def __getitem__(self, index): | |
| gc.collect() | |
| return self.get_model(self.data_list[index]) | |
| class Correspondence_Demo_Dataset(Demo_Dataset): | |
| def __init__(self, cfg): | |
| super().__init__(cfg) | |
| self.data_path = cfg.dataset.data_path | |
| self.is_pc = cfg.is_pc | |
| self.data_list = cfg.dataset.all_files | |
| self.pc_num_pts = 100000 | |
| self.preprocess_mesh = cfg.preprocess_mesh | |
| self.result_name = cfg.result_name | |
| print("val dataset len:", len(self.data_list)) | |