File size: 15,746 Bytes
4f22fc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
from sklearn.cluster import AgglomerativeClustering, KMeans
import numpy as np
import trimesh
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
import time

import json
from os.path import join
from typing import List

from collections import defaultdict
from scipy.sparse import coo_matrix, csr_matrix
from scipy.sparse.csgraph import connected_components

from plyfile import PlyData
import open3d as o3d

from scipy.spatial import cKDTree
from collections import Counter
from partfield.utils import *

#### Export to file #####
def export_colored_mesh_ply(V, F, FL, filename='segmented_mesh.ply'):
    """
    Export a mesh with per-face segmentation labels into a colored PLY file.

    Parameters:
    - V (np.ndarray): Vertices array of shape (N, 3)
    - F (np.ndarray): Faces array of shape (M, 3)
    - FL (np.ndarray): Face labels of shape (M,)
    - filename (str): Output filename
    """
    assert V.shape[1] == 3
    assert F.shape[1] == 3
    assert F.shape[0] == FL.shape[0]

    # Generate distinct colors for each unique label
    unique_labels = np.unique(FL)
    colormap = plt.cm.get_cmap("tab20", len(unique_labels))
    label_to_color = {
        label: (np.array(colormap(i)[:3]) * 255).astype(np.uint8)
        for i, label in enumerate(unique_labels)
    }

    mesh = trimesh.Trimesh(vertices=V, faces=F)
    FL = np.squeeze(FL)
    for i, face in enumerate(F):
        label = FL[i]
        color = label_to_color[label]
        color_with_alpha = np.append(color, 255)  # Add alpha value
        mesh.visual.face_colors[i] = color_with_alpha

    mesh.export(filename)
    print(f"Exported mesh to {filename}")

def export_pointcloud_with_labels_to_ply(V, VL, filename='colored_pointcloud.ply'):
    """
    Export a labeled point cloud to a PLY file with vertex colors.
    
    Parameters:
    - V: (N, 3) numpy array of XYZ coordinates
    - VL: (N,) numpy array of integer labels
    - filename: Output PLY file name
    """
    assert V.shape[0] == VL.shape[0], "Number of vertices and labels must match"

    # Generate unique colors for each label
    unique_labels = np.unique(VL)
    colormap = plt.cm.get_cmap("tab20", len(unique_labels))
    label_to_color = {
        label: colormap(i)[:3] for i, label in enumerate(unique_labels)
    }

    VL = np.squeeze(VL)
    # Map labels to RGB colors
    colors = np.array([label_to_color[label] for label in VL])
    
    # Open3D requires colors in float [0, 1]
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(V)
    pcd.colors = o3d.utility.Vector3dVector(colors)

    # Save to .ply
    o3d.io.write_point_cloud(filename, pcd)
    print(f"Point cloud saved to {filename}")
#########################

def construct_face_adjacency_matrix(face_list):
    """
    Given a list of faces (each face is a 3-tuple of vertex indices),
    construct a face-based adjacency matrix of shape (num_faces, num_faces).
    Two faces are adjacent if they share an edge.

    If multiple connected components exist, dummy edges are added to 
    turn them into a single connected component.

    Parameters
    ----------
    face_list : list of tuples
        List of faces, each face is a tuple (v0, v1, v2) of vertex indices.

    Returns
    -------
    face_adjacency : scipy.sparse.csr_matrix
        A CSR sparse matrix of shape (num_faces, num_faces), 
        containing 1s for adjacent faces and 0s otherwise. 
        Additional edges are added if the faces are in multiple components.
    """

    num_faces = len(face_list)
    if num_faces == 0:
        # Return an empty matrix if no faces
        return csr_matrix((0, 0))

    # Step 1: Map each undirected edge -> list of face indices that contain that edge
    edge_to_faces = defaultdict(list)

    # Populate the edge_to_faces dictionary
    for f_idx, (v0, v1, v2) in enumerate(face_list):
        # For an edge, we always store its endpoints in sorted order
        # to avoid duplication (e.g. edge (2,5) is the same as (5,2)).
        edges = [
            tuple(sorted((v0, v1))),
            tuple(sorted((v1, v2))),
            tuple(sorted((v2, v0)))
        ]
        for e in edges:
            edge_to_faces[e].append(f_idx)

    # Step 2: Build the adjacency (row, col) lists among faces
    row = []
    col = []
    for e, faces_sharing_e in edge_to_faces.items():
        # If an edge is shared by multiple faces, make each pair of those faces adjacent
        f_indices = list(set(faces_sharing_e))  # unique face indices for this edge
        if len(f_indices) > 1:
            # For each pair of faces, mark them as adjacent
            for i in range(len(f_indices)):
                for j in range(i + 1, len(f_indices)):
                    f_i = f_indices[i]
                    f_j = f_indices[j]
                    row.append(f_i)
                    col.append(f_j)
                    row.append(f_j)
                    col.append(f_i)

    # Create a COO matrix, then convert it to CSR
    data = np.ones(len(row), dtype=np.int8)
    face_adjacency = coo_matrix(
        (data, (row, col)),
        shape=(num_faces, num_faces)
    ).tocsr()

    return face_adjacency


def relabel_coarse_mesh(dense_mesh, dense_labels, coarse_mesh):
    """
    Relabels a coarse mesh using voting from a dense mesh, where every dense face gets to vote.

    Parameters:
        dense_mesh (trimesh.Trimesh): High-resolution input mesh.
        dense_labels (numpy.ndarray): Per-face labels for the dense mesh (shape: (N_dense_faces,)).
        coarse_mesh (trimesh.Trimesh): Coarser mesh to be relabeled.

    Returns:
        numpy.ndarray: New labels for the coarse mesh (shape: (N_coarse_faces,)).
    """
    # Compute centroids for both dense and coarse mesh faces
    dense_centroids = dense_mesh.vertices[dense_mesh.faces].mean(axis=1)  # (N_dense_faces, 3)
    coarse_centroids = coarse_mesh.vertices[coarse_mesh.faces].mean(axis=1)  # (N_coarse_faces, 3)

    # Use KDTree to efficiently find nearest coarse face for each dense face
    tree = cKDTree(coarse_centroids)
    _, nearest_coarse_faces = tree.query(dense_centroids)  # (N_dense_faces,)

    # Initialize label votes per coarse face
    face_label_votes = {i: [] for i in range(len(coarse_mesh.faces))}

    # Every dense face votes for its nearest coarse face
    dense_labels += 1
    for dense_face_idx, coarse_face_idx in enumerate(nearest_coarse_faces):
        face_label_votes[coarse_face_idx].append(dense_labels[dense_face_idx])

    # Assign new labels based on majority voting
    coarse_labels = np.zeros(len(coarse_mesh.faces), dtype=np.int32)

    for i, votes in face_label_votes.items():
        if votes:  # If this coarse face received votes
            most_common_label = Counter(votes).most_common(1)[0][0]
            coarse_labels[i] = most_common_label
        else:
            coarse_labels[i] = 0  # Mark as unassigned (optional)

    return coarse_labels

class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [1] * n
    
    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x, y):
        rootX = self.find(x)
        rootY = self.find(y)
        
        if rootX != rootY:
            if self.rank[rootX] > self.rank[rootY]:
                self.parent[rootY] = rootX
            elif self.rank[rootX] < self.rank[rootY]:
                self.parent[rootX] = rootY
            else:
                self.parent[rootY] = rootX
                self.rank[rootX] += 1

def hierarchical_clustering_labels(children, n_samples, max_cluster=20):
    # Union-Find structure to maintain cluster merges
    uf = UnionFind(2 * n_samples - 1)  # We may need to store up to 2*n_samples - 1 clusters
    
    current_cluster_count = n_samples
    
    # Process merges from the children array
    hierarchical_labels = []
    for i, (child1, child2) in enumerate(children):
        uf.union(child1, i + n_samples)
        uf.union(child2, i + n_samples)
        #uf.union(child1, child2)
        current_cluster_count -= 1  # After each merge, we reduce the cluster count
        
        if current_cluster_count <= max_cluster:
            labels = [uf.find(i) for i in range(n_samples)]
            hierarchical_labels.append(labels)
    
    return hierarchical_labels

def load_ply_to_numpy(filename):
    """
    Load a PLY file and extract the point cloud as a (N, 3) NumPy array.

    Parameters:
        filename (str): Path to the PLY file.

    Returns:
        numpy.ndarray: Point cloud array of shape (N, 3).
    """
    # Read PLY file
    ply_data = PlyData.read(filename)
    
    # Extract vertex data
    vertex_data = ply_data["vertex"]
    
    # Convert to NumPy array (x, y, z)
    points = np.vstack([vertex_data["x"], vertex_data["y"], vertex_data["z"]]).T
    
    return points

def solve_clustering(input_fname, uid, view_id, save_dir="test_results1", max_cluster=20, out_render_fol= "test_render_clustering", filehandle=None, use_agglo=False, max_num_clusters=18, viz_dense=False, export_mesh=True):
    print(uid, view_id)

    try:
        mesh_fname = f'{save_dir}/feat_pca_{uid}_{view_id}.ply'
        dense_mesh = load_mesh_util(mesh_fname)
    except:
        mesh_fname = f'{save_dir}/feat_pca_{uid}_{view_id}_batch.ply'
        dense_mesh = load_mesh_util(mesh_fname)

    vertices = dense_mesh.vertices
    bbmin = vertices.min(0)
    bbmax = vertices.max(0)
    center = (bbmin + bbmax) * 0.5
    scale = 2.0 * 0.9 / (bbmax - bbmin).max()
    vertices = (vertices - center) * scale
    dense_mesh.vertices = vertices

    ### Load coarse mesh
    input_fname = f'{save_dir}/input_{uid}_{view_id}.ply'
    coarse_mesh = trimesh.load(input_fname, force='mesh')
    vertices = coarse_mesh.vertices

    bbmin = vertices.min(0)
    bbmax = vertices.max(0)
    center = (bbmin + bbmax) * 0.5
    scale = 2.0 * 0.9 / (bbmax - bbmin).max()
    vertices = (vertices - center) * scale
    coarse_mesh.vertices = vertices
    #####################        

    try:
        point_feat = np.load(f'{save_dir}/part_feat_{uid}_{view_id}.npy')
    except:
        try:
            point_feat = np.load(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy')

        except:
            print()
            print("pointfeat loading error. skipping...")
            print(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy')
            return

    point_feat = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True)

    if not use_agglo:
        for num_cluster in range(2, max_num_clusters):
            clustering = KMeans(n_clusters=num_cluster, random_state=0).fit(point_feat)
            labels = clustering.labels_

            if not viz_dense:
                #### Relabel coarse from dense ####
                labels = relabel_coarse_mesh(dense_mesh, labels, coarse_mesh)
                V = coarse_mesh.vertices
                F = coarse_mesh.faces
                ###################################
            else:
                V = dense_mesh.vertices
                F = dense_mesh.faces   

            pred_labels = np.zeros((len(labels), 1))
            for i, label in enumerate(np.unique(labels)):
                # print(i, label)
                pred_labels[labels == label] = i  # Assign RGB values to each label


            fname = str(uid) + "_" + str(view_id) + "_" + str(num_cluster).zfill(2)
            fname_clustering = os.path.join(out_render_fol, "cluster_out", str(uid) + "_" + str(view_id) + "_" + str(num_cluster).zfill(2))
            np.save(fname_clustering, pred_labels)

            if export_mesh :
                fname_mesh = os.path.join(out_render_fol, "ply", str(uid) + "_" + str(view_id) + "_" + str(num_cluster).zfill(2) + ".ply")
                export_colored_mesh_ply(V, F, pred_labels, filename=fname_mesh)  

    else:

        adj_matrix = construct_face_adjacency_matrix(dense_mesh.faces)
        clustering = AgglomerativeClustering(connectivity=adj_matrix,
                                    n_clusters=1,
                                    ).fit(point_feat)
        hierarchical_labels = hierarchical_clustering_labels(clustering.children_, point_feat.shape[0], max_cluster=max_num_clusters)

        all_FL = []
        for n_cluster in range(max_num_clusters):
            print("Processing cluster: "+str(n_cluster))
            labels = hierarchical_labels[n_cluster]
            all_FL.append(labels)
        
        
        all_FL = np.array(all_FL)
        unique_labels = np.unique(all_FL)

        for n_cluster in range(max_num_clusters):
            FL = all_FL[n_cluster]

            if not viz_dense:
                #### Relabel coarse from dense ####
                FL = relabel_coarse_mesh(dense_mesh, FL, coarse_mesh)
                V = coarse_mesh.vertices
                F = coarse_mesh.faces
                ###################################
            else:
                V = dense_mesh.vertices
                F = dense_mesh.faces                

            unique_labels = np.unique(FL)
            relabel = np.zeros((len(FL), 1))
            for i, label in enumerate(unique_labels):
                relabel[FL == label] = i  # Assign RGB values to each label

            if export_mesh :
                fname_mesh = os.path.join(out_render_fol, "ply", str(uid) + "_" + str(view_id) + "_" + str(max_cluster - n_cluster).zfill(2) + ".ply")
                export_colored_mesh_ply(V, F, FL, filename=fname_mesh)        
        
            
if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--source_dir', default= "", type=str)
    parser.add_argument('--root', default= "", type=str)
    parser.add_argument('--dump_dir', default= "", type=str)
    
    parser.add_argument('--max_num_clusters', default= 18, type=int)
    parser.add_argument('--use_agglo', default= True, type=bool)
    parser.add_argument('--viz_dense', default= False, type=bool)
    parser.add_argument('--export_mesh', default= True, type=bool)


    FLAGS = parser.parse_args()
    root = FLAGS.root
    OUTPUT_FOL = FLAGS.dump_dir
    SOURCE_DIR = FLAGS.source_dir

    MAX_NUM_CLUSTERS = FLAGS.max_num_clusters
    USE_AGGLO = FLAGS.use_agglo
    EXPORT_MESH = FLAGS.export_mesh

    models = os.listdir(root)
    os.makedirs(OUTPUT_FOL, exist_ok=True)

    if EXPORT_MESH:
        ply_fol = os.path.join(OUTPUT_FOL, "ply")
        os.makedirs(ply_fol, exist_ok=True)    

    cluster_fol = os.path.join(OUTPUT_FOL, "cluster_out")
    os.makedirs(cluster_fol, exist_ok=True) 

    #### Get existing model_ids ###
    all_files = os.listdir(os.path.join(OUTPUT_FOL, "ply"))

    existing_model_ids = []
    for sample in all_files:
        uid = sample.split("_")[0]
        view_id = sample.split("_")[1]
        # sample_name = str(uid) + "_" + str(view_id)
        sample_name = str(uid)

        if sample_name not in existing_model_ids:
            existing_model_ids.append(sample_name)
    ##############################

    all_files = os.listdir(SOURCE_DIR)
    selected = []
    for f in all_files:
        if (".obj" in f or ".glb" in f) and f.split(".")[0] not in existing_model_ids:
            selected.append(f)
    
    print("Number of models to process: " + str(len(selected)))
    

    for model in selected:
        fname = os.path.join(SOURCE_DIR, model)
        uid = model.split(".")[-2]
        view_id = 0

        solve_clustering(fname, uid, view_id, save_dir=root, out_render_fol= OUTPUT_FOL, use_agglo=USE_AGGLO, max_num_clusters=MAX_NUM_CLUSTERS, viz_dense=FLAGS.viz_dense, export_mesh=EXPORT_MESH)