Ruining Li commited on
Commit
4f22fc0
·
1 Parent(s): 5cbf9bb

Init: add PartField + particulate, track example assets via LFS

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. PartField/LICENSE +36 -0
  3. PartField/README.md +242 -0
  4. PartField/applications/.polyscope.ini +6 -0
  5. PartField/applications/README.md +142 -0
  6. PartField/applications/multi_shape_cosegment.py +482 -0
  7. PartField/applications/pack_labels_to_obj.py +47 -0
  8. PartField/applications/run_smooth_functional_map.py +80 -0
  9. PartField/applications/shape_pair.py +385 -0
  10. PartField/applications/single_shape.py +758 -0
  11. PartField/compute_metric.py +97 -0
  12. PartField/configs/final/correspondence_demo.yaml +44 -0
  13. PartField/configs/final/demo.yaml +28 -0
  14. PartField/download_demo_data.sh +19 -0
  15. PartField/environment.yml +772 -0
  16. PartField/partfield/__pycache__/dataloader.cpython-310.pyc +0 -0
  17. PartField/partfield/__pycache__/model_trainer_pvcnn_only_demo.cpython-310.pyc +0 -0
  18. PartField/partfield/__pycache__/utils.cpython-310.pyc +0 -0
  19. PartField/partfield/config/__init__.py +26 -0
  20. PartField/partfield/config/__pycache__/__init__.cpython-310.pyc +0 -0
  21. PartField/partfield/config/__pycache__/defaults.cpython-310.pyc +0 -0
  22. PartField/partfield/config/defaults.py +92 -0
  23. PartField/partfield/dataloader.py +366 -0
  24. PartField/partfield/model/PVCNN/__pycache__/conv_pointnet.cpython-310.pyc +0 -0
  25. PartField/partfield/model/PVCNN/__pycache__/dnnlib_util.cpython-310.pyc +0 -0
  26. PartField/partfield/model/PVCNN/__pycache__/encoder_pc.cpython-310.pyc +0 -0
  27. PartField/partfield/model/PVCNN/__pycache__/pc_encoder.cpython-310.pyc +0 -0
  28. PartField/partfield/model/PVCNN/__pycache__/unet_3daware.cpython-310.pyc +0 -0
  29. PartField/partfield/model/PVCNN/conv_pointnet.py +251 -0
  30. PartField/partfield/model/PVCNN/dnnlib_util.py +1074 -0
  31. PartField/partfield/model/PVCNN/encoder_pc.py +243 -0
  32. PartField/partfield/model/PVCNN/pc_encoder.py +90 -0
  33. PartField/partfield/model/PVCNN/pv_module/__init__.py +2 -0
  34. PartField/partfield/model/PVCNN/pv_module/__pycache__/__init__.cpython-310.pyc +0 -0
  35. PartField/partfield/model/PVCNN/pv_module/__pycache__/pvconv.cpython-310.pyc +0 -0
  36. PartField/partfield/model/PVCNN/pv_module/__pycache__/shared_mlp.cpython-310.pyc +0 -0
  37. PartField/partfield/model/PVCNN/pv_module/__pycache__/voxelization.cpython-310.pyc +0 -0
  38. PartField/partfield/model/PVCNN/pv_module/ball_query.py +34 -0
  39. PartField/partfield/model/PVCNN/pv_module/frustum.py +141 -0
  40. PartField/partfield/model/PVCNN/pv_module/functional/__init__.py +1 -0
  41. PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/__init__.cpython-310.pyc +0 -0
  42. PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/devoxelization.cpython-310.pyc +0 -0
  43. PartField/partfield/model/PVCNN/pv_module/functional/devoxelization.py +12 -0
  44. PartField/partfield/model/PVCNN/pv_module/loss.py +10 -0
  45. PartField/partfield/model/PVCNN/pv_module/pointnet.py +113 -0
  46. PartField/partfield/model/PVCNN/pv_module/pvconv.py +38 -0
  47. PartField/partfield/model/PVCNN/pv_module/shared_mlp.py +35 -0
  48. PartField/partfield/model/PVCNN/pv_module/voxelization.py +50 -0
  49. PartField/partfield/model/PVCNN/unet_3daware.py +427 -0
  50. PartField/partfield/model/UNet/__pycache__/buildingblocks.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -33,4 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
36
  examples/*.png filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/*.glb filter=lfs diff=lfs merge=lfs -text
37
  examples/*.png filter=lfs diff=lfs merge=lfs -text
PartField/LICENSE ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NVIDIA License
2
+
3
+ 1. Definitions
4
+
5
+ “Licensor” means any person or entity that distributes its Work.
6
+ “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
7
+ The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
8
+ Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
9
+
10
+ 2. License Grant
11
+
12
+ 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
13
+
14
+ 3. Limitations
15
+
16
+ 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
17
+
18
+ 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
19
+
20
+ 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for non-commercial research and educational purposes only.
21
+
22
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
23
+
24
+ 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
25
+
26
+ 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
27
+
28
+ 4. Disclaimer of Warranty.
29
+
30
+ THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
31
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
32
+
33
+ 5. Limitation of Liability.
34
+
35
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
36
+
PartField/README.md ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PartField: Learning 3D Feature Fields for Part Segmentation and Beyond [ICCV 2025]
2
+ **[[Project]](https://research.nvidia.com/labs/toronto-ai/partfield-release/)** **[[PDF]](https://arxiv.org/pdf/2504.11451)**
3
+
4
+ Minghua Liu*, Mikaela Angelina Uy*, Donglai Xiang, Hao Su, Sanja Fidler, Nicholas Sharp, Jun Gao
5
+
6
+
7
+
8
+ ## Overview
9
+ ![Alt text](assets/teaser.png)
10
+
11
+ PartField is a feedforward model that predicts part-based feature fields for 3D shapes. Our learned features can be clustered to yield a high-quality part decomposition, outperforming the latest open-world 3D part segmentation approaches in both quality and speed. PartField can be applied to a wide variety of inputs in terms of modality, semantic class, and style. The learned feature field exhibits consistency across shapes, enabling applications such as cosegmentation, interactive selection, and correspondence.
12
+
13
+ ## Table of Contents
14
+
15
+ - [Pretrained Model](#pretrained-model)
16
+ - [Environment Setup](#environment-setup)
17
+ - [TLDR](#tldr)
18
+ - [Example Run](#example-run)
19
+ - [Interactive Tools and Applications](#interactive-tools-and-applications)
20
+ - [Evaluation on PartObjaverse-Tiny](#evaluation-on-partobjaverse-tiny)
21
+ - [Discussion](#discussion-clustering-with-messy-mesh-connectivities)
22
+ - [Citation](#citation)
23
+
24
+
25
+ ## Pretrained Model
26
+ ```
27
+ mkdir model
28
+ ```
29
+ The link to download our pretrained model is here: [Trained on Objaverse](https://huggingface.co/mikaelaangel/partfield-ckpt/blob/main/model_objaverse.ckpt). Due to licensing restrictions, we are unable to release the model that was also trained on PartNet.
30
+
31
+ ## Environment Setup
32
+
33
+ We use Python 3.10 with PyTorch 2.4 and CUDA 12.4. The environment and required packages can be installed individually as follows:
34
+ ```
35
+ conda create -n partfield python=3.10
36
+ conda activate partfield
37
+ conda install nvidia/label/cuda-12.4.0::cuda
38
+ pip install psutil
39
+ pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
40
+ pip install lightning==2.2 h5py yacs trimesh scikit-image loguru boto3
41
+ pip install mesh2sdf tetgen pymeshlab plyfile einops libigl polyscope potpourri3d simple_parsing arrgh open3d
42
+ pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu124.html
43
+ apt install libx11-6 libgl1 libxrender1
44
+ pip install vtk
45
+ ```
46
+
47
+ An environment file is also provided and can be used for installation:
48
+ ```
49
+ conda env create -f environment.yml
50
+ conda activate partfield
51
+ ```
52
+
53
+ ## TLDR
54
+ 1. Input data (`.obj` or `.glb` for meshes, `.ply` for splats) are stored in subfolders under `data/`. You can create a new subfolder and copy your custom files into it.
55
+ 2. Extract PartField features by running the script `partfield_inference.py`, passing the arguments `result_name [FEAT_FOL]` and `dataset.data_path [DATA_PATH]`. The output features will be saved in `exp_results/partfield_features/[FEAT_FOL]`.
56
+ 3. Segmented parts can be obtained by running the script `run_part_clustering.py`, using the arguments `--root exp/[FEAT_FOL]` and `--dump_dir [PART_OUT_FOL]`. The output segmentations will be saved in `exp_results/clustering/[PART_OUT_FOL]`.
57
+ 4. Application demo scripts are available in the `applications/` directory and can be used after extracting PartField features (i.e., after running `partfield_inference.py` on the desired demo data).
58
+
59
+ ## Example Run
60
+ ### Download Demo Data
61
+
62
+ #### Mesh Data
63
+ We showcase the feasibility of PartField using sample meshes from Objaverse (artist-created) and Trellis3D (AI-generated). Sample data can be downloaded below:
64
+ ```
65
+ sh download_demo_data.sh
66
+ ```
67
+ Downloaded meshes can be found in `data/objaverse_samples/` and `data/trellis_samples/`.
68
+
69
+ #### Gaussian Splats
70
+ We also demonstrate our approach using Gaussian splatting reconstructions as input. Sample splat reconstruction data from the NeRF dataset can be found [here](https://drive.google.com/drive/folders/1l0njShLq37hn1TovgeF-PVGBBrAdNQnf?usp=sharing). Download the data and place it in the `data/splat_samples/` folder.
71
+
72
+ ### Extract Feature Field
73
+ #### Mesh Data
74
+
75
+ ```
76
+ python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/objaverse dataset.data_path data/objaverse_samples
77
+ python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/trellis dataset.data_path data/trellis_samples
78
+ ```
79
+
80
+ #### Point Clouds / Gaussian Splats
81
+ ```
82
+ python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/splat dataset.data_path data/splat_samples is_pc True
83
+ ```
84
+
85
+ ### Part Segmentation
86
+ #### Mesh Data
87
+
88
+ We use agglomerative clustering for part segmentation on mesh inputs.
89
+ ```
90
+ python run_part_clustering.py --root exp_results/partfield_features/objaverse --dump_dir exp_results/clustering/objaverse --source_dir data/objaverse_samples --use_agglo True --max_num_clusters 20 --option 0
91
+ ```
92
+
93
+ When the input mesh has multiple connected components or poor connectivity, defining face adjacency by connecting geometrically close faces can yield better results (see discussion below):
94
+ ```
95
+ python run_part_clustering.py --root exp_results/partfield_features/trellis --dump_dir exp_results/clustering/trellis --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20 --option 1 --with_knn True
96
+ ```
97
+
98
+ Note that agglomerative clustering does not return a fixed clustering result, but rather a hierarchical part tree, where the root node represents the whole shape and each leaf node corresponds to a single triangle face. You can explore more clustering results by adaptively traversing the tree, such as deciding which part should be further segmented.
99
+
100
+ #### Point Cloud / Gaussian Splats
101
+ We use K-Means clustering for part segmentation on point cloud inputs.
102
+ ```
103
+ python run_part_clustering.py --root exp_results/partfield_features/splat --dump_dir exp_results/clustering/splat --source_dir data/splat_samples --max_num_clusters 20 --is_pc True
104
+ ```
105
+
106
+ ## Interactive Tools and Applications
107
+ We include UI tools to demonstrate various applications of PartField. Set up and try out our demos [here](applications/)!
108
+
109
+ ![Alt text](assets/co-seg.png)
110
+
111
+ ![Alt text](assets/regression_interactive_segmentation_guitars.gif)
112
+
113
+ ## Evaluation on PartObjaverse-Tiny
114
+
115
+ ![Alt text](assets/results_combined_compressed2.gif)
116
+
117
+ To evaluate all models in PartObjaverse-Tiny, you can download the data [here](https://github.com/Pointcept/SAMPart3D/blob/main/PartObjaverse-Tiny/PartObjaverse-Tiny.md) and run the following commands:
118
+ ```
119
+ python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/partobjtiny dataset.data_path data/PartObjaverse-Tiny/PartObjaverse-Tiny_mesh n_point_per_face 2000 n_sample_each 10000
120
+ python run_part_clustering.py --root exp_results/partfield_features/partobjtiny/ --dump_dir exp_results/clustering/partobjtiny --source_dir data/PartObjaverse-Tiny/PartObjaverse-Tiny_mesh --use_agglo True --max_num_clusters 20 --option 0
121
+ ```
122
+ If an OOM error occurs, you can reduce the number of points sampled per face—for example, by setting `n_point_per_face` to 500.
123
+
124
+ Evaluation metrics can be obtained by running the command below. The per-category average mIoU reported in the paper is also computed.
125
+ ```
126
+ python compute_metric.py
127
+ ```
128
+ This evaluation code builds on top of the implementation released by [SAMPart3D](https://github.com/Pointcept/SAMPart3D). Users with their own data and corresponding ground truths can easily modify this script to compute their metrics.
129
+
130
+ ## Discussion: Clustering with Messy Mesh Connectivities
131
+ <!-- Some meshes can get messy with a lot of connected components, here the connectivity information may not be useful, causing failure cases when using Agglomerative clustering. In these cases, we provide two alternatives, 1) cluster using KMeans. We provide sample code below, or 2) converting the input mesh to a manifold surface mesh.
132
+
133
+ Sample data download:
134
+ ```
135
+ cd data
136
+ mkdir messy_meshes_samples
137
+ cd messy_meshes_samples
138
+ wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-007/00790c705e4c4a1fbc0af9bf5c9e9525.glb
139
+ wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-132/13cc3ffc69964894a2bc94154aed687f.glb
140
+ ```
141
+
142
+ Extract Partfield feature on the original mesh and run KMeans clustering:
143
+ ```
144
+ python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/messy_meshes_samples dataset.data_path data/messy_meshes_samples
145
+ python run_part_clustering.py --root exp_results/partfield_features/messy_meshes_samples/ --dump_dir exp_results/clustering/messy_meshes_samples_kmeans/ --source_dir data/messy_meshes_samples --max_num_clusters 20
146
+ ```
147
+
148
+ Extract convert mesh into a surface manifold, extract Partfield feature and run agglomerative clustering:
149
+ ```
150
+ python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/messy_meshes_samples_remesh dataset.data_path data/messy_meshes_samples remesh_demo True
151
+ python run_part_clustering_remesh.py --root exp_results/partfield_features/messy_meshes_samples_remesh --dump_dir exp_results/clustering/messy_meshes_samples_remesh --source_dir data/messy_meshes_samples --use_agglo True --max_num_clusters 20
152
+
153
+ python run_part_clustering_remesh.py --root exp_results/partfield_features/trellis_remesh --dump_dir exp_results/clustering/trellis_remesh --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20
154
+ ``` -->
155
+ When using agglomerative clustering for part segmentation, an adjacency matrix is passed into the algorithm, which ideally requires the mesh to be a single connected component. However, some meshes can be messy, containing multiple connected components. If the input mesh is not a single connected component, we add pseudo-edges to the adjacency matrix to make it one. By default, we take a simple approach: adding `N-1` pseudo-edges as a chain to connect `N` components together. However, this approach can lead to poor results when the mesh is poorly connected and fragmented.
156
+
157
+ <img src="assets/messy_meshes_screenshot/option0.png" width="480"/>
158
+
159
+ ```
160
+ python run_part_clustering.py --root exp_results/partfield_features/trellis --dump_dir exp_results/clustering/trellis_bad --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20 --option 0
161
+ ```
162
+
163
+ When this occurs, we explore different options that can lead to better results:
164
+
165
+ ### 1. Preprocess Input Mesh
166
+
167
+ We can perform a simple cleanup on the input meshes by removing duplicate vertices and faces, and by merging nearby vertices using `pymeshlab`. This preprocessing step can be enabled via a flag when generating PartField features:
168
+
169
+ ```
170
+ python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/trellis_preprocess dataset.data_path data/trellis_samples preprocess_mesh True
171
+ ```
172
+
173
+ When running agglomerative clustering on a cleaned-up mesh, we observe improved part segmentation:
174
+
175
+ <img src="assets/messy_meshes_screenshot/preprocess.png" width="480"/>
176
+
177
+ ```
178
+ python run_part_clustering.py --root exp_results/partfield_features/trellis_preprocess --dump_dir exp_results/clustering/trellis_preprocess --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20 --option 0
179
+ ```
180
+
181
+ ### 2. Cluster with KMeans
182
+
183
+ If modifying the input mesh is not desirable and you prefer to avoid preprocessing, an alternative is to use KMeans clustering, which does not rely on an adjacency matrix.
184
+
185
+
186
+ <img src="assets/messy_meshes_screenshot/kmeans.png" width="480"/>
187
+
188
+ ```
189
+ python run_part_clustering.py --root exp_results/partfield_features/trellis --dump_dir exp_results/clustering/trellis_kmeans --source_dir data/trellis_samples --max_num_clusters 20
190
+ ```
191
+
192
+ ### 3. MST-based Adjacency Matrix
193
+
194
+ Instead of simply chaining the connected components of the input mesh, we also explore adding pseudo-edges to the adjacency matrix by constructing a KNN graph using face centroids and computing the minimum spanning tree of that graph.
195
+
196
+ <img src="assets/messy_meshes_screenshot/option1.png" width="480"/>
197
+
198
+ ```
199
+ python run_part_clustering.py --root exp_results/partfield_features/trellis --dump_dir exp_results/clustering/trellis_faceadj --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20 --option 1 --with_knn True
200
+ ```
201
+
202
+ <!-- ### Remesh with Marching Cubes (experimental)
203
+
204
+ We also explore computing the SDF of the input mesh and then running marching cubes, resulting in a surface mesh that is guaranteed to be a single connected component. We then run clustering on the new mesh and map back the segmentation labels to the original mesh by a voting scheme.
205
+
206
+ <img src="assets/messy_meshes_screenshot/remesh.png" width="480"/>
207
+
208
+ ```
209
+ python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/trellis_remesh dataset.data_path data/trellis_samples remesh_demo True
210
+
211
+ python run_part_clustering_remesh.py --root exp_results/partfield_features/trellis_remesh --dump_dir exp_results/clustering/trellis_remesh --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20
212
+ ```-->
213
+
214
+ ### More Challenging Meshes!
215
+ The proposed approaches improve results for some meshes, but we find that certain cases still do not produce satisfactory segmentations. We leave these challenges for future work. If you're interested, here are some examples of difficult meshes we encountered:
216
+
217
+ **Challenging Meshes:**
218
+ ```
219
+ cd data
220
+ mkdir challenge_samples
221
+ cd challenge_samples
222
+ wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-007/00790c705e4c4a1fbc0af9bf5c9e9525.glb
223
+ wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-132/13cc3ffc69964894a2bc94154aed687f.glb
224
+ ```
225
+
226
+ ## Citation
227
+ ```
228
+ @inproceedings{partfield2025,
229
+ title={PartField: Learning 3D Feature Fields for Part Segmentation and Beyond},
230
+ author={Minghua Liu and Mikaela Angelina Uy and Donglai Xiang and Hao Su and Sanja Fidler and Nicholas Sharp and Jun Gao},
231
+ year={2025}
232
+ }
233
+ ```
234
+
235
+ ## References
236
+ PartField borrows code from the following repositories:
237
+ - [OpenLRM](https://github.com/3DTopia/OpenLRM)
238
+ - [PyTorch 3D UNet](https://github.com/wolny/pytorch-3dunet)
239
+ - [PVCNN](https://github.com/mit-han-lab/pvcnn)
240
+ - [SAMPart3D](https://github.com/Pointcept/SAMPart3D) — evaluation script
241
+
242
+ Many thanks to the authors for sharing their code!
PartField/applications/.polyscope.ini ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "windowHeight": 1104,
3
+ "windowPosX": 66,
4
+ "windowPosY": 121,
5
+ "windowWidth": 2215
6
+ }
PartField/applications/README.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Interactive Tools and Applications
2
+
3
+ ## Single-Shape Feature and Segmentation Visualization Tool
4
+ We can visualize the output features and segmentation of a single shape by running the script below:
5
+
6
+ ```
7
+ cd applications/
8
+ python single_shape.py --data_root ../exp_results/partfield_features/trellis/ --filename dwarf
9
+ ```
10
+
11
+ - `Mode: pca, feature_viz, cluster_agglo, cluster_kmeans`
12
+ - `pca` : Visualizes the pca Partfield features of the input model as colors.
13
+ - `feature_viz` : Visualizes each dimension of the PartField features as a colormap.
14
+ - `cluster_agglo` : Visualizes the part segmentation of the input model using Agglomerative clustering.
15
+ - Number of clusters is specified with the slider.
16
+ - `Adj Matrix Def`: Specifies how the adjacency matrix is defined for the clustering algorithm by adding dummy edges to make the input mesh a single connected component.
17
+ - `Add KNN edges` : Adds additional dummy edges based on k nearest neighbors.
18
+ - `cluster_kmeans` : Visualizes the part segmentation of the input model using KMeans clustering.
19
+ - Number of clusters is specified with the slider.
20
+
21
+
22
+ ## Shape-Pair Co-Segmentation and Feature Exploration Tool
23
+ We provide a tool to analyze and visualize a pair of shapes that has two main functionalities: 1) **Co-segmentation** via co-clustering and 2) Partfield **feature exploration** and visualization. Try it out as follows:
24
+
25
+ ```
26
+ cd applications/
27
+ python shape_pair.py --data_root ../exp_results/partfield_features/trellis/ --filename dwarf --filename_alt goblin
28
+ ```
29
+
30
+ ### Co-Clustering for Co-Segmentation
31
+
32
+ Here explains the use-case for `Mode: co-segmentation`.
33
+
34
+ ![Alt text](../assets/co-seg.png)
35
+
36
+ The shape-pair is co-segmented by running co-clustering, In this application, we use the KMeans clustering algorithm. The `first shape (left)` is separated into parts via **unsupervised clustering** of its features with KMeans, from which the parts of the `second shape (right)` are then defined.
37
+
38
+ Below are a list parameters:
39
+ - `Source init`:
40
+ - `True`: Initializes the cluster centers of the second shape (right) with the cluster centers of the first shape (left).
41
+ - `False`: Uses KMeans++ to initialize the cluster centers for KMeans for the second shape.
42
+ - `Independent`:
43
+ - `True`: Labels after running KMeans clustering are directly used as parts for the second shape. Correspondence with the parts of the first shape is not explicitly computed after KMeans clustering.
44
+ - `False`: After KMeans clustering is ran on the features of the second shape, the mean features for each unique part is then computed. The mean part feature for each part of the first shape is also computed. Then the parts of the second shaped are assigned labels based on the nearest neighbor part of the first shape.
45
+ - `Num cluster`:
46
+ - `Model1` : A slider is used to specify the number of parts for the first shape, i.e. number of clusters for KMeans clustering.
47
+ - `Model2` : A slider is used to specify the number of parts for the second shape, i.e. number of clusters for KMeans clustering. Note: if `Source init` is set to `True` then this slider is ignored and the number of clusters for Model1 is used.
48
+
49
+
50
+ ### Feature Exploration and Visualization
51
+
52
+ Here explains the use-case for `Mode: feature_explore`.
53
+
54
+ ![Alt text](../assets/feature_exploration2.png)
55
+
56
+ This feature allows us to select a query point from the first shape (left) and the feature distance to all points in the second shape (left) and itself is then visualized as a colormap.
57
+ - `range` : A slider to specify the distance radius for feature similarity visualization. Large values will result in bigger highlighter areas.
58
+ - `continuous` :
59
+ - `False` : Query point is specified with a mouse click.
60
+ - `True` : You can slide your mouse around the first mesh to visualize feature distances.
61
+
62
+ ## Multi-shape Cosegmentation Tool
63
+ We further demonstrate PartField for cosegmentation of multiple/a set of shapes. Try out our demo application as follows:
64
+
65
+ ### Dependency Installation
66
+ Let's first install the necessary dependencies for this tool:
67
+ ```
68
+ pip install cuml-cu12
69
+ pip install xgboost
70
+ ```
71
+
72
+ ### Dataset
73
+ We use the Shape COSEG dataset for our demo. We first download the dataset here:
74
+ ```
75
+ mkdir data/coseg_guitar
76
+ cd data/coseg_guitar
77
+ wget https://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/shapes.zip
78
+ wget https://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/gt.zip
79
+ unzip shapes.zip
80
+ unzip gt.zip
81
+ ```
82
+
83
+ Now, let's extract Partfield features for the set of shapes:
84
+ ```
85
+ python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/coseg_guitar/ dataset.data_path data/coseg_guitar/shapes
86
+ ```
87
+
88
+ Now, we're ready to run the tool! We support two modes: 1) **Few-shot** with click-based annotations and 2) **Supervised** with ground truth labels.
89
+
90
+ ### Annotate Mode
91
+ ![Alt text](../assets/regression_interactive_segmentation_guitars.gif)
92
+
93
+ We can run our few-shot segmentation tool as follows:
94
+ ```
95
+ cd applications/
96
+ python multi_shape_cosegment.py --meshes ../exp_results/partfield_features/coseg_guitar/
97
+ ```
98
+ We can annotate the segments with a few clicks. A classifier is then ran to obtain the part segmentation.
99
+ - N_class: number of segmentation class labels
100
+ - Annotate:
101
+ - `00, 01, 02, ...`: Select the segmentation class label, then click on a shape region of that class.
102
+ - `Undo Last Selection`: Removes and disregards the last annotation made.
103
+ - Fit:
104
+ - Fit Method: Selects the classification method used for fitting. Default uses `Logistic Regression`.
105
+ - `Update Fit`: By default, the fitting process is automatically updated. This can also be changed to a manual update.
106
+
107
+ ### Ground Truth Labels Mode
108
+
109
+ ![Alt text](../assets/gt-based_coseg.png)
110
+
111
+ Alternatively, we can use the ground truth labels of a subset of the shapes to train the classifier.
112
+
113
+ ```
114
+ cd applications/
115
+ python multi_shape_cosegment.py --meshes ../exp_results/partfield_features/coseg_guitar/ --n_train_subset 15
116
+ ```
117
+ `Fit Method` can also be selected here to choose the classifier to be used.
118
+
119
+
120
+ ## 3D Correspondences
121
+
122
+ First, we clone the repository [SmoothFunctionalMaps](https://github.com/RobinMagnet/SmoothFunctionalMaps) and install additional packages.
123
+ ```
124
+ pip install omegaconf robust_laplacian
125
+ git submodule init
126
+ git submodule update --recursive
127
+ ```
128
+
129
+ Download the [DenseCorr3D dataset](https://drive.google.com/file/d/1bpgsNu8JewRafhdRN4woQL7ObQtfgcpu/view?usp=sharing) into the `data` folder. Unzip the contents and ensure that the file structure is organized so that you can access
130
+ `data/DenseCorr3D/animals/071b8_toy_animals_017`.
131
+
132
+ Extract the PartField features.
133
+ ```
134
+ # run in root directory of this repo
135
+ python partfield_inference.py -c configs/final/correspondence_demo.yaml --opts continue_ckpt model/model_objaverse.ckpt preprocess_mesh True
136
+ ```
137
+
138
+ Run the functional map.
139
+ ```
140
+ cd applications/
141
+ python run_smooth_functional_map.py -c ../configs/final/correspondence_demo.yaml --opts
142
+ ```
PartField/applications/multi_shape_cosegment.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import argparse
4
+ from dataclasses import dataclass
5
+
6
+ from arrgh import arrgh
7
+ import polyscope as ps
8
+ import polyscope.imgui as psim
9
+ import potpourri3d as pp3d
10
+ import trimesh
11
+
12
+ import cuml
13
+ import xgboost as xgb
14
+
15
+ import os, random
16
+
17
+ import sys
18
+ sys.path.append("..")
19
+ from partfield.utils import *
20
+
21
+ @dataclass
22
+ class State:
23
+
24
+ objects = None
25
+ train_objects = None
26
+
27
+ # Input options
28
+ subsample_inputs: int = -1
29
+ n_train_subset: int = 0
30
+
31
+ # Label
32
+ N_class: int = 2
33
+
34
+ # Annotations
35
+ # A annotations (initially A = 0)
36
+ anno_feat: np.array = np.zeros((0,448), dtype=np.float32) # [A,F]
37
+ anno_label: np.array = np.zeros((0,), dtype=np.int32) # [A]
38
+ anno_pos: np.array = np.zeros((0,3), dtype=np.float32) # [A,3]
39
+
40
+ # Intermediate selection data
41
+ is_selecting: bool = False
42
+ selection_class: int = 0
43
+
44
+ # Fitting algorithm
45
+ fit_to: str = "Annotations"
46
+ fit_method : str = "LogisticRegression"
47
+ auto_update_fit: bool = True
48
+
49
+ # Training data
50
+ # T training datapoints
51
+ train_feat: np.array = np.zeros((0,448), dtype=np.float32) # [T,F]
52
+ train_label: np.array = np.zeros((0,), dtype=np.int32) # [T]
53
+
54
+ # Viz
55
+ grid_w : int = 8
56
+ per_obj_shift : float = 2.
57
+ anno_radius : float = 0.01
58
+ ps_cloud_annotation = None
59
+ ps_structure_name_to_index_map = {}
60
+
61
+
62
+ fit_methods_list = ["LinearRegression", "LogisticRegression", "LinearSVC", "RandomForest", "NearestNeighbors", "XGBoost"]
63
+ fit_to_list = ["Annotations", "TrainingSet"]
64
+
65
+ def load_mesh_and_features(mesh_filepath, ind, require_gt=False, gt_label_fol = ""):
66
+
67
+ dirpath, filename = os.path.split(mesh_filepath)
68
+ filename_core = filename[9:-6] # splits off "feat_pca_" ... "_0.ply"
69
+ feature_filename = "part_feat_"+ filename_core + "_0_batch.npy"
70
+ feature_filepath = os.path.join(dirpath, feature_filename)
71
+
72
+ gt_filename = filename_core + ".seg"
73
+ gt_filepath = os.path.join(gt_label_fol, gt_filename)
74
+ have_gt = os.path.isfile(gt_filepath)
75
+
76
+ print(" Reading file:")
77
+ print(f" Mesh filename: {mesh_filepath}")
78
+ print(f" Feature filename: {feature_filepath}")
79
+ print(f" Ground Truth Label filename: {gt_filepath} -- present = {have_gt}")
80
+
81
+ # load features
82
+ feat = np.load(feature_filepath, allow_pickle=False)
83
+ feat = feat.astype(np.float32)
84
+
85
+ # load mesh things
86
+ # TODO replace this with just loading V/F from numpy archive
87
+ tm = load_mesh_util(mesh_filepath)
88
+
89
+ V = np.array(tm.vertices, dtype=np.float32)
90
+ F = np.array(tm.faces)
91
+
92
+ # load ground truth, if available
93
+ if have_gt:
94
+ gt_labels = np.loadtxt(gt_filepath)
95
+ gt_labels = gt_labels.astype(np.int32) - 1
96
+ else:
97
+ if require_gt:
98
+ raise ValueError("could not find ground-truth file, but it is required")
99
+ gt_labels = None
100
+
101
+ # pca_colors = None
102
+
103
+ return {
104
+ 'nicename' : f"{ind:02d}_{filename_core}",
105
+ 'mesh_filepath' : mesh_filepath,
106
+ 'feature_filepath' : feature_filepath,
107
+ 'V' : V,
108
+ 'F' : F,
109
+ 'feat_np' : feat,
110
+ # 'feat_pt' : torch.tensor(feat, device='cuda'),
111
+ 'gt_labels' : gt_labels
112
+ }
113
+
114
+ def shift_for_ind(state : State, ind):
115
+
116
+ x_ind = ind % state.grid_w
117
+ y_ind = ind // state.grid_w
118
+
119
+ shift = np.array([state.per_obj_shift * x_ind, 0, -state.per_obj_shift * y_ind])
120
+
121
+ return shift
122
+
123
+ def viz_upper_limit(state : State, ind_count):
124
+
125
+ x_max = min(ind_count, state.grid_w)
126
+ y_max = ind_count // state.grid_w
127
+
128
+ bound = np.array([state.per_obj_shift * x_max, 0, -state.per_obj_shift * y_max])
129
+
130
+ return bound
131
+
132
+
133
+ def initialize_object_viz(state : State, obj, index=0):
134
+ obj['ps_mesh'] = ps.register_surface_mesh(obj['nicename'], obj['V'], obj['F'], color=(.8, .8, .8))
135
+ shift = shift_for_ind(state, index)
136
+ obj['ps_mesh'].translate(shift)
137
+ obj['ps_mesh'].set_selection_mode('faces_only')
138
+ state.ps_structure_name_to_index_map[obj['nicename']] = index
139
+
140
+ def update_prediction(state: State):
141
+
142
+ print("Updating predictions..")
143
+
144
+ N_anno = state.anno_label.shape[0]
145
+
146
+ # Quick out if we don't have at least two distinct class labels present
147
+ if(state.fit_to == "Annotations" and len(np.unique(state.anno_label)) <= 1):
148
+ return state
149
+
150
+ # Quick out if we don't have
151
+ if(state.fit_to == "TrainingSet" and state.train_objects is None):
152
+ return state
153
+
154
+ if state.fit_method == "LinearRegression":
155
+ classifier = cuml.multiclass.MulticlassClassifier(cuml.linear_model.LinearRegression(), strategy='ovr')
156
+ elif state.fit_method == "LogisticRegression":
157
+ classifier = cuml.multiclass.MulticlassClassifier(cuml.linear_model.LogisticRegression(), strategy='ovr')
158
+ elif state.fit_method == "LinearSVC":
159
+ classifier = cuml.multiclass.MulticlassClassifier(cuml.svm.LinearSVC(), strategy='ovr')
160
+ elif state.fit_method == "RandomForest":
161
+ classifier = cuml.ensemble.RandomForestClassifier()
162
+ elif state.fit_method == "NearestNeighbors":
163
+ classifier = cuml.multiclass.MulticlassClassifier(cuml.neighbors.KNeighborsRegressor(n_neighbors=1), strategy='ovr')
164
+ elif state.fit_method == "XGBoost":
165
+ classifier = xgb.XGBClassifier(max_depth=7, n_estimators=1000)
166
+ else:
167
+ raise ValueError("unrecognized fit method")
168
+
169
+ if state.fit_to == "TrainingSet":
170
+
171
+ all_train_feats = []
172
+ all_train_labels = []
173
+ for obj in state.train_objects:
174
+ all_train_feats.append(obj['feat_np'])
175
+ all_train_labels.append(obj['gt_labels'])
176
+
177
+ all_train_feats = np.concatenate(all_train_feats, axis=0)
178
+ all_train_labels = np.concatenate(all_train_labels, axis=0)
179
+
180
+ state.N_class = np.max(all_train_labels) + 1
181
+
182
+ classifier.fit(all_train_feats, all_train_labels)
183
+
184
+
185
+ elif state.fit_to == "Annotations":
186
+ classifier.fit(state.anno_feat,state.anno_label)
187
+ else:
188
+ raise ValueError("unrecognized fit to")
189
+
190
+ n_total = 0
191
+ n_correct = 0
192
+
193
+ for obj in state.objects:
194
+ obj['pred_label'] = classifier.predict(obj['feat_np'])
195
+
196
+ if obj['gt_labels'] is not None:
197
+ n_total += obj['gt_labels'].shape[0]
198
+ n_correct += np.sum(obj['pred_label'] == obj['gt_labels'], dtype=np.int32)
199
+
200
+ if(state.fit_to == "TrainingSet" and n_total > 0):
201
+ frac = n_correct / n_total
202
+ print(f"Test accuracy: {n_correct:d} / {n_total:d} {100*frac:.02f}%")
203
+
204
+
205
+ print("Done updating predictions.")
206
+
207
+ return state
208
+
209
+ def update_prediction_viz(state: State):
210
+
211
+ for obj in state.objects:
212
+ if 'pred_label' in obj:
213
+ obj['ps_mesh'].add_scalar_quantity("pred labels", obj['pred_label'], defined_on='faces', vminmax=(0,state.N_class-1), cmap='turbo', enabled=True)
214
+
215
+ return state
216
+
217
+ def update_annotation_viz(state: State):
218
+
219
+ ps_cloud = ps.register_point_cloud("annotations", state.anno_pos, radius=state.anno_radius, material='candy')
220
+ ps_cloud.add_scalar_quantity("labels", state.anno_label, vminmax=(0,state.N_class-1), cmap='turbo', enabled=True)
221
+
222
+ state.ps_cloud_annotation = ps_cloud
223
+
224
+ return state
225
+
226
+
227
+ def filter_old_labels(state: State):
228
+ """
229
+ Filter out annotations from classes that don't exist any more
230
+ """
231
+
232
+ keep_mask = state.anno_label < state.N_class
233
+ state.anno_feat = state.anno_feat[keep_mask,:]
234
+ state.anno_label = state.anno_label[keep_mask]
235
+ state.anno_pos = state.anno_pos[keep_mask,:]
236
+
237
+ return state
238
+
239
+ def undo_last_annotation(state: State):
240
+
241
+ state.anno_feat = state.anno_feat[:-1,:]
242
+ state.anno_label = state.anno_label[:-1]
243
+ state.anno_pos = state.anno_pos[:-1,:]
244
+
245
+ return state
246
+
247
+ def ps_callback(state_list):
248
+ state : State = state_list[0] # hacky pass-by-reference, since we want to edit it below
249
+
250
+
251
+ # If we're in selection mode, that's the only thing we can do
252
+ if state.is_selecting:
253
+
254
+ psim.TextUnformatted(f"Annotating class {state.selection_class:02d}. Click on any mesh face.")
255
+
256
+ io = psim.GetIO()
257
+ if io.MouseClicked[0]:
258
+ screen_coords = io.MousePos
259
+ pick_result = ps.pick(screen_coords=screen_coords)
260
+
261
+ # Check if we hit one of the meshes
262
+ if pick_result.is_hit and pick_result.structure_name in state.ps_structure_name_to_index_map:
263
+ if pick_result.structure_data['element_type'] != "face":
264
+ # shouldn't be possible
265
+ raise ValueError("pick returned non-face")
266
+
267
+ i_obj = state.ps_structure_name_to_index_map[pick_result.structure_name]
268
+ f_hit = pick_result.structure_data['index']
269
+
270
+ obj = state.objects[i_obj]
271
+ V = obj['V']
272
+ F = obj['F']
273
+ feat = obj['feat_np']
274
+
275
+ face_corners = V[F[f_hit,:],:]
276
+ new_anno_feat = feat[f_hit,:]
277
+ new_anno_label = state.selection_class
278
+ new_anno_pos = np.mean(face_corners, axis=0) + shift_for_ind(state, i_obj)
279
+
280
+ state.anno_feat = np.concatenate((state.anno_feat, new_anno_feat[None,:]))
281
+ state.anno_label = np.concatenate((state.anno_label, np.array((new_anno_label,))))
282
+ state.anno_pos = np.concatenate((state.anno_pos, new_anno_pos[None,:]))
283
+
284
+ state = update_annotation_viz(state)
285
+ state.is_selecting = False
286
+ needs_pred_update = True
287
+
288
+ if state.auto_update_fit:
289
+ state = update_prediction(state)
290
+ state = update_prediction_viz(state)
291
+
292
+
293
+ return
294
+
295
+ # If not selecting, build the main UI
296
+ needs_pred_update = False
297
+
298
+ psim.PushItemWidth(150)
299
+ changed, state.N_class = psim.InputInt("N_class", state.N_class, step=1)
300
+ psim.PopItemWidth()
301
+ if changed:
302
+ state = filter_old_labels(state)
303
+ state = update_annotation_viz(state)
304
+
305
+
306
+ # Check for keypress annotation
307
+ io = psim.GetIO()
308
+ class_keys = { 'w' : 0, '1' : 1, '2' : 2, '3' : 3, '4' : 4, '5' : 5, '6' : 6, '7' : 7, '8' : 8, '9' : 9,}
309
+ for c in class_keys:
310
+ if class_keys[c] >= state.N_class:
311
+ continue
312
+
313
+ if psim.IsKeyPressed(ps.get_key_code(c)):
314
+ state.is_selecting = True
315
+ state.selection_class = class_keys[c]
316
+
317
+
318
+ psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver)
319
+ if(psim.TreeNode("Annotate")):
320
+
321
+ psim.TextUnformatted("New class annotation. Select class to add add annotation for:")
322
+ psim.TextUnformatted("(alternately, press key {w,1,2,3,4...})")
323
+ for i_class in range(state.N_class):
324
+
325
+ if i_class > 0:
326
+ psim.SameLine()
327
+
328
+ if psim.Button(f"{i_class:02d}"):
329
+ state.is_selecting = True
330
+ state.selection_class = i_class
331
+
332
+
333
+ if psim.Button("Undo Last Annotation"):
334
+ state = undo_last_annotation(state)
335
+ state = update_annotation_viz(state)
336
+ needs_pred_update = True
337
+
338
+
339
+
340
+ psim.TreePop()
341
+
342
+ psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver)
343
+ if(psim.TreeNode("Fit")):
344
+
345
+ psim.PushItemWidth(150)
346
+
347
+ changed, ind = psim.Combo("Fit To", fit_to_list.index(state.fit_to), fit_to_list)
348
+ if changed:
349
+ state.fit_to = fit_methods_list[ind]
350
+ needs_pred_update = True
351
+
352
+ changed, ind = psim.Combo("Fit Method", fit_methods_list.index(state.fit_method), fit_methods_list)
353
+ if changed:
354
+ state.fit_method = fit_methods_list[ind]
355
+ needs_pred_update = True
356
+
357
+ if psim.Button("Update fit"):
358
+ state = update_prediction(state)
359
+ state = update_prediction_viz(state)
360
+
361
+ psim.SameLine()
362
+
363
+ changed, state.auto_update_fit = psim.Checkbox("Auto-update fit", state.auto_update_fit)
364
+ if changed:
365
+ needs_pred_update = True
366
+
367
+
368
+ psim.PopItemWidth()
369
+
370
+ psim.TreePop()
371
+
372
+ psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver)
373
+ if(psim.TreeNode("Visualization")):
374
+
375
+ psim.PushItemWidth(150)
376
+ changed, state.anno_radius = psim.SliderFloat("Annotation Point Radius", state.anno_radius, 0.00001, 0.02)
377
+ if changed:
378
+ state = update_annotation_viz(state)
379
+ psim.PopItemWidth()
380
+
381
+ psim.TreePop()
382
+
383
+
384
+ if needs_pred_update and state.auto_update_fit:
385
+ state = update_prediction(state)
386
+ state = update_prediction_viz(state)
387
+
388
+
389
+ def main():
390
+
391
+ state = State()
392
+
393
+ ## Parse args
394
+ parser = argparse.ArgumentParser()
395
+
396
+ parser.add_argument('--meshes', nargs='+', help='List of meshes to process.', required=True)
397
+ parser.add_argument('--n_train_subset', default=0, help='How many meshes to train on.')
398
+ parser.add_argument('--gt_label_fol', default="../data/coseg_guitar/gt", help='Path where labels are stored.')
399
+ parser.add_argument('--subsample_inputs', default=state.subsample_inputs, help='Only show a random fraction of inputs')
400
+ parser.add_argument('--per_obj_shift', default=state.per_obj_shift, help='How to space out objects in UI grid')
401
+ parser.add_argument('--grid_w', default=state.grid_w, help='Grid width')
402
+
403
+ args = parser.parse_args()
404
+
405
+
406
+ state.n_train_subset = int(args.n_train_subset)
407
+ state.subsample_inputs = int(args.subsample_inputs)
408
+ state.per_obj_shift = float(args.per_obj_shift)
409
+ state.grid_w = int(args.grid_w)
410
+
411
+ ## Load data
412
+ # First, resolve directories to load all files in directory
413
+ all_filepaths = []
414
+ print("Resolving passed directories")
415
+ for entry in args.meshes:
416
+ if os.path.isdir(entry):
417
+ dir_path = entry
418
+ print(f" processing directory {dir_path}")
419
+ for filename in os.listdir(dir_path):
420
+ file_path = os.path.join(dir_path, filename)
421
+ if os.path.isfile(file_path) and file_path.endswith(".ply") and "feat_pca" in file_path:
422
+ print(f" adding file {file_path}")
423
+ all_filepaths.append(file_path)
424
+ else:
425
+ all_filepaths.append(entry)
426
+
427
+ random.shuffle(all_filepaths)
428
+
429
+ if state.subsample_inputs != -1:
430
+ all_filepaths = all_filepaths[:state.subsample_inputs]
431
+
432
+
433
+ if state.n_train_subset != 0:
434
+
435
+ print(state.n_train_subset)
436
+
437
+ train_filepaths = all_filepaths[:state.n_train_subset]
438
+ all_filepaths = all_filepaths[state.n_train_subset:]
439
+
440
+ print(f"Loading {len(train_filepaths)} files")
441
+ state.train_objects = []
442
+ for i, file_path in enumerate(train_filepaths):
443
+ state.train_objects.append(load_mesh_and_features(file_path, i, require_gt=True, gt_label_fol=args.gt_label_fol))
444
+
445
+ state.fit_to = "TrainingSet"
446
+
447
+ # Load files
448
+ print(f"Loading {len(all_filepaths)} files")
449
+ state.objects = []
450
+ for i, file_path in enumerate(all_filepaths):
451
+ state.objects.append(load_mesh_and_features(file_path, i))
452
+
453
+
454
+ ## Set up visualization
455
+ ps.init()
456
+ ps.set_automatically_compute_scene_extents(False)
457
+ lim = viz_upper_limit(state, len(state.objects))
458
+ ps.set_length_scale(np.linalg.norm(lim) / 4.)
459
+ low = np.array((0, -1., -1.))
460
+ high = lim
461
+ ps.set_bounding_box(low, high)
462
+
463
+ for ind, o in enumerate(state.objects):
464
+ initialize_object_viz(state, o, ind)
465
+
466
+ print(f"Loaded {len(state.objects)} objects")
467
+ if state.n_train_subset != 0:
468
+ print(f"Loaded {len(state.train_objects)} training objects")
469
+
470
+ # One first prediction
471
+ # (does nothing if there is no annotatoins / training data)
472
+ state = update_prediction(state)
473
+ state = update_prediction_viz(state)
474
+
475
+ # Start the interactive UI
476
+ ps.set_user_callback(lambda : ps_callback([state]))
477
+ ps.show()
478
+
479
+
480
+ if __name__ == "__main__":
481
+ main()
482
+
PartField/applications/pack_labels_to_obj.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, fnmatch, re
2
+ import argparse
3
+
4
+ import numpy as np
5
+ import matplotlib
6
+ from matplotlib import colors as mcolors
7
+ import matplotlib.cm
8
+ import potpourri3d as pp3d
9
+ import igl
10
+ from arrgh import arrgh
11
+
12
+ def main():
13
+
14
+ parser = argparse.ArgumentParser()
15
+
16
+ parser.add_argument("--input_mesh", type=str, required=True, help="The mesh to read from from, mesh file format.")
17
+ parser.add_argument("--input_labels", type=str, required=True, help="The labels, as a text file with one entry per line")
18
+ parser.add_argument("--label_count", type=int, default=-1, help="The number of labels to use for the visualization. If -1, computed as max of given labels.")
19
+ parser.add_argument("--output", type=str, required=True, help="The obj file to write output to")
20
+
21
+ args = parser.parse_args()
22
+
23
+
24
+ # Read the mesh
25
+ V, F = igl.read_triangle_mesh(args.input_mesh)
26
+
27
+ # Read the scalar function
28
+ S = np.loadtxt(args.input_labels)
29
+
30
+ # Convert integers to scalars on [0,1]
31
+ if args.label_count == -1:
32
+ N_max = np.max(S) + 1
33
+ else:
34
+ N_max = args.label_count
35
+ S = S.astype(np.float32) / max(N_max-1, 1)
36
+
37
+ # Validate and write
38
+ if len(S.shape) != 1 or S.shape[0] != F.shape[0]:
39
+ raise ValueError(f"when scalar_on==faces, the scalar should be a length num-faces numpy array, but it has shape {S.shape[0]} and F={F.shape[0]}")
40
+
41
+ S = np.stack((S, np.zeros_like(S)), axis=-1)
42
+
43
+ pp3d.write_mesh(V, F, args.output, UV_coords=S, UV_type='per-face')
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()
PartField/applications/run_smooth_functional_map.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import numpy as np
3
+ import torch
4
+ import trimesh
5
+ import json
6
+
7
+ sys.path.append("..")
8
+ sys.path.append("../third_party/SmoothFunctionalMaps")
9
+ sys.path.append("../third_party/SmoothFunctionalMaps/pyFM")
10
+
11
+ from partfield.config import default_argument_parser, setup
12
+ from pyFM.mesh import TriMesh
13
+ from pyFM.spectral import mesh_FM_to_p2p
14
+ import DiscreteOpt
15
+
16
+
17
+ def vertex_color_map(vertices):
18
+ min_coord, max_coord = np.min(vertices, axis=0, keepdims=True), np.max(vertices, axis=0, keepdims=True)
19
+ cmap = (vertices - min_coord) / (max_coord - min_coord)
20
+ return cmap
21
+
22
+
23
+ if __name__ == '__main__':
24
+ parser = default_argument_parser()
25
+ args = parser.parse_args()
26
+ cfg = setup(args, freeze=False)
27
+
28
+ feature_dir = os.path.join("../exp_results", cfg.result_name)
29
+
30
+ all_files = cfg.dataset.all_files
31
+ assert len(all_files) % 2 == 0
32
+ num_pairs = len(all_files) // 2
33
+
34
+ device = "cuda"
35
+
36
+ output_dir = "../exp_results/correspondence/"
37
+ os.makedirs(output_dir, exist_ok=True)
38
+
39
+ for i in range(num_pairs):
40
+ file0 = all_files[2 * i]
41
+ file1 = all_files[2 * i + 1]
42
+
43
+ uid0 = file0.split(".")[-2].replace("/", "_")
44
+ uid1 = file1.split(".")[-2].replace("/", "_")
45
+
46
+ mesh0 = trimesh.load(os.path.join(feature_dir, f"input_{uid0}_0.ply"), process=True)
47
+ mesh1 = trimesh.load(os.path.join(feature_dir, f"input_{uid1}_0.ply"), process=True)
48
+
49
+ feat0 = np.load(os.path.join(feature_dir, f"part_feat_{uid0}_0_batch.npy"))
50
+ feat1 = np.load(os.path.join(feature_dir, f"part_feat_{uid1}_0_batch.npy"))
51
+
52
+ assert mesh0.vertices.shape[0] == feat0.shape[0], "num of vertices should match num of features"
53
+ assert mesh1.vertices.shape[0] == feat1.shape[0], "num of vertices should match num of features"
54
+
55
+ th_descr0 = torch.tensor(feat0, device=device, dtype=torch.float32)
56
+ th_descr1 = torch.tensor(feat1, device=device, dtype=torch.float32)
57
+
58
+ cdist_01 = torch.cdist(th_descr0, th_descr1, p=2)
59
+ p2p_10_init = cdist_01.argmin(dim=0).cpu().numpy()
60
+ p2p_01_init = cdist_01.argmin(dim=1).cpu().numpy()
61
+
62
+ fm_mesh0 = TriMesh(mesh0.vertices, mesh0.faces, area_normalize=True, center=True).process(k=200, intrinsic=True)
63
+ fm_mesh1 = TriMesh(mesh1.vertices, mesh1.faces, area_normalize=True, center=True).process(k=200, intrinsic=True)
64
+
65
+ model = DiscreteOpt.SmoothDiscreteOptimization(fm_mesh0, fm_mesh1)
66
+ model.set_params("zoomout_rhm")
67
+ model.opt_params.step = 10
68
+ model.solve_from_p2p(p2p_21=p2p_10_init, p2p_12=p2p_01_init, n_jobs=30, verbose=True)
69
+
70
+ p2p_10_FM = mesh_FM_to_p2p(model.FM_12, fm_mesh0, fm_mesh1, use_adj=True)
71
+
72
+ color0 = vertex_color_map(mesh0.vertices)
73
+ color1 = color0[p2p_10_FM]
74
+
75
+ output_mesh0 = trimesh.Trimesh(mesh0.vertices, mesh0.faces, vertex_colors=color0)
76
+ output_mesh1 = trimesh.Trimesh(mesh1.vertices, mesh1.faces, vertex_colors=color1)
77
+
78
+ output_mesh0.export(os.path.join(output_dir, f"correspondence_{uid0}_{uid1}_0.ply"))
79
+ output_mesh1.export(os.path.join(output_dir, f"correspondence_{uid0}_{uid1}_1.ply"))
80
+
PartField/applications/shape_pair.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import polyscope as ps
4
+ import polyscope.imgui as psim
5
+ import potpourri3d as pp3d
6
+ import trimesh
7
+ import igl
8
+ from dataclasses import dataclass
9
+ from simple_parsing import ArgumentParser
10
+ from arrgh import arrgh
11
+
12
+ ### For clustering
13
+ from collections import defaultdict
14
+ from sklearn.cluster import AgglomerativeClustering, DBSCAN, KMeans
15
+ from scipy.sparse import coo_matrix, csr_matrix
16
+ from scipy.spatial import KDTree
17
+ from scipy.sparse.csgraph import connected_components
18
+ from sklearn.neighbors import NearestNeighbors
19
+ import networkx as nx
20
+
21
+ from scipy.optimize import linear_sum_assignment
22
+
23
+ import os, sys
24
+ sys.path.append("..")
25
+ from partfield.utils import *
26
+
27
+ @dataclass
28
+ class Options:
29
+
30
+ """ Basic Options """
31
+ filename: str
32
+ filename_alt: str = None
33
+
34
+ """System Options"""
35
+ device: str = "cuda" # Device
36
+ debug: bool = False # enable debug checks
37
+ extras: bool = False # include extra output for viz/debugging
38
+
39
+ """ State """
40
+ mode: str = 'co-segmentation'
41
+ m: dict = None # mesh
42
+ m_alt: dict = None # second mesh
43
+
44
+ # pca mode
45
+
46
+ # feature explore mode
47
+ i_feature: int = 0
48
+
49
+ i_cluster: int = 1
50
+ i_cluster2: int = 1
51
+
52
+ i_eps: int = 0.6
53
+
54
+ ### For mixing in clustering
55
+ weight_dist = 1.0
56
+ weight_feat = 1.0
57
+
58
+ ### For clustering visualization
59
+ independent: bool = True
60
+ source_init: bool = True
61
+
62
+ feature_range: float = 0.1
63
+ continuous_explore: bool = False
64
+
65
+ viz_mode: str = "faces"
66
+
67
+ output_fol: str = "results_pair"
68
+
69
+ ### counter for screenshot
70
+ counter: int = 0
71
+
72
+ modes_list = ['feature_explore', "co-segmentation"]
73
+
74
+ def load_features(feature_filename, mesh_filename, viz_mode):
75
+
76
+ print("Reading features:")
77
+ print(f" Feature filename: {feature_filename}")
78
+ print(f" Mesh filename: {mesh_filename}")
79
+
80
+ # load features
81
+ feat = np.load(feature_filename, allow_pickle=True)
82
+ feat = feat.astype(np.float32)
83
+
84
+ # load mesh things
85
+ tm = load_mesh_util(mesh_filename)
86
+
87
+ V = np.array(tm.vertices, dtype=np.float32)
88
+ F = np.array(tm.faces)
89
+
90
+ if viz_mode == "faces":
91
+ pca_colors = np.array(tm.visual.face_colors, dtype=np.float32)
92
+ pca_colors = pca_colors[:,:3] / 255.
93
+
94
+ else:
95
+ pca_colors = np.array(tm.visual.vertex_colors, dtype=np.float32)
96
+ pca_colors = pca_colors[:,:3] / 255.
97
+
98
+ arrgh(V, F, pca_colors, feat)
99
+
100
+ return {
101
+ 'V' : V,
102
+ 'F' : F,
103
+ 'pca_colors' : pca_colors,
104
+ 'feat_np' : feat,
105
+ 'feat_pt' : torch.tensor(feat, device='cuda'),
106
+ 'trimesh' : tm,
107
+ 'label' : None,
108
+ 'num_cluster' : 1,
109
+ 'scalar' : None
110
+ }
111
+
112
+ def prep_feature_mesh(m, name='mesh'):
113
+ ps_mesh = ps.register_surface_mesh(name, m['V'], m['F'])
114
+ ps_mesh.set_selection_mode('faces_only')
115
+ m['ps_mesh'] = ps_mesh
116
+
117
+ def viz_pca_colors(m):
118
+ m['ps_mesh'].add_color_quantity('pca colors', m['pca_colors'], enabled=True, defined_on=m["viz_mode"])
119
+
120
+ def viz_feature(m, ind):
121
+ m['ps_mesh'].add_scalar_quantity('pca colors', m['feat_np'][:,ind], cmap='turbo', enabled=True, defined_on=m["viz_mode"])
122
+
123
+ def feature_distance_np(feats, query_feat):
124
+ # normalize
125
+ feats = feats / np.linalg.norm(feats,axis=1)[:,None]
126
+ query_feat = query_feat / np.linalg.norm(query_feat)
127
+ # cosine distance
128
+ cos_sim = np.dot(feats, query_feat)
129
+ cos_dist = (1 - cos_sim) / 2.
130
+ return cos_dist
131
+
132
+ def feature_distance_pt(feats, query_feat):
133
+ return (1. - torch.nn.functional.cosine_similarity(feats, query_feat[None,:], dim=-1)) / 2.
134
+
135
+
136
+ def ps_callback(opts):
137
+ m = opts.m
138
+
139
+ changed, ind = psim.Combo("Mode", modes_list.index(opts.mode), modes_list)
140
+ if changed:
141
+ opts.mode = modes_list[ind]
142
+ m['ps_mesh'].remove_all_quantities()
143
+ if opts.m_alt is not None:
144
+ opts.m_alt['ps_mesh'].remove_all_quantities()
145
+
146
+ elif opts.mode == 'feature_explore':
147
+ psim.TextUnformatted("Click on the mesh on the left")
148
+ psim.TextUnformatted("to highlight all faces within a given radius in feature space.""")
149
+
150
+ io = psim.GetIO()
151
+ if io.MouseClicked[0] or opts.continuous_explore:
152
+ screen_coords = io.MousePos
153
+ cam_params = ps.get_view_camera_parameters()
154
+
155
+ pick_result = ps.pick(screen_coords=screen_coords)
156
+
157
+ # Check if we hit one of the meshes
158
+ if pick_result.is_hit and pick_result.structure_name == "mesh":
159
+ if pick_result.structure_data['element_type'] != "face":
160
+ # shouldn't be possible
161
+ raise ValueError("pick returned non-face")
162
+
163
+ f_hit = pick_result.structure_data['index']
164
+ bary_weights = np.array(pick_result.structure_data['bary_coords'])
165
+
166
+ # get the feature via interpolation
167
+ point_feat = m['feat_np'][f_hit,:]
168
+ point_feat_pt = torch.tensor(point_feat, device='cuda')
169
+
170
+ all_dists1 = feature_distance_pt(m['feat_pt'], point_feat_pt).detach().cpu().numpy()
171
+ m['ps_mesh'].add_scalar_quantity("distance", all_dists1, cmap='blues', vminmax=(0, opts.feature_range), enabled=True, defined_on=m["viz_mode"])
172
+ opts.m['scalar'] = all_dists1
173
+
174
+ if opts.m_alt is not None:
175
+ all_dists2 = feature_distance_pt(opts.m_alt['feat_pt'], point_feat_pt).detach().cpu().numpy()
176
+ opts.m_alt['ps_mesh'].add_scalar_quantity("distance", all_dists2, cmap='blues', vminmax=(0, opts.feature_range), enabled=True, defined_on=m["viz_mode"])
177
+ opts.m_alt['scalar'] = all_dists2
178
+
179
+ else:
180
+ # not hit
181
+ pass
182
+
183
+ if psim.Button("Export"):
184
+ ### Save output
185
+ OUTPUT_FOL = opts.output_fol
186
+ fname1 = opts.filename
187
+ out_mesh_file = os.path.join(OUTPUT_FOL, fname1+'.obj')
188
+
189
+ igl.write_obj(out_mesh_file, opts.m["V"], opts.m["F"])
190
+ print("Saved '{}'.".format(out_mesh_file))
191
+
192
+ out_face_ids_file = os.path.join(OUTPUT_FOL, fname1 + '_feat_dist_' + str(opts.counter) +'.txt')
193
+ np.savetxt(out_face_ids_file, opts.m['scalar'], fmt='%f')
194
+ print("Saved '{}'.".format(out_face_ids_file))
195
+
196
+
197
+ fname2 = opts.filename_alt
198
+ out_mesh_file = os.path.join(OUTPUT_FOL, fname2+'.obj')
199
+
200
+ igl.write_obj(out_mesh_file, opts.m_alt["V"], opts.m_alt["F"])
201
+ print("Saved '{}'.".format(out_mesh_file))
202
+
203
+ out_face_ids_file = os.path.join(OUTPUT_FOL, fname2 + '_feat_dist_' + str(opts.counter) +'.txt')
204
+ np.savetxt(out_face_ids_file, opts.m_alt['scalar'], fmt='%f')
205
+ # print("Saved '{}'.".format(out_face_ids_file))
206
+
207
+ opts.counter += 1
208
+
209
+
210
+ _, opts.feature_range = psim.SliderFloat('range', opts.feature_range, v_min=0., v_max=1.0, power=3)
211
+ _, opts.continuous_explore = psim.Checkbox('continuous', opts.continuous_explore)
212
+
213
+ # TODO nsharp remember how the keycodes work
214
+ if io.KeysDown[ord('q')]:
215
+ opts.feature_range += 0.01
216
+ if io.KeysDown[ord('w')]:
217
+ opts.feature_range -= 0.01
218
+
219
+
220
+ elif opts.mode == "co-segmentation":
221
+
222
+ changed, opts.source_init = psim.Checkbox("Source Init", opts.source_init)
223
+ changed, opts.independent = psim.Checkbox("Independent", opts.independent)
224
+
225
+ psim.TextUnformatted("Use the slider to toggle the number of desired clusters.")
226
+ cluster_changed, opts.i_cluster = psim.SliderInt("num clusters for model1", opts.i_cluster, v_min=1, v_max=30)
227
+ cluster_changed, opts.i_cluster2 = psim.SliderInt("num clusters for model2", opts.i_cluster2, v_min=1, v_max=30)
228
+
229
+ # if cluster_changed:
230
+ if psim.Button("Recompute"):
231
+
232
+ ### Run clustering algorithm
233
+
234
+ ### Mesh 1
235
+ num_clusters1 = opts.i_cluster
236
+ point_feat1 = m['feat_np']
237
+ point_feat1 = point_feat1 / np.linalg.norm(point_feat1, axis=-1, keepdims=True)
238
+ clustering1 = KMeans(n_clusters=num_clusters1, random_state=0, n_init="auto").fit(point_feat1)
239
+
240
+ ### Get feature means per cluster
241
+ feature_means1 = []
242
+ for j in range(num_clusters1):
243
+ all_cluster_feat = point_feat1[clustering1.labels_==j]
244
+ mean_feat = np.mean(all_cluster_feat, axis=0)
245
+ feature_means1.append(mean_feat)
246
+
247
+ feature_means1 = np.array(feature_means1)
248
+ tree = KDTree(feature_means1)
249
+
250
+
251
+ if opts.source_init:
252
+ num_clusters2 = opts.i_cluster
253
+ init_mode = np.array(feature_means1)
254
+
255
+ ## default is kmeans++
256
+ else:
257
+ num_clusters2 = opts.i_cluster2
258
+ init_mode = "k-means++"
259
+
260
+ ### Mesh 2
261
+ point_feat2 = opts.m_alt['feat_np']
262
+ point_feat2 = point_feat2 / np.linalg.norm(point_feat2, axis=-1, keepdims=True)
263
+
264
+ clustering2 = KMeans(n_clusters=num_clusters2, random_state=0, init=init_mode).fit(point_feat2)
265
+
266
+ ### Get feature means per cluster
267
+ feature_means2 = []
268
+ for j in range(num_clusters2):
269
+ all_cluster_feat = point_feat2[clustering2.labels_==j]
270
+ mean_feat = np.mean(all_cluster_feat, axis=0)
271
+ feature_means2.append(mean_feat)
272
+
273
+ feature_means2 = np.array(feature_means2)
274
+ _, nn_idx = tree.query(feature_means2, k=1)
275
+
276
+ print(nn_idx)
277
+ print("Both KMeans")
278
+ print(np.unique(clustering1.labels_))
279
+ print(np.unique(clustering2.labels_))
280
+
281
+ relabelled_2 = nn_idx[clustering2.labels_]
282
+
283
+ print(np.unique(relabelled_2))
284
+ print()
285
+
286
+ m['ps_mesh'].add_scalar_quantity("cluster_both_kmeans", clustering1.labels_, cmap='turbo', vminmax=(0, num_clusters1-1), enabled=True, defined_on=m["viz_mode"])
287
+ opts.m['label'] = clustering1.labels_
288
+ opts.m['num_cluster'] = num_clusters1
289
+
290
+ if opts.independent:
291
+ opts.m_alt['ps_mesh'].add_scalar_quantity("cluster", clustering2.labels_, cmap='turbo', vminmax=(0, num_clusters2-1), enabled=True, defined_on=m["viz_mode"])
292
+ opts.m_alt['label'] = clustering2.labels_
293
+ opts.m_alt['num_cluster'] = num_clusters2
294
+ else:
295
+ opts.m_alt['ps_mesh'].add_scalar_quantity("cluster", relabelled_2, cmap='turbo', vminmax=(0, num_clusters1-1), enabled=True, defined_on=m["viz_mode"])
296
+ opts.m_alt['label'] = relabelled_2
297
+ opts.m_alt['num_cluster'] = num_clusters1
298
+
299
+
300
+ if psim.Button("Export"):
301
+ ### Save output
302
+ OUTPUT_FOL = opts.output_fol
303
+ fname1 = opts.filename
304
+ out_mesh_file = os.path.join(OUTPUT_FOL, fname1+'.obj')
305
+
306
+ igl.write_obj(out_mesh_file, opts.m["V"], opts.m["F"])
307
+ print("Saved '{}'.".format(out_mesh_file))
308
+
309
+ if m["viz_mode"] == "faces":
310
+ out_face_ids_file = os.path.join(OUTPUT_FOL, fname1 + "_" + str(opts.m['num_cluster']) + '_pred_face_ids.txt')
311
+ else:
312
+ out_face_ids_file = os.path.join(OUTPUT_FOL, fname1 + "_" + str(opts.m['num_cluster']) + '_pred_vertices_ids.txt')
313
+
314
+ np.savetxt(out_face_ids_file, opts.m['label'], fmt='%d')
315
+ print("Saved '{}'.".format(out_face_ids_file))
316
+
317
+
318
+ fname2 = opts.filename_alt
319
+ out_mesh_file = os.path.join(OUTPUT_FOL, fname2 +'.obj')
320
+
321
+ igl.write_obj(out_mesh_file, opts.m_alt["V"], opts.m_alt["F"])
322
+ print("Saved '{}'.".format(out_mesh_file))
323
+
324
+ if m["viz_mode"] == "faces":
325
+ out_face_ids_file = os.path.join(OUTPUT_FOL, fname2 + "_" + str(opts.m_alt['num_cluster']) + '_pred_face_ids.txt')
326
+ else:
327
+ out_face_ids_file = os.path.join(OUTPUT_FOL, fname2 + "_" + str(opts.m_alt['num_cluster']) + '_pred_vertices_ids.txt')
328
+
329
+ np.savetxt(out_face_ids_file, opts.m_alt['label'], fmt='%d')
330
+ print("Saved '{}'.".format(out_face_ids_file))
331
+
332
+
333
+ def main():
334
+ ## Parse args
335
+ # Uses simple_parsing library to automatically construct parser from the dataclass Options
336
+ parser = ArgumentParser()
337
+ parser.add_arguments(Options, dest="options")
338
+ parser.add_argument('--data_root', default="../exp_results/partfield_features/trellis", help='Path the model features are stored.')
339
+ args = parser.parse_args()
340
+ opts: Options = args.options
341
+
342
+ DATA_ROOT = args.data_root
343
+
344
+ shape_1 = opts.filename
345
+ shape_2 = opts.filename_alt
346
+
347
+ if os.path.exists(os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0.npy")):
348
+ feature_fname1 = os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0.npy")
349
+ feature_fname2 = os.path.join(DATA_ROOT, "part_feat_"+ shape_2 + "_0.npy")
350
+
351
+ mesh_fname1 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_1 + "_0.ply")
352
+ mesh_fname2 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_2 + "_0.ply")
353
+ else:
354
+ feature_fname1 = os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0_batch.npy")
355
+ feature_fname2 = os.path.join(DATA_ROOT, "part_feat_"+ shape_2 + "_0_batch.npy")
356
+
357
+ mesh_fname1 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_1 + "_0.ply")
358
+ mesh_fname2 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_2 + "_0.ply")
359
+
360
+ #### To save output ####
361
+ os.makedirs(opts.output_fol, exist_ok=True)
362
+ ########################
363
+
364
+ # Initialize
365
+ ps.init()
366
+
367
+ mesh_dict = load_features(feature_fname1, mesh_fname1, opts.viz_mode)
368
+ prep_feature_mesh(mesh_dict)
369
+ mesh_dict["viz_mode"] = opts.viz_mode
370
+ opts.m = mesh_dict
371
+
372
+ mesh_dict_alt = load_features(feature_fname2, mesh_fname2, opts.viz_mode)
373
+ prep_feature_mesh(mesh_dict_alt, name='mesh_alt')
374
+ mesh_dict_alt['ps_mesh'].translate((2.5, 0., 0.))
375
+ mesh_dict_alt["viz_mode"] = opts.viz_mode
376
+ opts.m_alt = mesh_dict_alt
377
+
378
+ # Start the interactive UI
379
+ ps.set_user_callback(lambda : ps_callback(opts))
380
+ ps.show()
381
+
382
+
383
+ if __name__ == "__main__":
384
+ main()
385
+
PartField/applications/single_shape.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import polyscope as ps
4
+ import polyscope.imgui as psim
5
+ import potpourri3d as pp3d
6
+ import trimesh
7
+ import igl
8
+ from dataclasses import dataclass
9
+ from simple_parsing import ArgumentParser
10
+ from arrgh import arrgh
11
+
12
+ ### For clustering
13
+ from collections import defaultdict
14
+ from sklearn.cluster import AgglomerativeClustering, DBSCAN, KMeans
15
+ from scipy.sparse import coo_matrix, csr_matrix
16
+ from scipy.spatial import KDTree
17
+ from scipy.sparse.csgraph import connected_components
18
+ from sklearn.neighbors import NearestNeighbors
19
+ import networkx as nx
20
+
21
+ from scipy.optimize import linear_sum_assignment
22
+
23
+ import os, sys
24
+ sys.path.append("..")
25
+ from partfield.utils import *
26
+
27
+ @dataclass
28
+ class Options:
29
+
30
+ """ Basic Options """
31
+ filename: str
32
+
33
+ """System Options"""
34
+ device: str = "cuda" # Device
35
+ debug: bool = False # enable debug checks
36
+ extras: bool = False # include extra output for viz/debugging
37
+
38
+ """ State """
39
+ mode: str = 'pca'
40
+ m: dict = None # mesh
41
+
42
+ # pca mode
43
+
44
+ # feature explore mode
45
+ i_feature: int = 0
46
+
47
+ i_cluster: int = 1
48
+
49
+ i_eps: int = 0.6
50
+
51
+ ### For mixing in clustering
52
+ weight_dist = 1.0
53
+ weight_feat = 1.0
54
+
55
+ ### For clustering visualization
56
+ feature_range: float = 0.1
57
+ continuous_explore: bool = False
58
+
59
+ viz_mode: str = "faces"
60
+
61
+ output_fol: str = "results_single"
62
+
63
+ ### For adj_matrix
64
+ adj_mode: str = "Vanilla"
65
+ add_knn_edges: bool = False
66
+
67
+ ### counter for screenshot
68
+ counter: int = 0
69
+
70
+ modes_list = ['pca', 'feature_viz', 'cluster_agglo', 'cluster_kmeans']
71
+ adj_mode_list = ["Vanilla", "Face_MST", "CC_MST"]
72
+
73
+ #### For clustering
74
+ class UnionFind:
75
+ def __init__(self, n):
76
+ self.parent = list(range(n))
77
+ self.rank = [1] * n
78
+
79
+ def find(self, x):
80
+ if self.parent[x] != x:
81
+ self.parent[x] = self.find(self.parent[x])
82
+ return self.parent[x]
83
+
84
+ def union(self, x, y):
85
+ rootX = self.find(x)
86
+ rootY = self.find(y)
87
+
88
+ if rootX != rootY:
89
+ if self.rank[rootX] > self.rank[rootY]:
90
+ self.parent[rootY] = rootX
91
+ elif self.rank[rootX] < self.rank[rootY]:
92
+ self.parent[rootX] = rootY
93
+ else:
94
+ self.parent[rootY] = rootX
95
+ self.rank[rootX] += 1
96
+
97
+ #####################################
98
+ ## Face adjacency computation options
99
+ #####################################
100
+ def construct_face_adjacency_matrix_ccmst(face_list, vertices, k=10, with_knn=True):
101
+ """
102
+ Given a list of faces (each face is a 3-tuple of vertex indices),
103
+ construct a face-based adjacency matrix of shape (num_faces, num_faces).
104
+
105
+ Two faces are adjacent if they share an edge (the "mesh adjacency").
106
+ If multiple connected components remain, we:
107
+ 1) Compute the centroid of each connected component as the mean of all face centroids.
108
+ 2) Use a KNN graph (k=10) based on centroid distances on each connected component.
109
+ 3) Compute MST of that KNN graph.
110
+ 4) Add MST edges that connect different components as "dummy" edges
111
+ in the face adjacency matrix, ensuring one connected component. The selected face for
112
+ each connected component is the face closest to the component centroid.
113
+
114
+ Parameters
115
+ ----------
116
+ face_list : list of tuples
117
+ List of faces, each face is a tuple (v0, v1, v2) of vertex indices.
118
+ vertices : np.ndarray of shape (num_vertices, 3)
119
+ Array of vertex coordinates.
120
+ k : int, optional
121
+ Number of neighbors to use in centroid KNN. Default is 10.
122
+
123
+ Returns
124
+ -------
125
+ face_adjacency : scipy.sparse.csr_matrix
126
+ A CSR sparse matrix of shape (num_faces, num_faces),
127
+ containing 1s for adjacent faces (shared-edge adjacency)
128
+ plus dummy edges ensuring a single connected component.
129
+ """
130
+ num_faces = len(face_list)
131
+ if num_faces == 0:
132
+ # Return an empty matrix if no faces
133
+ return csr_matrix((0, 0))
134
+
135
+ #--------------------------------------------------------------------------
136
+ # 1) Build adjacency based on shared edges.
137
+ # (Same logic as the original code, plus import statements.)
138
+ #--------------------------------------------------------------------------
139
+ edge_to_faces = defaultdict(list)
140
+ uf = UnionFind(num_faces)
141
+ for f_idx, (v0, v1, v2) in enumerate(face_list):
142
+ # Sort each edge’s endpoints so (i, j) == (j, i)
143
+ edges = [
144
+ tuple(sorted((v0, v1))),
145
+ tuple(sorted((v1, v2))),
146
+ tuple(sorted((v2, v0)))
147
+ ]
148
+ for e in edges:
149
+ edge_to_faces[e].append(f_idx)
150
+
151
+ row = []
152
+ col = []
153
+ for edge, face_indices in edge_to_faces.items():
154
+ unique_faces = list(set(face_indices))
155
+ if len(unique_faces) > 1:
156
+ # For every pair of distinct faces that share this edge,
157
+ # mark them as mutually adjacent
158
+ for i in range(len(unique_faces)):
159
+ for j in range(i + 1, len(unique_faces)):
160
+ fi = unique_faces[i]
161
+ fj = unique_faces[j]
162
+ row.append(fi)
163
+ col.append(fj)
164
+ row.append(fj)
165
+ col.append(fi)
166
+ uf.union(fi, fj)
167
+
168
+ data = np.ones(len(row), dtype=np.int8)
169
+ face_adjacency = coo_matrix(
170
+ (data, (row, col)), shape=(num_faces, num_faces)
171
+ ).tocsr()
172
+
173
+ #--------------------------------------------------------------------------
174
+ # 2) Check if the graph from shared edges is already connected.
175
+ #--------------------------------------------------------------------------
176
+ n_components = 0
177
+ for i in range(num_faces):
178
+ if uf.find(i) == i:
179
+ n_components += 1
180
+ print("n_components", n_components)
181
+
182
+ if n_components == 1:
183
+ # Already a single connected component, no need for dummy edges
184
+ return face_adjacency
185
+
186
+ #--------------------------------------------------------------------------
187
+ # 3) Compute centroids of each face for building a KNN graph.
188
+ #--------------------------------------------------------------------------
189
+ face_centroids = []
190
+ for (v0, v1, v2) in face_list:
191
+ centroid = (vertices[v0] + vertices[v1] + vertices[v2]) / 3.0
192
+ face_centroids.append(centroid)
193
+ face_centroids = np.array(face_centroids)
194
+
195
+ #--------------------------------------------------------------------------
196
+ # 4b) Build a KNN graph on connected components
197
+ #--------------------------------------------------------------------------
198
+ # Group faces by their root representative in the Union-Find structure
199
+ component_dict = {}
200
+ for face_idx in range(num_faces):
201
+ root = uf.find(face_idx)
202
+ if root not in component_dict:
203
+ component_dict[root] = set()
204
+ component_dict[root].add(face_idx)
205
+
206
+ connected_components = list(component_dict.values())
207
+
208
+ print("Using connected component MST.")
209
+ component_centroid_face_idx = []
210
+ connected_component_centroids = []
211
+ knn = NearestNeighbors(n_neighbors=1, algorithm='auto')
212
+ for component in connected_components:
213
+ curr_component_faces = list(component)
214
+ curr_component_face_centroids = face_centroids[curr_component_faces]
215
+ component_centroid = np.mean(curr_component_face_centroids, axis=0)
216
+
217
+ ### Assign a face closest to the centroid
218
+ face_idx = curr_component_faces[np.argmin(np.linalg.norm(curr_component_face_centroids-component_centroid, axis=-1))]
219
+
220
+ connected_component_centroids.append(component_centroid)
221
+ component_centroid_face_idx.append(face_idx)
222
+
223
+ component_centroid_face_idx = np.array(component_centroid_face_idx)
224
+ connected_component_centroids = np.array(connected_component_centroids)
225
+
226
+ if n_components < k:
227
+ knn = NearestNeighbors(n_neighbors=n_components, algorithm='auto')
228
+ else:
229
+ knn = NearestNeighbors(n_neighbors=k, algorithm='auto')
230
+ knn.fit(connected_component_centroids)
231
+ distances, indices = knn.kneighbors(connected_component_centroids)
232
+
233
+ #--------------------------------------------------------------------------
234
+ # 5) Build a weighted graph in NetworkX using centroid-distances as edges
235
+ #--------------------------------------------------------------------------
236
+ G = nx.Graph()
237
+ # Add each face as a node in the graph
238
+ G.add_nodes_from(range(num_faces))
239
+
240
+ # For each face i, add edges (i -> j) for each neighbor j in the KNN
241
+ for idx1 in range(n_components):
242
+ i = component_centroid_face_idx[idx1]
243
+ for idx2, dist in zip(indices[idx1], distances[idx1]):
244
+ j = component_centroid_face_idx[idx2]
245
+ if i == j:
246
+ continue # skip self-loop
247
+ # Add an undirected edge with 'weight' = distance
248
+ # NetworkX handles parallel edges gracefully via last add_edge,
249
+ # but it typically overwrites the weight if (i, j) already exists.
250
+ G.add_edge(i, j, weight=dist)
251
+
252
+ #--------------------------------------------------------------------------
253
+ # 6) Compute MST on that KNN graph
254
+ #--------------------------------------------------------------------------
255
+ mst = nx.minimum_spanning_tree(G, weight='weight')
256
+ # Sort MST edges by ascending weight, so we add the shortest edges first
257
+ mst_edges_sorted = sorted(
258
+ mst.edges(data=True), key=lambda e: e[2]['weight']
259
+ )
260
+ print("mst edges sorted", len(mst_edges_sorted))
261
+ #--------------------------------------------------------------------------
262
+ # 7) Use a union-find structure to add MST edges only if they
263
+ # connect two currently disconnected components of the adjacency matrix
264
+ #--------------------------------------------------------------------------
265
+
266
+ # Convert face_adjacency to LIL format for efficient edge addition
267
+ adjacency_lil = face_adjacency.tolil()
268
+
269
+ # Now, step through MST edges in ascending order
270
+ for (u, v, attr) in mst_edges_sorted:
271
+ if uf.find(u) != uf.find(v):
272
+ # These belong to different components, so unify them
273
+ uf.union(u, v)
274
+ # And add a "dummy" edge to our adjacency matrix
275
+ adjacency_lil[u, v] = 1
276
+ adjacency_lil[v, u] = 1
277
+
278
+ # Convert back to CSR format and return
279
+ face_adjacency = adjacency_lil.tocsr()
280
+
281
+ if with_knn:
282
+ print("Adding KNN edges.")
283
+ ### Add KNN edges graph too
284
+ dummy_row = []
285
+ dummy_col = []
286
+ for idx1 in range(n_components):
287
+ i = component_centroid_face_idx[idx1]
288
+ for idx2 in indices[idx1]:
289
+ j = component_centroid_face_idx[idx2]
290
+ dummy_row.extend([i, j])
291
+ dummy_col.extend([j, i]) ### duplicates are handled by coo
292
+
293
+ dummy_data = np.ones(len(dummy_row), dtype=np.int16)
294
+ dummy_mat = coo_matrix(
295
+ (dummy_data, (dummy_row, dummy_col)),
296
+ shape=(num_faces, num_faces)
297
+ ).tocsr()
298
+ face_adjacency = face_adjacency + dummy_mat
299
+ ###########################
300
+
301
+ return face_adjacency
302
+ #########################
303
+
304
+ def construct_face_adjacency_matrix_facemst(face_list, vertices, k=10, with_knn=True):
305
+ """
306
+ Given a list of faces (each face is a 3-tuple of vertex indices),
307
+ construct a face-based adjacency matrix of shape (num_faces, num_faces).
308
+
309
+ Two faces are adjacent if they share an edge (the "mesh adjacency").
310
+ If multiple connected components remain, we:
311
+ 1) Compute the centroid of each face.
312
+ 2) Use a KNN graph (k=10) based on centroid distances.
313
+ 3) Compute MST of that KNN graph.
314
+ 4) Add MST edges that connect different components as "dummy" edges
315
+ in the face adjacency matrix, ensuring one connected component.
316
+
317
+ Parameters
318
+ ----------
319
+ face_list : list of tuples
320
+ List of faces, each face is a tuple (v0, v1, v2) of vertex indices.
321
+ vertices : np.ndarray of shape (num_vertices, 3)
322
+ Array of vertex coordinates.
323
+ k : int, optional
324
+ Number of neighbors to use in centroid KNN. Default is 10.
325
+
326
+ Returns
327
+ -------
328
+ face_adjacency : scipy.sparse.csr_matrix
329
+ A CSR sparse matrix of shape (num_faces, num_faces),
330
+ containing 1s for adjacent faces (shared-edge adjacency)
331
+ plus dummy edges ensuring a single connected component.
332
+ """
333
+ num_faces = len(face_list)
334
+ if num_faces == 0:
335
+ # Return an empty matrix if no faces
336
+ return csr_matrix((0, 0))
337
+
338
+ #--------------------------------------------------------------------------
339
+ # 1) Build adjacency based on shared edges.
340
+ # (Same logic as the original code, plus import statements.)
341
+ #--------------------------------------------------------------------------
342
+ edge_to_faces = defaultdict(list)
343
+ uf = UnionFind(num_faces)
344
+ for f_idx, (v0, v1, v2) in enumerate(face_list):
345
+ # Sort each edge’s endpoints so (i, j) == (j, i)
346
+ edges = [
347
+ tuple(sorted((v0, v1))),
348
+ tuple(sorted((v1, v2))),
349
+ tuple(sorted((v2, v0)))
350
+ ]
351
+ for e in edges:
352
+ edge_to_faces[e].append(f_idx)
353
+
354
+ row = []
355
+ col = []
356
+ for edge, face_indices in edge_to_faces.items():
357
+ unique_faces = list(set(face_indices))
358
+ if len(unique_faces) > 1:
359
+ # For every pair of distinct faces that share this edge,
360
+ # mark them as mutually adjacent
361
+ for i in range(len(unique_faces)):
362
+ for j in range(i + 1, len(unique_faces)):
363
+ fi = unique_faces[i]
364
+ fj = unique_faces[j]
365
+ row.append(fi)
366
+ col.append(fj)
367
+ row.append(fj)
368
+ col.append(fi)
369
+ uf.union(fi, fj)
370
+
371
+ data = np.ones(len(row), dtype=np.int8)
372
+ face_adjacency = coo_matrix(
373
+ (data, (row, col)), shape=(num_faces, num_faces)
374
+ ).tocsr()
375
+
376
+ #--------------------------------------------------------------------------
377
+ # 2) Check if the graph from shared edges is already connected.
378
+ #--------------------------------------------------------------------------
379
+ n_components = 0
380
+ for i in range(num_faces):
381
+ if uf.find(i) == i:
382
+ n_components += 1
383
+ print("n_components", n_components)
384
+
385
+ if n_components == 1:
386
+ # Already a single connected component, no need for dummy edges
387
+ return face_adjacency
388
+ #--------------------------------------------------------------------------
389
+ # 3) Compute centroids of each face for building a KNN graph.
390
+ #--------------------------------------------------------------------------
391
+ face_centroids = []
392
+ for (v0, v1, v2) in face_list:
393
+ centroid = (vertices[v0] + vertices[v1] + vertices[v2]) / 3.0
394
+ face_centroids.append(centroid)
395
+ face_centroids = np.array(face_centroids)
396
+
397
+ #--------------------------------------------------------------------------
398
+ # 4) Build a KNN graph (k=10) over face centroids using scikit‐learn
399
+ #--------------------------------------------------------------------------
400
+ knn = NearestNeighbors(n_neighbors=k, algorithm='auto')
401
+ knn.fit(face_centroids)
402
+ distances, indices = knn.kneighbors(face_centroids)
403
+ # 'distances[i]' are the distances from face i to each of its 'k' neighbors
404
+ # 'indices[i]' are the face indices of those neighbors
405
+
406
+ #--------------------------------------------------------------------------
407
+ # 5) Build a weighted graph in NetworkX using centroid-distances as edges
408
+ #--------------------------------------------------------------------------
409
+ G = nx.Graph()
410
+ # Add each face as a node in the graph
411
+ G.add_nodes_from(range(num_faces))
412
+
413
+ # For each face i, add edges (i -> j) for each neighbor j in the KNN
414
+ for i in range(num_faces):
415
+ for j, dist in zip(indices[i], distances[i]):
416
+ if i == j:
417
+ continue # skip self-loop
418
+ # Add an undirected edge with 'weight' = distance
419
+ # NetworkX handles parallel edges gracefully via last add_edge,
420
+ # but it typically overwrites the weight if (i, j) already exists.
421
+ G.add_edge(i, j, weight=dist)
422
+
423
+ #--------------------------------------------------------------------------
424
+ # 6) Compute MST on that KNN graph
425
+ #--------------------------------------------------------------------------
426
+ mst = nx.minimum_spanning_tree(G, weight='weight')
427
+ # Sort MST edges by ascending weight, so we add the shortest edges first
428
+ mst_edges_sorted = sorted(
429
+ mst.edges(data=True), key=lambda e: e[2]['weight']
430
+ )
431
+ print("mst edges sorted", len(mst_edges_sorted))
432
+ #--------------------------------------------------------------------------
433
+ # 7) Use a union-find structure to add MST edges only if they
434
+ # connect two currently disconnected components of the adjacency matrix
435
+ #--------------------------------------------------------------------------
436
+
437
+ # Convert face_adjacency to LIL format for efficient edge addition
438
+ adjacency_lil = face_adjacency.tolil()
439
+
440
+ # Now, step through MST edges in ascending order
441
+ for (u, v, attr) in mst_edges_sorted:
442
+ if uf.find(u) != uf.find(v):
443
+ # These belong to different components, so unify them
444
+ uf.union(u, v)
445
+ # And add a "dummy" edge to our adjacency matrix
446
+ adjacency_lil[u, v] = 1
447
+ adjacency_lil[v, u] = 1
448
+
449
+ # Convert back to CSR format and return
450
+ face_adjacency = adjacency_lil.tocsr()
451
+
452
+ if with_knn:
453
+ print("Adding KNN edges.")
454
+ ### Add KNN edges graph too
455
+ dummy_row = []
456
+ dummy_col = []
457
+ for i in range(num_faces):
458
+ for j in indices[i]:
459
+ dummy_row.extend([i, j])
460
+ dummy_col.extend([j, i]) ### duplicates are handled by coo
461
+
462
+ dummy_data = np.ones(len(dummy_row), dtype=np.int16)
463
+ dummy_mat = coo_matrix(
464
+ (dummy_data, (dummy_row, dummy_col)),
465
+ shape=(num_faces, num_faces)
466
+ ).tocsr()
467
+ face_adjacency = face_adjacency + dummy_mat
468
+ ###########################
469
+
470
+ return face_adjacency
471
+
472
+ def construct_face_adjacency_matrix_naive(face_list):
473
+ """
474
+ Given a list of faces (each face is a 3-tuple of vertex indices),
475
+ construct a face-based adjacency matrix of shape (num_faces, num_faces).
476
+ Two faces are adjacent if they share an edge.
477
+
478
+ If multiple connected components exist, dummy edges are added to
479
+ turn them into a single connected component. Edges are added naively by
480
+ randomly selecting a face and connecting consecutive components -- (comp_i, comp_i+1) ...
481
+
482
+ Parameters
483
+ ----------
484
+ face_list : list of tuples
485
+ List of faces, each face is a tuple (v0, v1, v2) of vertex indices.
486
+
487
+ Returns
488
+ -------
489
+ face_adjacency : scipy.sparse.csr_matrix
490
+ A CSR sparse matrix of shape (num_faces, num_faces),
491
+ containing 1s for adjacent faces and 0s otherwise.
492
+ Additional edges are added if the faces are in multiple components.
493
+ """
494
+
495
+ num_faces = len(face_list)
496
+ if num_faces == 0:
497
+ # Return an empty matrix if no faces
498
+ return csr_matrix((0, 0))
499
+
500
+ # Step 1: Map each undirected edge -> list of face indices that contain that edge
501
+ edge_to_faces = defaultdict(list)
502
+
503
+ # Populate the edge_to_faces dictionary
504
+ for f_idx, (v0, v1, v2) in enumerate(face_list):
505
+ # For an edge, we always store its endpoints in sorted order
506
+ # to avoid duplication (e.g. edge (2,5) is the same as (5,2)).
507
+ edges = [
508
+ tuple(sorted((v0, v1))),
509
+ tuple(sorted((v1, v2))),
510
+ tuple(sorted((v2, v0)))
511
+ ]
512
+ for e in edges:
513
+ edge_to_faces[e].append(f_idx)
514
+
515
+ # Step 2: Build the adjacency (row, col) lists among faces
516
+ row = []
517
+ col = []
518
+ for e, faces_sharing_e in edge_to_faces.items():
519
+ # If an edge is shared by multiple faces, make each pair of those faces adjacent
520
+ f_indices = list(set(faces_sharing_e)) # unique face indices for this edge
521
+ if len(f_indices) > 1:
522
+ # For each pair of faces, mark them as adjacent
523
+ for i in range(len(f_indices)):
524
+ for j in range(i + 1, len(f_indices)):
525
+ f_i = f_indices[i]
526
+ f_j = f_indices[j]
527
+ row.append(f_i)
528
+ col.append(f_j)
529
+ row.append(f_j)
530
+ col.append(f_i)
531
+
532
+ # Create a COO matrix, then convert it to CSR
533
+ data = np.ones(len(row), dtype=np.int8)
534
+ face_adjacency = coo_matrix(
535
+ (data, (row, col)),
536
+ shape=(num_faces, num_faces)
537
+ ).tocsr()
538
+
539
+ # Step 3: Ensure single connected component
540
+ # Use connected_components to see how many components exist
541
+ n_components, labels = connected_components(face_adjacency, directed=False)
542
+
543
+ if n_components > 1:
544
+ # We have multiple components; let's "connect" them via dummy edges
545
+ # The simplest approach is to pick one face from each component
546
+ # and connect them sequentially to enforce a single component.
547
+ component_representatives = []
548
+
549
+ for comp_id in range(n_components):
550
+ # indices of faces in this component
551
+ faces_in_comp = np.where(labels == comp_id)[0]
552
+ if len(faces_in_comp) > 0:
553
+ # take the first face in this component as a representative
554
+ component_representatives.append(faces_in_comp[0])
555
+
556
+ # Now, add edges between consecutive representatives
557
+ dummy_row = []
558
+ dummy_col = []
559
+ for i in range(len(component_representatives) - 1):
560
+ f_i = component_representatives[i]
561
+ f_j = component_representatives[i + 1]
562
+ dummy_row.extend([f_i, f_j])
563
+ dummy_col.extend([f_j, f_i])
564
+
565
+ if dummy_row:
566
+ dummy_data = np.ones(len(dummy_row), dtype=np.int8)
567
+ dummy_mat = coo_matrix(
568
+ (dummy_data, (dummy_row, dummy_col)),
569
+ shape=(num_faces, num_faces)
570
+ ).tocsr()
571
+ face_adjacency = face_adjacency + dummy_mat
572
+
573
+ return face_adjacency
574
+ #####################################
575
+
576
+ def load_features(feature_filename, mesh_filename, viz_mode):
577
+
578
+ print("Reading features:")
579
+ print(f" Feature filename: {feature_filename}")
580
+ print(f" Mesh filename: {mesh_filename}")
581
+
582
+ # load features
583
+ feat = np.load(feature_filename, allow_pickle=True)
584
+ feat = feat.astype(np.float32)
585
+
586
+ # load mesh things
587
+ tm = load_mesh_util(mesh_filename)
588
+
589
+ V = np.array(tm.vertices, dtype=np.float32)
590
+ F = np.array(tm.faces)
591
+
592
+ if viz_mode == "faces":
593
+ pca_colors = np.array(tm.visual.face_colors, dtype=np.float32)
594
+ pca_colors = pca_colors[:,:3] / 255.
595
+
596
+ else:
597
+ pca_colors = np.array(tm.visual.vertex_colors, dtype=np.float32)
598
+ pca_colors = pca_colors[:,:3] / 255.
599
+
600
+ arrgh(V, F, pca_colors, feat)
601
+
602
+ print(F)
603
+ print(V[F[1][0]])
604
+ print(V[F[1][1]])
605
+ print(V[F[1][2]])
606
+
607
+ return {
608
+ 'V' : V,
609
+ 'F' : F,
610
+ 'pca_colors' : pca_colors,
611
+ 'feat_np' : feat,
612
+ 'feat_pt' : torch.tensor(feat, device='cuda'),
613
+ 'trimesh' : tm,
614
+ 'label' : None,
615
+ 'num_cluster' : 1,
616
+ 'scalar' : None
617
+ }
618
+
619
+ def prep_feature_mesh(m, name='mesh'):
620
+ ps_mesh = ps.register_surface_mesh(name, m['V'], m['F'])
621
+ ps_mesh.set_selection_mode('faces_only')
622
+ m['ps_mesh'] = ps_mesh
623
+
624
+ def viz_pca_colors(m):
625
+ m['ps_mesh'].add_color_quantity('pca colors', m['pca_colors'], enabled=True, defined_on=m["viz_mode"])
626
+
627
+ def viz_feature(m, ind):
628
+ m['ps_mesh'].add_scalar_quantity('pca colors', m['feat_np'][:,ind], cmap='turbo', enabled=True, defined_on=m["viz_mode"])
629
+
630
+ def feature_distance_np(feats, query_feat):
631
+ # normalize
632
+ feats = feats / np.linalg.norm(feats,axis=1)[:,None]
633
+ query_feat = query_feat / np.linalg.norm(query_feat)
634
+ # cosine distance
635
+ cos_sim = np.dot(feats, query_feat)
636
+ cos_dist = (1 - cos_sim) / 2.
637
+ return cos_dist
638
+
639
+ def feature_distance_pt(feats, query_feat):
640
+ return (1. - torch.nn.functional.cosine_similarity(feats, query_feat[None,:], dim=-1)) / 2.
641
+
642
+
643
+ def ps_callback(opts):
644
+ m = opts.m
645
+
646
+ changed, ind = psim.Combo("Mode", modes_list.index(opts.mode), modes_list)
647
+ if changed:
648
+ opts.mode = modes_list[ind]
649
+ m['ps_mesh'].remove_all_quantities()
650
+
651
+ if opts.mode == 'pca':
652
+ psim.TextUnformatted("""3-dim PCA embeddeding of features is shown as rgb color""")
653
+ viz_pca_colors(m)
654
+
655
+ elif opts.mode == 'feature_viz':
656
+ psim.TextUnformatted("""Use the slider to scrub through all features.\nCtrl-click to type a particular index.""")
657
+
658
+ this_changed, opts.i_feature = psim.SliderInt("feature index", opts.i_feature, v_min=0, v_max=(m['feat_np'].shape[-1]-1))
659
+ this_changed = this_changed or changed
660
+
661
+ if this_changed:
662
+ viz_feature(m, opts.i_feature)
663
+
664
+ elif opts.mode == "cluster_agglo":
665
+ psim.TextUnformatted("""Use the slider to toggle the number of desired clusters.""")
666
+ cluster_changed, opts.i_cluster = psim.SliderInt("number of clusters", opts.i_cluster, v_min=1, v_max=30)
667
+
668
+ ### To handle different face adjacency options
669
+ mode_changed, ind = psim.Combo("Adj Matrix Def", adj_mode_list.index(opts.adj_mode), adj_mode_list)
670
+ knn_changed, opts.add_knn_edges = psim.Checkbox("Add KNN edges", opts.add_knn_edges)
671
+
672
+ if mode_changed:
673
+ opts.adj_mode = adj_mode_list[ind]
674
+
675
+ if psim.Button("Recompute"):
676
+
677
+ ### Run clustering algorithm
678
+ num_clusters = opts.i_cluster
679
+
680
+ ### Mesh 1
681
+ point_feat = m['feat_np']
682
+ point_feat = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True)
683
+
684
+ ### Compute adjacency matrix ###
685
+ if opts.adj_mode == "Vanilla":
686
+ adj_matrix = construct_face_adjacency_matrix_naive(opts.m["F"])
687
+ elif opts.adj_mode == "Face_MST":
688
+ adj_matrix = construct_face_adjacency_matrix_facemst(opts.m["F"], opts.m["V"], with_knn=opts.add_knn_edges)
689
+ elif opts.adj_mode == "CC_MST":
690
+ adj_matrix = construct_face_adjacency_matrix_ccmst(opts.m["F"], opts.m["V"], with_knn=opts.add_knn_edges)
691
+ ################################
692
+
693
+ ## Agglomerative clustering
694
+ clustering = AgglomerativeClustering(connectivity= adj_matrix,
695
+ n_clusters=num_clusters,
696
+ ).fit(point_feat)
697
+
698
+ m['ps_mesh'].add_scalar_quantity("cluster", clustering.labels_, cmap='turbo', vminmax=(0, num_clusters-1), enabled=True, defined_on=m["viz_mode"])
699
+ print("Recomputed.")
700
+
701
+
702
+ elif opts.mode == "cluster_kmeans":
703
+ psim.TextUnformatted("""Use the slider to toggle the number of desired clusters.""")
704
+
705
+ cluster_changed, opts.i_cluster = psim.SliderInt("number of clusters", opts.i_cluster, v_min=1, v_max=30)
706
+
707
+ if psim.Button("Recompute"):
708
+
709
+ ### Run clustering algorithm
710
+ num_clusters = opts.i_cluster
711
+
712
+ ### Mesh 1
713
+ point_feat = m['feat_np']
714
+ point_feat = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True)
715
+ clustering = KMeans(n_clusters=num_clusters, random_state=0, n_init="auto").fit(point_feat)
716
+
717
+ m['ps_mesh'].add_scalar_quantity("cluster", clustering.labels_, cmap='turbo', vminmax=(0, num_clusters-1), enabled=True, defined_on=m["viz_mode"])
718
+
719
+ def main():
720
+ ## Parse args
721
+ # Uses simple_parsing library to automatically construct parser from the dataclass Options
722
+ parser = ArgumentParser()
723
+ parser.add_arguments(Options, dest="options")
724
+ parser.add_argument('--data_root', default="../exp_results/partfield_features/trellis/", help='Path the model features are stored.')
725
+ args = parser.parse_args()
726
+ opts: Options = args.options
727
+
728
+ DATA_ROOT = args.data_root
729
+
730
+ shape_1 = opts.filename
731
+
732
+ if os.path.exists(os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0.npy")):
733
+ feature_fname1 = os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0.npy")
734
+ mesh_fname1 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_1 + "_0.ply")
735
+ else:
736
+ feature_fname1 = os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0_batch.npy")
737
+ mesh_fname1 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_1 + "_0.ply")
738
+
739
+ #### To save output ####
740
+ os.makedirs(opts.output_fol, exist_ok=True)
741
+ ########################
742
+
743
+ # Initialize
744
+ ps.init()
745
+
746
+ mesh_dict = load_features(feature_fname1, mesh_fname1, opts.viz_mode)
747
+ prep_feature_mesh(mesh_dict)
748
+ mesh_dict["viz_mode"] = opts.viz_mode
749
+ opts.m = mesh_dict
750
+
751
+ # Start the interactive UI
752
+ ps.set_user_callback(lambda : ps_callback(opts))
753
+ ps.show()
754
+
755
+
756
+ if __name__ == "__main__":
757
+ main()
758
+
PartField/compute_metric.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ from os.path import join
4
+ from typing import List
5
+ import os
6
+
7
+ def compute_iou(pred, gt):
8
+ intersection = np.logical_and(pred, gt).sum()
9
+ union = np.logical_or(pred, gt).sum()
10
+ if union != 0:
11
+ return (intersection / union) * 100
12
+ else:
13
+ return 0
14
+
15
+ def eval_single_gt_shape(gt_label, pred_masks):
16
+ # gt: [N,], label index
17
+ # pred: [B, N], B is the number of predicted parts, binary label
18
+ unique_gt_label = np.unique(gt_label)
19
+ best_ious = []
20
+ for label in unique_gt_label:
21
+ best_iou = 0
22
+ if label == -1:
23
+ continue
24
+ for mask in pred_masks:
25
+ iou = compute_iou(mask, gt_label == label)
26
+ best_iou = max(best_iou, iou)
27
+ best_ious.append(best_iou)
28
+ return np.mean(best_ious)
29
+
30
+ def eval_whole_dataset(pred_folder, merge_parts=False):
31
+ print(pred_folder)
32
+ meta = json.load(open("/home/mikaelaangel/Desktop/data/PartObjaverse-Tiny_semantic.json", "r"))
33
+
34
+ categories = meta.keys()
35
+ results_per_cat = {}
36
+ per_cat_mious = []
37
+ overall_mious = []
38
+
39
+ MAX_NUM_CLUSTERS = 20
40
+ view_id = 0
41
+
42
+ for cat in categories:
43
+ results_per_cat[cat] = []
44
+ for shape_id in meta[cat].keys():
45
+
46
+ try:
47
+ all_pred_labels = []
48
+ for num_cluster in range(2, MAX_NUM_CLUSTERS):
49
+ ### load each label
50
+ fname_clustering = os.path.join(pred_folder, "cluster_out", str(shape_id) + "_" + str(view_id) + "_" + str(num_cluster).zfill(2)) + ".npy"
51
+ pred_label = np.load(fname_clustering)
52
+ all_pred_labels.append(np.squeeze(pred_label))
53
+
54
+ all_pred_labels = np.array(all_pred_labels)
55
+
56
+ except:
57
+ continue
58
+
59
+ pred_masks = []
60
+
61
+ #### Path for PartObjaverseTiny Labels
62
+ gt_labels_path = "PartObjaverse-Tiny_instance_gt"
63
+ #################################
64
+
65
+ gt_label = np.load(os.path.join(gt_labels_path, shape_id + ".npy"))
66
+
67
+ if merge_parts:
68
+ pred_masks = []
69
+ for result in all_pred_labels:
70
+ pred = result
71
+ assert pred.shape[0] == gt_label.shape[0]
72
+ for label in np.unique(pred):
73
+ pred_masks.append(pred == label)
74
+ miou = eval_single_gt_shape(gt_label, np.array(pred_masks))
75
+ results_per_cat[cat].append(miou)
76
+ else:
77
+ best_miou = 0
78
+ for result in all_pred_labels:
79
+ pred_masks = []
80
+ pred = result
81
+
82
+ for label in np.unique(pred):
83
+ pred_masks.append(pred == label)
84
+ miou = eval_single_gt_shape(gt_label, np.array(pred_masks))
85
+ best_miou = max(best_miou, miou)
86
+ results_per_cat[cat].append(best_miou)
87
+
88
+ print(np.mean(results_per_cat[cat]))
89
+ per_cat_mious.append(np.mean(results_per_cat[cat]))
90
+ overall_mious += results_per_cat[cat]
91
+ print(np.mean(per_cat_mious))
92
+ print(np.mean(overall_mious), len(overall_mious))
93
+
94
+
95
+ if __name__ == "__main__":
96
+ eval_whole_dataset("dump_partobjtiny_clustering")
97
+
PartField/configs/final/correspondence_demo.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ result_name: partfield_features/correspondence_demo
2
+
3
+ continue_ckpt: model/model.ckpt
4
+
5
+ triplane_channels_low: 128
6
+ triplane_channels_high: 512
7
+ triplane_resolution: 128
8
+
9
+ vertex_feature: True
10
+ n_point_per_face: 1000
11
+ n_sample_each: 10000
12
+ is_pc: False
13
+ remesh_demo: False
14
+ correspondence_demo: True
15
+
16
+ preprocess_mesh: True
17
+
18
+ dataset:
19
+ type: "Mix"
20
+ data_path: data/DenseCorr3D
21
+ train_batch_size: 1
22
+ val_batch_size: 1
23
+ train_num_workers: 8
24
+ all_files:
25
+ # pairs of example to run correspondence
26
+ - animals/071b8_toy_animals_017/simple_mesh.obj
27
+ - animals/bdfd0_toy_animals_016/simple_mesh.obj
28
+ - animals/2d6b3_toy_animals_009/simple_mesh.obj
29
+ - animals/96615_toy_animals_018/simple_mesh.obj
30
+ - chairs/063d1_chair_006/simple_mesh.obj
31
+ - chairs/bea57_chair_012/simple_mesh.obj
32
+ - chairs/fe0fe_chair_004/simple_mesh.obj
33
+ - chairs/288dc_chair_011/simple_mesh.obj
34
+ # consider decimating animals/../color_mesh.obj yourself for better mesh topology than the provided simple_mesh.obj
35
+ # (e.g. <50k vertices for functional map efficiency).
36
+
37
+ loss:
38
+ triplet: 1.0
39
+
40
+ use_2d_feat: False
41
+ pvcnn:
42
+ point_encoder_type: 'pvcnn'
43
+ z_triplane_channels: 256
44
+ z_triplane_resolution: 128
PartField/configs/final/demo.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ result_name: demo_test
2
+
3
+ continue_ckpt: model/model.ckpt
4
+
5
+ triplane_channels_low: 128
6
+ triplane_channels_high: 512
7
+ triplane_resolution: 128
8
+
9
+ n_point_per_face: 1000
10
+ n_sample_each: 10000
11
+ is_pc : False
12
+ remesh_demo : False
13
+
14
+ dataset:
15
+ type: "Mix"
16
+ data_path: "objaverse_data"
17
+ train_batch_size: 1
18
+ val_batch_size: 1
19
+ train_num_workers: 8
20
+
21
+ loss:
22
+ triplet: 1.0
23
+
24
+ use_2d_feat: False
25
+ pvcnn:
26
+ point_encoder_type: 'pvcnn'
27
+ z_triplane_channels: 256
28
+ z_triplane_resolution: 128
PartField/download_demo_data.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ mkdir data
3
+ cd data
4
+ mkdir objaverse_samples
5
+ cd objaverse_samples
6
+ wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-050/00200996b8f34f55a2dd2f44d316d107.glb
7
+ wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-042/002e462c8bfa4267a9c9f038c7966f3b.glb
8
+ wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-046/0c3ca2b32545416f8f1e6f0e87def1a6.glb
9
+ wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-063/65c6ffa083c6496eb84a0aa3c48d63ad.glb
10
+
11
+ cd ..
12
+ mkdir trellis_samples
13
+ cd trellis_samples
14
+ wget https://github.com/Trellis3D/trellis3d.github.io/raw/refs/heads/main/assets/scenes/blacksmith/glbs/dwarf.glb
15
+ wget https://github.com/Trellis3D/trellis3d.github.io/raw/refs/heads/main/assets/img2/glbs/goblin.glb
16
+ wget https://github.com/Trellis3D/trellis3d.github.io/raw/refs/heads/main/assets/img2/glbs/excavator.glb
17
+ wget https://github.com/Trellis3D/trellis3d.github.io/raw/refs/heads/main/assets/img2/glbs/elephant.glb
18
+ cd ..
19
+ cd ..
PartField/environment.yml ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: partfield
2
+ channels:
3
+ - nvidia/label/cuda-12.4.0
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ - _anaconda_depends=2025.03=py310_mkl_0
8
+ - _libgcc_mutex=0.1=conda_forge
9
+ - _openmp_mutex=4.5=2_gnu
10
+ - aiobotocore=2.21.1=pyhd8ed1ab_0
11
+ - aiohappyeyeballs=2.6.1=pyhd8ed1ab_0
12
+ - aiohttp=3.11.14=py310h89163eb_0
13
+ - aioitertools=0.12.0=pyhd8ed1ab_1
14
+ - aiosignal=1.3.2=pyhd8ed1ab_0
15
+ - alabaster=1.0.0=pyhd8ed1ab_1
16
+ - alsa-lib=1.2.13=hb9d3cd8_0
17
+ - altair=5.5.0=pyhd8ed1ab_1
18
+ - anaconda=custom=py310_3
19
+ - anyio=4.9.0=pyh29332c3_0
20
+ - aom=3.9.1=hac33072_0
21
+ - appdirs=1.4.4=pyhd8ed1ab_1
22
+ - argon2-cffi=23.1.0=pyhd8ed1ab_1
23
+ - argon2-cffi-bindings=21.2.0=py310ha75aee5_5
24
+ - arrow=1.3.0=pyhd8ed1ab_1
25
+ - astroid=3.3.9=py310hff52083_0
26
+ - astropy=6.1.7=py310hf462985_0
27
+ - astropy-iers-data=0.2025.3.31.0.36.18=pyhd8ed1ab_0
28
+ - asttokens=3.0.0=pyhd8ed1ab_1
29
+ - async-lru=2.0.5=pyh29332c3_0
30
+ - async-timeout=5.0.1=pyhd8ed1ab_1
31
+ - asyncssh=2.20.0=pyhd8ed1ab_0
32
+ - atomicwrites=1.4.1=pyhd8ed1ab_1
33
+ - attr=2.5.1=h166bdaf_1
34
+ - attrs=25.3.0=pyh71513ae_0
35
+ - automat=24.8.1=pyhd8ed1ab_1
36
+ - autopep8=2.0.4=pyhd8ed1ab_0
37
+ - aws-c-auth=0.8.6=hd08a7f5_4
38
+ - aws-c-cal=0.8.7=h043a21b_0
39
+ - aws-c-common=0.12.0=hb9d3cd8_0
40
+ - aws-c-compression=0.3.1=h3870646_2
41
+ - aws-c-event-stream=0.5.4=h04a3f94_2
42
+ - aws-c-http=0.9.4=hb9b18c6_4
43
+ - aws-c-io=0.17.0=h3dad3f2_6
44
+ - aws-c-mqtt=0.12.2=h108da3e_2
45
+ - aws-c-s3=0.7.13=h822ba82_2
46
+ - aws-c-sdkutils=0.2.3=h3870646_2
47
+ - aws-checksums=0.2.3=h3870646_2
48
+ - aws-crt-cpp=0.31.0=h55f77e1_4
49
+ - aws-sdk-cpp=1.11.510=h37a5c72_3
50
+ - azure-core-cpp=1.14.0=h5cfcd09_0
51
+ - azure-identity-cpp=1.10.0=h113e628_0
52
+ - azure-storage-blobs-cpp=12.13.0=h3cf044e_1
53
+ - azure-storage-common-cpp=12.8.0=h736e048_1
54
+ - azure-storage-files-datalake-cpp=12.12.0=ha633028_1
55
+ - babel=2.17.0=pyhd8ed1ab_0
56
+ - backports=1.0=pyhd8ed1ab_5
57
+ - backports.tarfile=1.2.0=pyhd8ed1ab_1
58
+ - bcrypt=4.3.0=py310h505e2c1_0
59
+ - beautifulsoup4=4.13.3=pyha770c72_0
60
+ - binaryornot=0.4.4=pyhd8ed1ab_2
61
+ - binutils=2.43=h4852527_4
62
+ - binutils_impl_linux-64=2.43=h4bf12b8_4
63
+ - binutils_linux-64=2.43=h4852527_4
64
+ - black=25.1.0=pyha5154f8_0
65
+ - blas=1.0=mkl
66
+ - bleach=6.2.0=pyh29332c3_4
67
+ - bleach-with-css=6.2.0=h82add2a_4
68
+ - blinker=1.9.0=pyhff2d567_0
69
+ - blosc=1.21.6=he440d0b_1
70
+ - bokeh=3.7.0=pyhd8ed1ab_0
71
+ - brotli=1.1.0=hb9d3cd8_2
72
+ - brotli-bin=1.1.0=hb9d3cd8_2
73
+ - brotli-python=1.1.0=py310hf71b8c6_2
74
+ - brunsli=0.1=h9c3ff4c_0
75
+ - bzip2=1.0.8=h4bc722e_7
76
+ - c-ares=1.34.4=hb9d3cd8_0
77
+ - c-blosc2=2.15.2=h3122c55_1
78
+ - c-compiler=1.9.0=h2b85faf_0
79
+ - ca-certificates=2025.1.31=hbcca054_0
80
+ - cached-property=1.5.2=hd8ed1ab_1
81
+ - cached_property=1.5.2=pyha770c72_1
82
+ - cachetools=5.5.2=pyhd8ed1ab_0
83
+ - cairo=1.18.4=h3394656_0
84
+ - certifi=2025.1.31=pyhd8ed1ab_0
85
+ - cffi=1.17.1=py310h8deb56e_0
86
+ - chardet=5.2.0=pyhd8ed1ab_3
87
+ - charls=2.4.2=h59595ed_0
88
+ - charset-normalizer=3.4.1=pyhd8ed1ab_0
89
+ - click=8.1.8=pyh707e725_0
90
+ - cloudpickle=3.1.1=pyhd8ed1ab_0
91
+ - colorama=0.4.6=pyhd8ed1ab_1
92
+ - colorcet=3.1.0=pyhd8ed1ab_1
93
+ - comm=0.2.2=pyhd8ed1ab_1
94
+ - constantly=15.1.0=py_0
95
+ - contourpy=1.3.1=py310h3788b33_0
96
+ - cookiecutter=2.6.0=pyhd8ed1ab_1
97
+ - cpython=3.10.16=py310hd8ed1ab_1
98
+ - cryptography=44.0.2=py310h6c63255_0
99
+ - cssselect=1.2.0=pyhd8ed1ab_1
100
+ - cuda=12.4.0=0
101
+ - cuda-cccl_linux-64=12.8.90=ha770c72_1
102
+ - cuda-command-line-tools=12.8.1=ha770c72_0
103
+ - cuda-compiler=12.8.1=hbad6d8a_0
104
+ - cuda-crt-dev_linux-64=12.8.93=ha770c72_1
105
+ - cuda-crt-tools=12.8.93=ha770c72_1
106
+ - cuda-cudart=12.8.90=h5888daf_1
107
+ - cuda-cudart-dev=12.8.90=h5888daf_1
108
+ - cuda-cudart-dev_linux-64=12.8.90=h3f2d84a_1
109
+ - cuda-cudart-static=12.8.90=h5888daf_1
110
+ - cuda-cudart-static_linux-64=12.8.90=h3f2d84a_1
111
+ - cuda-cudart_linux-64=12.8.90=h3f2d84a_1
112
+ - cuda-cuobjdump=12.8.90=hbd13f7d_1
113
+ - cuda-cupti=12.8.90=hbd13f7d_0
114
+ - cuda-cupti-dev=12.8.90=h5888daf_0
115
+ - cuda-cuxxfilt=12.8.90=hbd13f7d_1
116
+ - cuda-demo-suite=12.4.99=0
117
+ - cuda-driver-dev=12.8.90=h5888daf_1
118
+ - cuda-driver-dev_linux-64=12.8.90=h3f2d84a_1
119
+ - cuda-gdb=12.8.90=h50b4baa_0
120
+ - cuda-libraries=12.8.1=ha770c72_0
121
+ - cuda-libraries-dev=12.8.1=ha770c72_0
122
+ - cuda-nsight=12.8.90=h7938cbb_1
123
+ - cuda-nvcc=12.8.93=hcdd1206_1
124
+ - cuda-nvcc-dev_linux-64=12.8.93=he91c749_1
125
+ - cuda-nvcc-impl=12.8.93=h85509e4_1
126
+ - cuda-nvcc-tools=12.8.93=he02047a_1
127
+ - cuda-nvcc_linux-64=12.8.93=h04802cd_1
128
+ - cuda-nvdisasm=12.8.90=hbd13f7d_1
129
+ - cuda-nvml-dev=12.8.90=hbd13f7d_0
130
+ - cuda-nvprof=12.8.90=hbd13f7d_0
131
+ - cuda-nvprune=12.8.90=hbd13f7d_1
132
+ - cuda-nvrtc=12.8.93=h5888daf_1
133
+ - cuda-nvrtc-dev=12.8.93=h5888daf_1
134
+ - cuda-nvtx=12.8.90=hbd13f7d_0
135
+ - cuda-nvvm-dev_linux-64=12.8.93=ha770c72_1
136
+ - cuda-nvvm-impl=12.8.93=he02047a_1
137
+ - cuda-nvvm-tools=12.8.93=he02047a_1
138
+ - cuda-nvvp=12.8.93=hbd13f7d_1
139
+ - cuda-opencl=12.8.90=hbd13f7d_0
140
+ - cuda-opencl-dev=12.8.90=h5888daf_0
141
+ - cuda-profiler-api=12.8.90=h7938cbb_1
142
+ - cuda-sanitizer-api=12.8.93=hbd13f7d_1
143
+ - cuda-toolkit=12.8.1=ha804496_0
144
+ - cuda-tools=12.8.1=ha770c72_0
145
+ - cuda-version=12.8=h5d125a7_3
146
+ - cuda-visual-tools=12.8.1=ha770c72_0
147
+ - curl=8.12.1=h332b0f4_0
148
+ - cxx-compiler=1.9.0=h1a2810e_0
149
+ - cycler=0.12.1=pyhd8ed1ab_1
150
+ - cyrus-sasl=2.1.27=h54b06d7_7
151
+ - cytoolz=1.0.1=py310ha75aee5_0
152
+ - datashader=0.17.0=pyhd8ed1ab_0
153
+ - dav1d=1.2.1=hd590300_0
154
+ - dbus=1.13.6=h5008d03_3
155
+ - debugpy=1.8.13=py310hf71b8c6_0
156
+ - decorator=5.2.1=pyhd8ed1ab_0
157
+ - defusedxml=0.7.1=pyhd8ed1ab_0
158
+ - deprecated=1.2.18=pyhd8ed1ab_0
159
+ - diff-match-patch=20241021=pyhd8ed1ab_1
160
+ - dill=0.3.9=pyhd8ed1ab_1
161
+ - docstring-to-markdown=0.16=pyh29332c3_1
162
+ - docutils=0.21.2=pyhd8ed1ab_1
163
+ - double-conversion=3.3.1=h5888daf_0
164
+ - et_xmlfile=2.0.0=pyhd8ed1ab_1
165
+ - exceptiongroup=1.2.2=pyhd8ed1ab_1
166
+ - executing=2.1.0=pyhd8ed1ab_1
167
+ - expat=2.7.0=h5888daf_0
168
+ - fcitx-qt5=1.2.7=h748e8b9_2
169
+ - filelock=3.18.0=pyhd8ed1ab_0
170
+ - flake8=7.1.2=pyhd8ed1ab_0
171
+ - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
172
+ - font-ttf-inconsolata=3.000=h77eed37_0
173
+ - font-ttf-source-code-pro=2.038=h77eed37_0
174
+ - font-ttf-ubuntu=0.83=h77eed37_3
175
+ - fontconfig=2.15.0=h7e30c49_1
176
+ - fonts-conda-ecosystem=1=0
177
+ - fonts-conda-forge=1=0
178
+ - fonttools=4.56.0=py310h89163eb_0
179
+ - fqdn=1.5.1=pyhd8ed1ab_1
180
+ - freetype=2.13.3=h48d6fc4_0
181
+ - frozenlist=1.5.0=py310h89163eb_1
182
+ - fzf=0.61.0=h59e48b9_0
183
+ - gcc=13.3.0=h9576a4e_2
184
+ - gcc_impl_linux-64=13.3.0=h1e990d8_2
185
+ - gcc_linux-64=13.3.0=hc28eda2_8
186
+ - gds-tools=1.13.1.3=h5888daf_0
187
+ - gettext=0.23.1=h5888daf_0
188
+ - gettext-tools=0.23.1=h5888daf_0
189
+ - gflags=2.2.2=h5888daf_1005
190
+ - giflib=5.2.2=hd590300_0
191
+ - gitdb=4.0.12=pyhd8ed1ab_0
192
+ - gitpython=3.1.44=pyhff2d567_0
193
+ - glib=2.84.0=h07242d1_0
194
+ - glib-tools=2.84.0=h4833e2c_0
195
+ - glog=0.7.1=hbabe93e_0
196
+ - gmp=6.3.0=hac33072_2
197
+ - gmpy2=2.1.5=py310he8512ff_3
198
+ - graphite2=1.3.13=h59595ed_1003
199
+ - greenlet=3.1.1=py310hf71b8c6_1
200
+ - gst-plugins-base=1.24.7=h0a52356_0
201
+ - gstreamer=1.24.7=hf3bb09a_0
202
+ - gxx=13.3.0=h9576a4e_2
203
+ - gxx_impl_linux-64=13.3.0=hae580e1_2
204
+ - gxx_linux-64=13.3.0=h6834431_8
205
+ - h11=0.14.0=pyhd8ed1ab_1
206
+ - h2=4.2.0=pyhd8ed1ab_0
207
+ - h5py=3.13.0=nompi_py310h60e0fe6_100
208
+ - harfbuzz=10.4.0=h76408a6_0
209
+ - hdf5=1.14.3=nompi_h2d575fe_109
210
+ - holoviews=1.20.2=pyhd8ed1ab_0
211
+ - hpack=4.1.0=pyhd8ed1ab_0
212
+ - httpcore=1.0.7=pyh29332c3_1
213
+ - httpx=0.28.1=pyhd8ed1ab_0
214
+ - hvplot=0.11.2=pyhd8ed1ab_0
215
+ - hyperframe=6.1.0=pyhd8ed1ab_0
216
+ - hyperlink=21.0.0=pyh29332c3_1
217
+ - icu=75.1=he02047a_0
218
+ - idna=3.10=pyhd8ed1ab_1
219
+ - imagecodecs=2024.12.30=py310h78a9a29_0
220
+ - imageio=2.37.0=pyhfb79c49_0
221
+ - imagesize=1.4.1=pyhd8ed1ab_0
222
+ - imbalanced-learn=0.13.0=pyhd8ed1ab_0
223
+ - importlib-metadata=8.6.1=pyha770c72_0
224
+ - importlib_resources=6.5.2=pyhd8ed1ab_0
225
+ - incremental=24.7.2=pyhd8ed1ab_1
226
+ - inflection=0.5.1=pyhd8ed1ab_1
227
+ - iniconfig=2.0.0=pyhd8ed1ab_1
228
+ - intake=2.0.8=pyhd8ed1ab_0
229
+ - intervaltree=3.1.0=pyhd8ed1ab_1
230
+ - ipykernel=6.29.5=pyh3099207_0
231
+ - ipython=8.34.0=pyh907856f_0
232
+ - ipython_genutils=0.2.0=pyhd8ed1ab_2
233
+ - isoduration=20.11.0=pyhd8ed1ab_1
234
+ - isort=6.0.1=pyhd8ed1ab_0
235
+ - itemadapter=0.11.0=pyhd8ed1ab_0
236
+ - itemloaders=1.3.2=pyhd8ed1ab_1
237
+ - itsdangerous=2.2.0=pyhd8ed1ab_1
238
+ - jaraco.classes=3.4.0=pyhd8ed1ab_2
239
+ - jaraco.context=6.0.1=pyhd8ed1ab_0
240
+ - jaraco.functools=4.1.0=pyhd8ed1ab_0
241
+ - jedi=0.19.2=pyhd8ed1ab_1
242
+ - jeepney=0.9.0=pyhd8ed1ab_0
243
+ - jellyfish=1.1.3=py310h505e2c1_0
244
+ - jinja2=3.1.6=pyhd8ed1ab_0
245
+ - jmespath=1.0.1=pyhd8ed1ab_1
246
+ - joblib=1.4.2=pyhd8ed1ab_1
247
+ - jq=1.7.1=hd590300_0
248
+ - json5=0.10.0=pyhd8ed1ab_1
249
+ - jsonpointer=3.0.0=py310hff52083_1
250
+ - jsonschema=4.23.0=pyhd8ed1ab_1
251
+ - jsonschema-specifications=2024.10.1=pyhd8ed1ab_1
252
+ - jsonschema-with-format-nongpl=4.23.0=hd8ed1ab_1
253
+ - jupyter=1.1.1=pyhd8ed1ab_1
254
+ - jupyter-lsp=2.2.5=pyhd8ed1ab_1
255
+ - jupyter_client=8.6.3=pyhd8ed1ab_1
256
+ - jupyter_console=6.6.3=pyhd8ed1ab_1
257
+ - jupyter_core=5.7.2=pyh31011fe_1
258
+ - jupyter_events=0.12.0=pyh29332c3_0
259
+ - jupyter_server=2.15.0=pyhd8ed1ab_0
260
+ - jupyter_server_terminals=0.5.3=pyhd8ed1ab_1
261
+ - jupyterlab=4.3.6=pyhd8ed1ab_0
262
+ - jupyterlab-variableinspector=3.2.4=pyhd8ed1ab_0
263
+ - jupyterlab_pygments=0.3.0=pyhd8ed1ab_2
264
+ - jupyterlab_server=2.27.3=pyhd8ed1ab_1
265
+ - jxrlib=1.1=hd590300_3
266
+ - kernel-headers_linux-64=3.10.0=he073ed8_18
267
+ - keyring=25.6.0=pyha804496_0
268
+ - keyutils=1.6.1=h166bdaf_0
269
+ - kiwisolver=1.4.7=py310h3788b33_0
270
+ - krb5=1.21.3=h659f571_0
271
+ - lame=3.100=h166bdaf_1003
272
+ - lazy-loader=0.4=pyhd8ed1ab_2
273
+ - lazy_loader=0.4=pyhd8ed1ab_2
274
+ - lcms2=2.17=h717163a_0
275
+ - ld_impl_linux-64=2.43=h712a8e2_4
276
+ - lerc=4.0.0=h27087fc_0
277
+ - libabseil=20250127.1=cxx17_hbbce691_0
278
+ - libaec=1.1.3=h59595ed_0
279
+ - libarrow=19.0.1=h120c447_5_cpu
280
+ - libarrow-acero=19.0.1=hcb10f89_5_cpu
281
+ - libarrow-dataset=19.0.1=hcb10f89_5_cpu
282
+ - libarrow-substrait=19.0.1=h1bed206_5_cpu
283
+ - libasprintf=0.23.1=h8e693c7_0
284
+ - libasprintf-devel=0.23.1=h8e693c7_0
285
+ - libavif16=1.2.1=hbb36593_2
286
+ - libblas=3.9.0=1_h86c2bf4_netlib
287
+ - libbrotlicommon=1.1.0=hb9d3cd8_2
288
+ - libbrotlidec=1.1.0=hb9d3cd8_2
289
+ - libbrotlienc=1.1.0=hb9d3cd8_2
290
+ - libcap=2.75=h39aace5_0
291
+ - libcblas=3.9.0=8_h3b12eaf_netlib
292
+ - libclang-cpp19.1=19.1.7=default_hb5137d0_2
293
+ - libclang-cpp20.1=20.1.1=default_hb5137d0_0
294
+ - libclang13=20.1.1=default_h9c6a7e4_0
295
+ - libcrc32c=1.1.2=h9c3ff4c_0
296
+ - libcublas=12.8.4.1=h9ab20c4_1
297
+ - libcublas-dev=12.8.4.1=h9ab20c4_1
298
+ - libcufft=11.3.3.83=h5888daf_1
299
+ - libcufft-dev=11.3.3.83=h5888daf_1
300
+ - libcufile=1.13.1.3=h12f29b5_0
301
+ - libcufile-dev=1.13.1.3=h5888daf_0
302
+ - libcups=2.3.3=h4637d8d_4
303
+ - libcurand=10.3.9.90=h9ab20c4_1
304
+ - libcurand-dev=10.3.9.90=h9ab20c4_1
305
+ - libcurl=8.12.1=h332b0f4_0
306
+ - libcusolver=11.7.3.90=h9ab20c4_1
307
+ - libcusolver-dev=11.7.3.90=h9ab20c4_1
308
+ - libcusparse=12.5.8.93=hbd13f7d_0
309
+ - libcusparse-dev=12.5.8.93=h5888daf_0
310
+ - libdeflate=1.23=h4ddbbb0_0
311
+ - libdrm=2.4.124=hb9d3cd8_0
312
+ - libedit=3.1.20250104=pl5321h7949ede_0
313
+ - libegl=1.7.0=ha4b6fd6_2
314
+ - libev=4.33=hd590300_2
315
+ - libevent=2.1.12=hf998b51_1
316
+ - libexpat=2.7.0=h5888daf_0
317
+ - libffi=3.4.6=h2dba641_1
318
+ - libflac=1.4.3=h59595ed_0
319
+ - libgcc=14.2.0=h767d61c_2
320
+ - libgcc-devel_linux-64=13.3.0=hc03c837_102
321
+ - libgcc-ng=14.2.0=h69a702a_2
322
+ - libgcrypt-lib=1.11.0=hb9d3cd8_2
323
+ - libgettextpo=0.23.1=h5888daf_0
324
+ - libgettextpo-devel=0.23.1=h5888daf_0
325
+ - libgfortran=14.2.0=h69a702a_2
326
+ - libgfortran-ng=14.2.0=h69a702a_2
327
+ - libgfortran5=14.2.0=hf1ad2bd_2
328
+ - libgl=1.7.0=ha4b6fd6_2
329
+ - libglib=2.84.0=h2ff4ddf_0
330
+ - libglvnd=1.7.0=ha4b6fd6_2
331
+ - libglx=1.7.0=ha4b6fd6_2
332
+ - libgomp=14.2.0=h767d61c_2
333
+ - libgoogle-cloud=2.36.0=hc4361e1_1
334
+ - libgoogle-cloud-storage=2.36.0=h0121fbd_1
335
+ - libgpg-error=1.51=hbd13f7d_1
336
+ - libgrpc=1.71.0=he753a82_0
337
+ - libhwy=1.1.0=h00ab1b0_0
338
+ - libiconv=1.18=h4ce23a2_1
339
+ - libjpeg-turbo=3.0.0=hd590300_1
340
+ - libjxl=0.11.1=hdb8da77_0
341
+ - liblapack=3.9.0=8_h3b12eaf_netlib
342
+ - libllvm19=19.1.7=ha7bfdaf_1
343
+ - libllvm20=20.1.1=ha7bfdaf_0
344
+ - liblzma=5.6.4=hb9d3cd8_0
345
+ - libnghttp2=1.64.0=h161d5f1_0
346
+ - libnl=3.11.0=hb9d3cd8_0
347
+ - libnpp=12.3.3.100=h9ab20c4_1
348
+ - libnpp-dev=12.3.3.100=h9ab20c4_1
349
+ - libnsl=2.0.1=hd590300_0
350
+ - libntlm=1.8=hb9d3cd8_0
351
+ - libnuma=2.0.18=h4ab18f5_2
352
+ - libnvfatbin=12.8.90=hbd13f7d_0
353
+ - libnvfatbin-dev=12.8.90=h5888daf_0
354
+ - libnvjitlink=12.8.93=h5888daf_1
355
+ - libnvjitlink-dev=12.8.93=h5888daf_1
356
+ - libnvjpeg=12.3.5.92=h97fd463_0
357
+ - libnvjpeg-dev=12.3.5.92=ha770c72_0
358
+ - libogg=1.3.5=h4ab18f5_0
359
+ - libopengl=1.7.0=ha4b6fd6_2
360
+ - libopentelemetry-cpp=1.19.0=hd1b1c89_0
361
+ - libopentelemetry-cpp-headers=1.19.0=ha770c72_0
362
+ - libopus=1.3.1=h7f98852_1
363
+ - libparquet=19.0.1=h081d1f1_5_cpu
364
+ - libpciaccess=0.18=hd590300_0
365
+ - libpng=1.6.47=h943b412_0
366
+ - libpq=17.4=h27ae623_0
367
+ - libprotobuf=5.29.3=h501fc15_0
368
+ - libre2-11=2024.07.02=hba17884_3
369
+ - libsanitizer=13.3.0=he8ea267_2
370
+ - libsndfile=1.2.2=hc60ed4a_1
371
+ - libsodium=1.0.20=h4ab18f5_0
372
+ - libspatialindex=2.1.0=he57a185_0
373
+ - libsqlite=3.49.1=hee588c1_2
374
+ - libssh2=1.11.1=hf672d98_0
375
+ - libstdcxx=14.2.0=h8f9b012_2
376
+ - libstdcxx-devel_linux-64=13.3.0=hc03c837_102
377
+ - libstdcxx-ng=14.2.0=h4852527_2
378
+ - libsystemd0=257.4=h4e0b6ca_1
379
+ - libthrift=0.21.0=h0e7cc3e_0
380
+ - libtiff=4.7.0=hd9ff511_3
381
+ - libudev1=257.4=hbe16f8c_1
382
+ - libutf8proc=2.10.0=h4c51ac1_0
383
+ - libuuid=2.38.1=h0b41bf4_0
384
+ - libvorbis=1.3.7=h9c3ff4c_0
385
+ - libwebp=1.5.0=hae8dbeb_0
386
+ - libwebp-base=1.5.0=h851e524_0
387
+ - libxcb=1.17.0=h8a09558_0
388
+ - libxcrypt=4.4.36=hd590300_1
389
+ - libxkbcommon=1.8.1=hc4a0caf_0
390
+ - libxkbfile=1.1.0=h166bdaf_1
391
+ - libxml2=2.13.7=h8d12d68_0
392
+ - libxslt=1.1.39=h76b75d6_0
393
+ - libzlib=1.3.1=hb9d3cd8_2
394
+ - libzopfli=1.0.3=h9c3ff4c_0
395
+ - linkify-it-py=2.0.3=pyhd8ed1ab_1
396
+ - locket=1.0.0=pyhd8ed1ab_0
397
+ - lxml=5.3.1=py310h6ee67d5_0
398
+ - lz4=4.3.3=py310h80b8a69_2
399
+ - lz4-c=1.10.0=h5888daf_1
400
+ - markdown=3.6=pyhd8ed1ab_0
401
+ - markdown-it-py=3.0.0=pyhd8ed1ab_1
402
+ - markupsafe=3.0.2=py310h89163eb_1
403
+ - matplotlib=3.10.1=py310hff52083_0
404
+ - matplotlib-base=3.10.1=py310h68603db_0
405
+ - matplotlib-inline=0.1.7=pyhd8ed1ab_1
406
+ - mccabe=0.7.0=pyhd8ed1ab_1
407
+ - mdit-py-plugins=0.4.2=pyhd8ed1ab_1
408
+ - mdurl=0.1.2=pyhd8ed1ab_1
409
+ - mistune=3.1.3=pyh29332c3_0
410
+ - more-itertools=10.6.0=pyhd8ed1ab_0
411
+ - mpc=1.3.1=h24ddda3_1
412
+ - mpfr=4.2.1=h90cbb55_3
413
+ - mpg123=1.32.9=hc50e24c_0
414
+ - mpmath=1.3.0=pyhd8ed1ab_1
415
+ - msgpack-python=1.1.0=py310h3788b33_0
416
+ - multidict=6.2.0=py310h89163eb_0
417
+ - multipledispatch=0.6.0=pyhd8ed1ab_1
418
+ - munkres=1.1.4=pyh9f0ad1d_0
419
+ - mypy=1.15.0=py310ha75aee5_0
420
+ - mypy_extensions=1.0.0=pyha770c72_1
421
+ - mysql-common=9.0.1=h266115a_5
422
+ - mysql-libs=9.0.1=he0572af_5
423
+ - narwhals=1.32.0=pyhd8ed1ab_0
424
+ - nbclient=0.10.2=pyhd8ed1ab_0
425
+ - nbconvert=7.16.6=hb482800_0
426
+ - nbconvert-core=7.16.6=pyh29332c3_0
427
+ - nbconvert-pandoc=7.16.6=hed9df3c_0
428
+ - nbformat=5.10.4=pyhd8ed1ab_1
429
+ - ncurses=6.5=h2d0b736_3
430
+ - nest-asyncio=1.6.0=pyhd8ed1ab_1
431
+ - networkx=3.4.2=pyh267e887_2
432
+ - nlohmann_json=3.11.3=he02047a_1
433
+ - nltk=3.9.1=pyhd8ed1ab_1
434
+ - nomkl=1.0=h5ca1d4c_0
435
+ - notebook=7.3.3=pyhd8ed1ab_0
436
+ - notebook-shim=0.2.4=pyhd8ed1ab_1
437
+ - nsight-compute=2025.1.1.2=hb5ebaad_0
438
+ - nspr=4.36=h5888daf_0
439
+ - nss=3.110=h159eef7_0
440
+ - numexpr=2.10.2=py310hdb6e06b_100
441
+ - numpydoc=1.8.0=pyhd8ed1ab_1
442
+ - ocl-icd=2.3.2=hb9d3cd8_2
443
+ - oniguruma=6.9.10=hb9d3cd8_0
444
+ - opencl-headers=2024.10.24=h5888daf_0
445
+ - openjpeg=2.5.3=h5fbd93e_0
446
+ - openldap=2.6.9=he970967_0
447
+ - openpyxl=3.1.5=py310h0999ad4_1
448
+ - openssl=3.4.1=h7b32b05_0
449
+ - orc=2.1.1=h17f744e_1
450
+ - overrides=7.7.0=pyhd8ed1ab_1
451
+ - packaging=24.2=pyhd8ed1ab_2
452
+ - pandas=2.2.3=py310h5eaa309_1
453
+ - pandoc=3.6.4=ha770c72_0
454
+ - pandocfilters=1.5.0=pyhd8ed1ab_0
455
+ - panel=1.6.2=pyhd8ed1ab_0
456
+ - param=2.2.0=pyhd8ed1ab_0
457
+ - parsel=1.10.0=pyhd8ed1ab_0
458
+ - parso=0.8.4=pyhd8ed1ab_1
459
+ - partd=1.4.2=pyhd8ed1ab_0
460
+ - pathspec=0.12.1=pyhd8ed1ab_1
461
+ - patsy=1.0.1=pyhd8ed1ab_1
462
+ - pcre2=10.44=hba22ea6_2
463
+ - pexpect=4.9.0=pyhd8ed1ab_1
464
+ - pickleshare=0.7.5=pyhd8ed1ab_1004
465
+ - pillow=11.1.0=py310h7e6dc6c_0
466
+ - pip=25.0.1=pyh8b19718_0
467
+ - pixman=0.44.2=h29eaf8c_0
468
+ - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_2
469
+ - platformdirs=4.3.7=pyh29332c3_0
470
+ - plotly=6.0.1=pyhd8ed1ab_0
471
+ - pluggy=1.5.0=pyhd8ed1ab_1
472
+ - ply=3.11=pyhd8ed1ab_3
473
+ - prometheus-cpp=1.3.0=ha5d0236_0
474
+ - prometheus_client=0.21.1=pyhd8ed1ab_0
475
+ - prompt-toolkit=3.0.50=pyha770c72_0
476
+ - prompt_toolkit=3.0.50=hd8ed1ab_0
477
+ - propcache=0.2.1=py310h89163eb_1
478
+ - protego=0.4.0=pyhd8ed1ab_0
479
+ - protobuf=5.29.3=py310hcba5963_0
480
+ - psutil=7.0.0=py310ha75aee5_0
481
+ - pthread-stubs=0.4=hb9d3cd8_1002
482
+ - ptyprocess=0.7.0=pyhd8ed1ab_1
483
+ - pulseaudio-client=17.0=hac146a9_1
484
+ - pure_eval=0.2.3=pyhd8ed1ab_1
485
+ - py-cpuinfo=9.0.0=pyhd8ed1ab_1
486
+ - pyarrow=19.0.1=py310hff52083_0
487
+ - pyarrow-core=19.0.1=py310hac404ae_0_cpu
488
+ - pyasn1=0.6.1=pyhd8ed1ab_2
489
+ - pyasn1-modules=0.4.2=pyhd8ed1ab_0
490
+ - pycodestyle=2.12.1=pyhd8ed1ab_1
491
+ - pyconify=0.2.1=pyhd8ed1ab_0
492
+ - pycparser=2.22=pyh29332c3_1
493
+ - pyct=0.5.0=pyhd8ed1ab_1
494
+ - pycurl=7.45.6=py310h6811363_0
495
+ - pydeck=0.9.1=pyhd8ed1ab_0
496
+ - pydispatcher=2.0.5=py_1
497
+ - pydocstyle=6.3.0=pyhd8ed1ab_1
498
+ - pyerfa=2.0.1.5=py310hf462985_0
499
+ - pyflakes=3.2.0=pyhd8ed1ab_1
500
+ - pygithub=2.6.1=pyhd8ed1ab_0
501
+ - pygments=2.19.1=pyhd8ed1ab_0
502
+ - pyjwt=2.10.1=pyhd8ed1ab_0
503
+ - pylint=3.3.5=pyh29332c3_0
504
+ - pylint-venv=3.0.4=pyhd8ed1ab_1
505
+ - pyls-spyder=0.4.0=pyhd8ed1ab_1
506
+ - pynacl=1.5.0=py310ha75aee5_4
507
+ - pyodbc=5.2.0=py310hf71b8c6_0
508
+ - pyopenssl=25.0.0=pyhd8ed1ab_0
509
+ - pyparsing=3.2.3=pyhd8ed1ab_1
510
+ - pyqt=5.15.9=py310h04931ad_5
511
+ - pyqt5-sip=12.12.2=py310hc6cd4ac_5
512
+ - pyqtwebengine=5.15.9=py310h704022c_5
513
+ - pyside6=6.8.3=py310hfd10a26_0
514
+ - pysocks=1.7.1=pyha55dd90_7
515
+ - pytables=3.10.1=py310h1affd9f_4
516
+ - pytest=8.3.5=pyhd8ed1ab_0
517
+ - python=3.10.16=he725a3c_1_cpython
518
+ - python-dateutil=2.9.0.post0=pyhff2d567_1
519
+ - python-fastjsonschema=2.21.1=pyhd8ed1ab_0
520
+ - python-gssapi=1.9.0=py310h695cd88_1
521
+ - python-json-logger=2.0.7=pyhd8ed1ab_0
522
+ - python-lsp-black=2.0.0=pyhff2d567_1
523
+ - python-lsp-jsonrpc=1.1.2=pyhff2d567_1
524
+ - python-lsp-server=1.12.2=pyhff2d567_0
525
+ - python-lsp-server-base=1.12.2=pyhd8ed1ab_0
526
+ - python-slugify=8.0.4=pyhd8ed1ab_1
527
+ - python-tzdata=2025.2=pyhd8ed1ab_0
528
+ - python_abi=3.10=5_cp310
529
+ - pytoolconfig=1.2.5=pyhd8ed1ab_1
530
+ - pytz=2024.1=pyhd8ed1ab_0
531
+ - pyuca=1.2=pyhd8ed1ab_2
532
+ - pyviz_comms=3.0.4=pyhd8ed1ab_1
533
+ - pywavelets=1.8.0=py310hf462985_0
534
+ - pyxdg=0.28=pyhd8ed1ab_0
535
+ - pyyaml=6.0.2=py310h89163eb_2
536
+ - pyzmq=26.3.0=py310h71f11fc_0
537
+ - qdarkstyle=3.2.3=pyhd8ed1ab_1
538
+ - qhull=2020.2=h434a139_5
539
+ - qstylizer=0.2.4=pyhff2d567_0
540
+ - qt-main=5.15.15=hc3cb62f_2
541
+ - qt-webengine=5.15.15=h0071231_2
542
+ - qt6-main=6.8.3=h588cce1_0
543
+ - qtawesome=1.4.0=pyh9208f05_1
544
+ - qtconsole=5.6.1=pyhd8ed1ab_1
545
+ - qtconsole-base=5.6.1=pyha770c72_1
546
+ - qtpy=2.4.3=pyhd8ed1ab_0
547
+ - queuelib=1.8.0=pyhd8ed1ab_0
548
+ - rav1e=0.6.6=he8a937b_2
549
+ - rdma-core=56.0=h5888daf_0
550
+ - re2=2024.07.02=h9925aae_3
551
+ - readline=8.2=h8c095d6_2
552
+ - referencing=0.36.2=pyh29332c3_0
553
+ - regex=2024.11.6=py310ha75aee5_0
554
+ - requests=2.32.3=pyhd8ed1ab_1
555
+ - requests-file=2.1.0=pyhd8ed1ab_1
556
+ - rfc3339-validator=0.1.4=pyhd8ed1ab_1
557
+ - rfc3986-validator=0.1.1=pyh9f0ad1d_0
558
+ - rich=14.0.0=pyh29332c3_0
559
+ - rope=1.13.0=pyhd8ed1ab_1
560
+ - rpds-py=0.24.0=py310hc1293b2_0
561
+ - rtree=1.4.0=pyh11ca60a_1
562
+ - s2n=1.5.14=h6c98b2b_0
563
+ - s3fs=2025.3.1=pyhd8ed1ab_0
564
+ - scikit-image=0.25.2=py310h5eaa309_0
565
+ - scikit-learn=1.6.1=py310h27f47ee_0
566
+ - scipy=1.15.2=py310h1d65ade_0
567
+ - scrapy=2.12.0=py310hff52083_1
568
+ - seaborn=0.13.2=hd8ed1ab_3
569
+ - seaborn-base=0.13.2=pyhd8ed1ab_3
570
+ - secretstorage=3.3.3=py310hff52083_3
571
+ - send2trash=1.8.3=pyh0d859eb_1
572
+ - service-identity=24.2.0=pyha770c72_1
573
+ - service_identity=24.2.0=hd8ed1ab_1
574
+ - setuptools=75.8.2=pyhff2d567_0
575
+ - sip=6.7.12=py310hc6cd4ac_0
576
+ - six=1.17.0=pyhd8ed1ab_0
577
+ - sklearn-compat=0.1.3=pyhd8ed1ab_0
578
+ - smmap=5.0.2=pyhd8ed1ab_0
579
+ - snappy=1.2.1=h8bd8927_1
580
+ - sniffio=1.3.1=pyhd8ed1ab_1
581
+ - snowballstemmer=2.2.0=pyhd8ed1ab_0
582
+ - sortedcontainers=2.4.0=pyhd8ed1ab_1
583
+ - soupsieve=2.5=pyhd8ed1ab_1
584
+ - sphinx=8.1.3=pyhd8ed1ab_1
585
+ - sphinxcontrib-applehelp=2.0.0=pyhd8ed1ab_1
586
+ - sphinxcontrib-devhelp=2.0.0=pyhd8ed1ab_1
587
+ - sphinxcontrib-htmlhelp=2.1.0=pyhd8ed1ab_1
588
+ - sphinxcontrib-jsmath=1.0.1=pyhd8ed1ab_1
589
+ - sphinxcontrib-qthelp=2.0.0=pyhd8ed1ab_1
590
+ - sphinxcontrib-serializinghtml=1.1.10=pyhd8ed1ab_1
591
+ - spyder=6.0.5=hd8ed1ab_0
592
+ - spyder-base=6.0.5=linux_pyh62a8a7d_0
593
+ - spyder-kernels=3.0.3=unix_pyh707e725_0
594
+ - sqlalchemy=2.0.40=py310ha75aee5_0
595
+ - stack_data=0.6.3=pyhd8ed1ab_1
596
+ - statsmodels=0.14.4=py310hf462985_0
597
+ - streamlit=1.44.0=pyhd8ed1ab_1
598
+ - superqt=0.7.3=pyhb6d5dde_0
599
+ - svt-av1=3.0.2=h5888daf_0
600
+ - sympy=1.13.3=pyh2585a3b_105
601
+ - sysroot_linux-64=2.17=h0157908_18
602
+ - tabulate=0.9.0=pyhd8ed1ab_2
603
+ - tblib=3.0.0=pyhd8ed1ab_1
604
+ - tenacity=9.0.0=pyhd8ed1ab_1
605
+ - terminado=0.18.1=pyh0d859eb_0
606
+ - text-unidecode=1.3=pyhd8ed1ab_2
607
+ - textdistance=4.6.3=pyhd8ed1ab_1
608
+ - threadpoolctl=3.6.0=pyhecae5ae_0
609
+ - three-merge=0.1.1=pyhd8ed1ab_1
610
+ - tifffile=2025.3.30=pyhd8ed1ab_0
611
+ - tinycss2=1.4.0=pyhd8ed1ab_0
612
+ - tk=8.6.13=noxft_h4845f30_101
613
+ - tldextract=5.1.3=pyhd8ed1ab_1
614
+ - toml=0.10.2=pyhd8ed1ab_1
615
+ - tomli=2.2.1=pyhd8ed1ab_1
616
+ - tomlkit=0.13.2=pyha770c72_1
617
+ - toolz=1.0.0=pyhd8ed1ab_1
618
+ - tornado=6.4.2=py310ha75aee5_0
619
+ - tqdm=4.67.1=pyhd8ed1ab_1
620
+ - traitlets=5.14.3=pyhd8ed1ab_1
621
+ - twisted=24.11.0=py310ha75aee5_0
622
+ - types-python-dateutil=2.9.0.20241206=pyhd8ed1ab_0
623
+ - typing-extensions=4.13.0=h9fa5a19_1
624
+ - typing_extensions=4.13.0=pyh29332c3_1
625
+ - typing_utils=0.1.0=pyhd8ed1ab_1
626
+ - tzdata=2025b=h78e105d_0
627
+ - uc-micro-py=1.0.3=pyhd8ed1ab_1
628
+ - ujson=5.10.0=py310hf71b8c6_1
629
+ - unicodedata2=16.0.0=py310ha75aee5_0
630
+ - unixodbc=2.3.12=h661eb56_0
631
+ - uri-template=1.3.0=pyhd8ed1ab_1
632
+ - urllib3=2.3.0=pyhd8ed1ab_0
633
+ - w3lib=2.3.1=pyhd8ed1ab_0
634
+ - watchdog=6.0.0=py310hff52083_0
635
+ - wayland=1.23.1=h3e06ad9_0
636
+ - wcwidth=0.2.13=pyhd8ed1ab_1
637
+ - webcolors=24.11.1=pyhd8ed1ab_0
638
+ - webencodings=0.5.1=pyhd8ed1ab_3
639
+ - websocket-client=1.8.0=pyhd8ed1ab_1
640
+ - whatthepatch=1.0.7=pyhd8ed1ab_1
641
+ - wheel=0.45.1=pyhd8ed1ab_1
642
+ - wrapt=1.17.2=py310ha75aee5_0
643
+ - wurlitzer=3.1.1=pyhd8ed1ab_1
644
+ - xarray=2025.3.1=pyhd8ed1ab_0
645
+ - xcb-util=0.4.1=hb711507_2
646
+ - xcb-util-cursor=0.1.5=hb9d3cd8_0
647
+ - xcb-util-image=0.4.0=hb711507_2
648
+ - xcb-util-keysyms=0.4.1=hb711507_0
649
+ - xcb-util-renderutil=0.3.10=hb711507_0
650
+ - xcb-util-wm=0.4.2=hb711507_0
651
+ - xkeyboard-config=2.43=hb9d3cd8_0
652
+ - xorg-libice=1.1.2=hb9d3cd8_0
653
+ - xorg-libsm=1.2.6=he73a12e_0
654
+ - xorg-libx11=1.8.12=h4f16b4b_0
655
+ - xorg-libxau=1.0.12=hb9d3cd8_0
656
+ - xorg-libxcomposite=0.4.6=hb9d3cd8_2
657
+ - xorg-libxcursor=1.2.3=hb9d3cd8_0
658
+ - xorg-libxdamage=1.1.6=hb9d3cd8_0
659
+ - xorg-libxdmcp=1.1.5=hb9d3cd8_0
660
+ - xorg-libxext=1.3.6=hb9d3cd8_0
661
+ - xorg-libxfixes=6.0.1=hb9d3cd8_0
662
+ - xorg-libxi=1.8.2=hb9d3cd8_0
663
+ - xorg-libxrandr=1.5.4=hb9d3cd8_0
664
+ - xorg-libxrender=0.9.12=hb9d3cd8_0
665
+ - xorg-libxtst=1.2.5=hb9d3cd8_3
666
+ - xorg-libxxf86vm=1.1.6=hb9d3cd8_0
667
+ - xyzservices=2025.1.0=pyhd8ed1ab_0
668
+ - yaml=0.2.5=h7f98852_2
669
+ - yapf=0.43.0=pyhd8ed1ab_1
670
+ - yarl=1.18.3=py310h89163eb_1
671
+ - zeromq=4.3.5=h3b0a872_7
672
+ - zfp=1.0.1=h5888daf_2
673
+ - zict=3.0.0=pyhd8ed1ab_1
674
+ - zipp=3.21.0=pyhd8ed1ab_1
675
+ - zlib=1.3.1=hb9d3cd8_2
676
+ - zlib-ng=2.2.4=h7955e40_0
677
+ - zope.interface=7.2=py310ha75aee5_0
678
+ - zstandard=0.23.0=py310ha75aee5_1
679
+ - zstd=1.5.7=hb8e6e7a_2
680
+ - pip:
681
+ - addict==2.4.0
682
+ - arrgh==1.0.0
683
+ - boto3==1.37.24
684
+ - botocore==1.37.24
685
+ - configargparse==1.7
686
+ - cuda-bindings==12.8.0
687
+ - cuda-python==12.8.0
688
+ - cudf-cu12==25.2.2
689
+ - cuml-cu12==25.2.1
690
+ - cupy-cuda12x==13.4.1
691
+ - cuvs-cu12==25.2.1
692
+ - dash==3.0.2
693
+ - dask==2024.12.1
694
+ - dask-cuda==25.2.0
695
+ - dask-cudf-cu12==25.2.2
696
+ - dask-expr==1.1.21
697
+ - distributed==2024.12.1
698
+ - distributed-ucxx-cu12==0.42.0
699
+ - docstring-parser==0.16
700
+ - einops==0.8.1
701
+ - fastrlock==0.8.3
702
+ - flask==3.0.3
703
+ - fsspec==2024.12.0
704
+ - ipywidgets==8.1.5
705
+ - jupyterlab-widgets==3.0.13
706
+ - libcudf-cu12==25.2.2
707
+ - libcuml-cu12==25.2.1
708
+ - libcuvs-cu12==25.2.1
709
+ - libigl==2.5.1
710
+ - libkvikio-cu12==25.2.1
711
+ - libraft-cu12==25.2.0
712
+ - libucx-cu12==1.18.0
713
+ - libucxx-cu12==0.42.0
714
+ - lightning==2.2.0
715
+ - lightning-utilities==0.14.2
716
+ - llvmlite==0.43.0
717
+ - loguru==0.7.3
718
+ - mesh2sdf==1.1.0
719
+ - numba==0.60.0
720
+ - numba-cuda==0.2.0
721
+ - numpy==2.0.2
722
+ - nvidia-cublas-cu12==12.4.2.65
723
+ - nvidia-cuda-cupti-cu12==12.4.99
724
+ - nvidia-cuda-nvrtc-cu12==12.4.99
725
+ - nvidia-cuda-runtime-cu12==12.4.99
726
+ - nvidia-cudnn-cu12==9.1.0.70
727
+ - nvidia-cufft-cu12==11.2.0.44
728
+ - nvidia-curand-cu12==10.3.5.119
729
+ - nvidia-cusolver-cu12==11.6.0.99
730
+ - nvidia-cusparse-cu12==12.3.0.142
731
+ - nvidia-ml-py==12.570.86
732
+ - nvidia-nccl-cu12==2.20.5
733
+ - nvidia-nvcomp-cu12==4.2.0.11
734
+ - nvidia-nvjitlink-cu12==12.4.99
735
+ - nvidia-nvtx-cu12==12.4.99
736
+ - nvtx==0.2.11
737
+ - open3d==0.19.0
738
+ - plyfile==1.1
739
+ - polyscope==2.4.0
740
+ - pooch==1.8.2
741
+ - potpourri3d==1.2.1
742
+ - pylibcudf-cu12==25.2.2
743
+ - pylibraft-cu12==25.2.0
744
+ - pymeshlab==2023.12.post3
745
+ - pynvjitlink-cu12==0.5.2
746
+ - pynvml==12.0.0
747
+ - pyquaternion==0.9.9
748
+ - pytorch-lightning==2.5.1
749
+ - pyvista==0.44.2
750
+ - raft-dask-cu12==25.2.0
751
+ - rapids-dask-dependency==25.2.0
752
+ - retrying==1.3.4
753
+ - rmm-cu12==25.2.0
754
+ - s3transfer==0.11.4
755
+ - scooby==0.10.0
756
+ - simple-parsing==0.1.7
757
+ - tetgen==0.6.5
758
+ - torch==2.4.0+cu124
759
+ - torch-scatter==2.1.2+pt24cu124
760
+ - torchaudio==2.4.0+cu124
761
+ - torchmetrics==1.7.0
762
+ - torchvision==0.19.0+cu124
763
+ - treelite==4.4.1
764
+ - trimesh==4.6.6
765
+ - triton==3.0.0
766
+ - ucx-py-cu12==0.42.0
767
+ - ucxx-cu12==0.42.0
768
+ - vtk==9.3.1
769
+ - werkzeug==3.0.6
770
+ - widgetsnbextension==4.0.13
771
+ - xgboost==3.0.0
772
+ - yacs==0.1.8
PartField/partfield/__pycache__/dataloader.cpython-310.pyc ADDED
Binary file (7.71 kB). View file
 
PartField/partfield/__pycache__/model_trainer_pvcnn_only_demo.cpython-310.pyc ADDED
Binary file (7.72 kB). View file
 
PartField/partfield/__pycache__/utils.cpython-310.pyc ADDED
Binary file (372 Bytes). View file
 
PartField/partfield/config/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os.path as osp
3
+ from datetime import datetime
4
+ import pytz
5
+
6
+ def default_argument_parser(add_help=True, default_config_file=""):
7
+ parser = argparse.ArgumentParser(add_help=add_help)
8
+ parser.add_argument("--config-file", '-c', default=default_config_file, metavar="FILE", help="path to config file")
9
+ parser.add_argument(
10
+ "--opts",
11
+ help="Modify config options using the command-line",
12
+ default=None,
13
+ nargs=argparse.REMAINDER,
14
+ )
15
+ return parser
16
+
17
+ def setup(args, freeze=True):
18
+ from .defaults import _C as cfg
19
+ cfg = cfg.clone()
20
+ cfg.merge_from_file(args.config_file)
21
+ cfg.merge_from_list(args.opts)
22
+ dt = datetime.now(pytz.timezone('America/Los_Angeles')).strftime('%y%m%d-%H%M%S')
23
+ cfg.output_dir = osp.join(cfg.output_dir, cfg.name, dt)
24
+ if freeze:
25
+ cfg.freeze()
26
+ return cfg
PartField/partfield/config/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.15 kB). View file
 
PartField/partfield/config/__pycache__/defaults.cpython-310.pyc ADDED
Binary file (2.17 kB). View file
 
PartField/partfield/config/defaults.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+
3
+ _C = CN()
4
+ _C.seed = 0
5
+ _C.output_dir = "results"
6
+ _C.result_name = "test_all"
7
+
8
+ _C.triplet_sampling = "random"
9
+ _C.load_original_mesh = False
10
+
11
+ _C.num_pos = 64
12
+ _C.num_neg_random = 256
13
+ _C.num_neg_hard_pc = 128
14
+ _C.num_neg_hard_emb = 128
15
+
16
+ _C.vertex_feature = False # if true, sample feature on vertices; if false, sample feature on faces
17
+ _C.n_point_per_face = 2000
18
+ _C.n_sample_each = 10000
19
+ _C.preprocess_mesh = False
20
+
21
+ _C.regress_2d_feat = False
22
+
23
+ _C.is_pc = False
24
+
25
+ _C.cut_manifold = False
26
+ _C.remesh_demo = False
27
+ _C.correspondence_demo = False
28
+
29
+ _C.save_every_epoch = 10
30
+ _C.training_epochs = 30
31
+ _C.continue_training = False
32
+
33
+ _C.continue_ckpt = None
34
+ _C.epoch_selected = "epoch=50.ckpt"
35
+
36
+ _C.triplane_resolution = 128
37
+ _C.triplane_channels_low = 128
38
+ _C.triplane_channels_high = 512
39
+ _C.lr = 1e-3
40
+ _C.train = True
41
+ _C.test = False
42
+
43
+ _C.inference_save_pred_sdf_to_mesh=True
44
+ _C.inference_save_feat_pca=True
45
+ _C.name = "test"
46
+ _C.test_subset = False
47
+ _C.test_corres = False
48
+ _C.test_partobjaversetiny = False
49
+
50
+ _C.dataset = CN()
51
+ _C.dataset.type = "Demo_Dataset"
52
+ _C.dataset.data_path = "objaverse_data/"
53
+ _C.dataset.train_num_workers = 64
54
+ _C.dataset.val_num_workers = 32
55
+ _C.dataset.train_batch_size = 2
56
+ _C.dataset.val_batch_size = 2
57
+ _C.dataset.all_files = [] # only used for correspondence demo
58
+
59
+ _C.voxel2triplane = CN()
60
+ _C.voxel2triplane.transformer_dim = 1024
61
+ _C.voxel2triplane.transformer_layers = 6
62
+ _C.voxel2triplane.transformer_heads = 8
63
+ _C.voxel2triplane.triplane_low_res = 32
64
+ _C.voxel2triplane.triplane_high_res = 256
65
+ _C.voxel2triplane.triplane_dim = 64
66
+ _C.voxel2triplane.normalize_vox_feat = False
67
+
68
+
69
+ _C.loss = CN()
70
+ _C.loss.triplet = 0.0
71
+ _C.loss.sdf = 1.0
72
+ _C.loss.feat = 10.0
73
+ _C.loss.l1 = 0.0
74
+
75
+ _C.use_pvcnn = False
76
+ _C.use_pvcnnonly = True
77
+
78
+ _C.pvcnn = CN()
79
+ _C.pvcnn.point_encoder_type = 'pvcnn'
80
+ _C.pvcnn.use_point_scatter = True
81
+ _C.pvcnn.z_triplane_channels = 64
82
+ _C.pvcnn.z_triplane_resolution = 256
83
+ _C.pvcnn.unet_cfg = CN()
84
+ _C.pvcnn.unet_cfg.depth = 3
85
+ _C.pvcnn.unet_cfg.enabled = True
86
+ _C.pvcnn.unet_cfg.rolled = True
87
+ _C.pvcnn.unet_cfg.use_3d_aware = True
88
+ _C.pvcnn.unet_cfg.start_hidden_channels = 32
89
+ _C.pvcnn.unet_cfg.use_initial_conv = False
90
+
91
+ _C.use_2d_feat = False
92
+ _C.inference_metrics_only = False
PartField/partfield/dataloader.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import boto3
3
+ import json
4
+ from os import path as osp
5
+ # from botocore.config import Config
6
+ # from botocore.exceptions import ClientError
7
+ import h5py
8
+ import io
9
+ import numpy as np
10
+ import skimage
11
+ import trimesh
12
+ import os
13
+ from scipy.spatial import KDTree
14
+ import gc
15
+ from plyfile import PlyData
16
+
17
+ ## For remeshing
18
+ import mesh2sdf
19
+ import tetgen
20
+ import vtk
21
+ import math
22
+ import tempfile
23
+
24
+ ### For mesh processing
25
+ import pymeshlab
26
+
27
+ from partfield.utils import *
28
+
29
+ #########################
30
+ ## To handle quad inputs
31
+ #########################
32
+ def quad_to_triangle_mesh(F):
33
+ """
34
+ Converts a quad-dominant mesh into a pure triangle mesh by splitting quads into two triangles.
35
+
36
+ Parameters:
37
+ quad_mesh (trimesh.Trimesh): Input mesh with quad faces.
38
+
39
+ Returns:
40
+ trimesh.Trimesh: A new mesh with only triangle faces.
41
+ """
42
+ faces = F
43
+
44
+ ### If already a triangle mesh -- skip
45
+ if len(faces[0]) == 3:
46
+ return F
47
+
48
+ new_faces = []
49
+
50
+ for face in faces:
51
+ if len(face) == 4: # Quad face
52
+ # Split into two triangles
53
+ new_faces.append([face[0], face[1], face[2]]) # Triangle 1
54
+ new_faces.append([face[0], face[2], face[3]]) # Triangle 2
55
+ else:
56
+ print(f"Warning: Skipping non-triangle/non-quad face {face}")
57
+
58
+ new_faces = np.array(new_faces)
59
+
60
+ return new_faces
61
+ #########################
62
+
63
+ class Demo_Dataset(torch.utils.data.Dataset):
64
+ def __init__(self, cfg):
65
+ super().__init__()
66
+
67
+ self.data_path = cfg.dataset.data_path
68
+ self.is_pc = cfg.is_pc
69
+
70
+ all_files = os.listdir(self.data_path)
71
+
72
+ selected = []
73
+ for f in all_files:
74
+ if ".ply" in f and self.is_pc:
75
+ selected.append(f)
76
+ elif (".obj" in f or ".glb" in f or ".off" in f) and not self.is_pc:
77
+ selected.append(f)
78
+
79
+ self.data_list = selected
80
+ self.pc_num_pts = 100000
81
+
82
+ self.preprocess_mesh = cfg.preprocess_mesh
83
+ self.result_name = cfg.result_name
84
+
85
+ print("val dataset len:", len(self.data_list))
86
+
87
+
88
+ def __len__(self):
89
+ return len(self.data_list)
90
+
91
+ def load_ply_to_numpy(self, filename):
92
+ """
93
+ Load a PLY file and extract the point cloud as a (N, 3) NumPy array.
94
+
95
+ Parameters:
96
+ filename (str): Path to the PLY file.
97
+
98
+ Returns:
99
+ numpy.ndarray: Point cloud array of shape (N, 3).
100
+ """
101
+ ply_data = PlyData.read(filename)
102
+
103
+ # Extract vertex data
104
+ vertex_data = ply_data["vertex"]
105
+
106
+ # Convert to NumPy array (x, y, z)
107
+ points = np.vstack([vertex_data["x"], vertex_data["y"], vertex_data["z"]]).T
108
+
109
+ return points
110
+
111
+ def get_model(self, ply_file):
112
+
113
+ uid = ply_file.split(".")[-2].replace("/", "_")
114
+
115
+ ####
116
+ if self.is_pc:
117
+ ply_file_read = os.path.join(self.data_path, ply_file)
118
+ pc = self.load_ply_to_numpy(ply_file_read)
119
+
120
+ bbmin = pc.min(0)
121
+ bbmax = pc.max(0)
122
+ center = (bbmin + bbmax) * 0.5
123
+ scale = 2.0 * 0.9 / (bbmax - bbmin).max()
124
+ pc = (pc - center) * scale
125
+
126
+ else:
127
+ obj_path = os.path.join(self.data_path, ply_file)
128
+ mesh = load_mesh_util(obj_path)
129
+ vertices = mesh.vertices
130
+ faces = mesh.faces
131
+
132
+ bbmin = vertices.min(0)
133
+ bbmax = vertices.max(0)
134
+ center = (bbmin + bbmax) * 0.5
135
+ scale = 2.0 * 0.9 / (bbmax - bbmin).max()
136
+ vertices = (vertices - center) * scale
137
+ mesh.vertices = vertices
138
+
139
+ ### Make sure it is a triangle mesh -- just convert the quad
140
+ mesh.faces = quad_to_triangle_mesh(faces)
141
+
142
+ print("before preprocessing...")
143
+ print(mesh.vertices.shape)
144
+ print(mesh.faces.shape)
145
+ print()
146
+
147
+ ### Pre-process mesh
148
+ if self.preprocess_mesh:
149
+ # Create a PyMeshLab mesh directly from vertices and faces
150
+ ml_mesh = pymeshlab.Mesh(vertex_matrix=mesh.vertices, face_matrix=mesh.faces)
151
+
152
+ # Create a MeshSet and add your mesh
153
+ ms = pymeshlab.MeshSet()
154
+ ms.add_mesh(ml_mesh, "from_trimesh")
155
+
156
+ # Apply filters
157
+ ms.apply_filter('meshing_remove_duplicate_faces')
158
+ ms.apply_filter('meshing_remove_duplicate_vertices')
159
+ percentageMerge = pymeshlab.PercentageValue(0.5)
160
+ ms.apply_filter('meshing_merge_close_vertices', threshold=percentageMerge)
161
+ ms.apply_filter('meshing_remove_unreferenced_vertices')
162
+
163
+ # Save or extract mesh
164
+ processed = ms.current_mesh()
165
+ mesh.vertices = processed.vertex_matrix()
166
+ mesh.faces = processed.face_matrix()
167
+
168
+ print("after preprocessing...")
169
+ print(mesh.vertices.shape)
170
+ print(mesh.faces.shape)
171
+
172
+ ### Save input
173
+ save_dir = f"exp_results/{self.result_name}"
174
+ os.makedirs(save_dir, exist_ok=True)
175
+ view_id = 0
176
+ mesh.export(f'{save_dir}/input_{uid}_{view_id}.ply')
177
+
178
+
179
+ pc, _ = trimesh.sample.sample_surface(mesh, self.pc_num_pts)
180
+
181
+ result = {
182
+ 'uid': uid
183
+ }
184
+
185
+ result['pc'] = torch.tensor(pc, dtype=torch.float32)
186
+
187
+ if not self.is_pc:
188
+ result['vertices'] = mesh.vertices
189
+ result['faces'] = mesh.faces
190
+
191
+ return result
192
+
193
+ def __getitem__(self, index):
194
+
195
+ gc.collect()
196
+
197
+ return self.get_model(self.data_list[index])
198
+
199
+ ##############
200
+
201
+ ###############################
202
+ class Demo_Remesh_Dataset(torch.utils.data.Dataset):
203
+ def __init__(self, cfg):
204
+ super().__init__()
205
+
206
+ self.data_path = cfg.dataset.data_path
207
+
208
+ all_files = os.listdir(self.data_path)
209
+
210
+ selected = []
211
+ for f in all_files:
212
+ if (".obj" in f or ".glb" in f):
213
+ selected.append(f)
214
+
215
+ self.data_list = selected
216
+ self.pc_num_pts = 100000
217
+
218
+ self.preprocess_mesh = cfg.preprocess_mesh
219
+ self.result_name = cfg.result_name
220
+
221
+ print("val dataset len:", len(self.data_list))
222
+
223
+
224
+ def __len__(self):
225
+ return len(self.data_list)
226
+
227
+
228
+ def get_model(self, ply_file):
229
+
230
+ uid = ply_file.split(".")[-2]
231
+
232
+ ####
233
+ obj_path = os.path.join(self.data_path, ply_file)
234
+ mesh = load_mesh_util(obj_path)
235
+ vertices = mesh.vertices
236
+ faces = mesh.faces
237
+
238
+ bbmin = vertices.min(0)
239
+ bbmax = vertices.max(0)
240
+ center = (bbmin + bbmax) * 0.5
241
+ scale = 2.0 * 0.9 / (bbmax - bbmin).max()
242
+ vertices = (vertices - center) * scale
243
+ mesh.vertices = vertices
244
+
245
+ ### Pre-process mesh
246
+ if self.preprocess_mesh:
247
+ # Create a PyMeshLab mesh directly from vertices and faces
248
+ ml_mesh = pymeshlab.Mesh(vertex_matrix=mesh.vertices, face_matrix=mesh.faces)
249
+
250
+ # Create a MeshSet and add your mesh
251
+ ms = pymeshlab.MeshSet()
252
+ ms.add_mesh(ml_mesh, "from_trimesh")
253
+
254
+ # Apply filters
255
+ ms.apply_filter('meshing_remove_duplicate_faces')
256
+ ms.apply_filter('meshing_remove_duplicate_vertices')
257
+ percentageMerge = pymeshlab.PercentageValue(0.5)
258
+ ms.apply_filter('meshing_merge_close_vertices', threshold=percentageMerge)
259
+ ms.apply_filter('meshing_remove_unreferenced_vertices')
260
+
261
+
262
+ # Save or extract mesh
263
+ processed = ms.current_mesh()
264
+ mesh.vertices = processed.vertex_matrix()
265
+ mesh.faces = processed.face_matrix()
266
+
267
+ print("after preprocessing...")
268
+ print(mesh.vertices.shape)
269
+ print(mesh.faces.shape)
270
+
271
+ ### Save input
272
+ save_dir = f"exp_results/{self.result_name}"
273
+ os.makedirs(save_dir, exist_ok=True)
274
+ view_id = 0
275
+ mesh.export(f'{save_dir}/input_{uid}_{view_id}.ply')
276
+
277
+ try:
278
+ ###### Remesh ######
279
+ size= 256
280
+ level = 2 / size
281
+
282
+ sdf = mesh2sdf.core.compute(mesh.vertices, mesh.faces, size)
283
+ # NOTE: the negative value is not reliable if the mesh is not watertight
284
+ udf = np.abs(sdf)
285
+ vertices, faces, _, _ = skimage.measure.marching_cubes(udf, level)
286
+
287
+ #### Only use SDF mesh ###
288
+ # new_mesh = trimesh.Trimesh(vertices, faces)
289
+ ##########################
290
+
291
+ #### Make tet #####
292
+ components = trimesh.Trimesh(vertices, faces).split(only_watertight=False)
293
+ new_mesh = [] #trimesh.Trimesh()
294
+ if len(components) > 100000:
295
+ raise NotImplementedError
296
+ for i, c in enumerate(components):
297
+ c.fix_normals()
298
+ new_mesh.append(c) #trimesh.util.concatenate(new_mesh, c)
299
+ new_mesh = trimesh.util.concatenate(new_mesh)
300
+
301
+ # generate tet mesh
302
+ tet = tetgen.TetGen(new_mesh.vertices, new_mesh.faces)
303
+ tet.tetrahedralize(plc=True, nobisect=1., quality=True, fixedvolume=True, maxvolume=math.sqrt(2) / 12 * (2 / size) ** 3)
304
+ tmp_vtk = tempfile.NamedTemporaryFile(suffix='.vtk', delete=True)
305
+ tet.grid.save(tmp_vtk.name)
306
+
307
+ # extract surface mesh from tet mesh
308
+ reader = vtk.vtkUnstructuredGridReader()
309
+ reader.SetFileName(tmp_vtk.name)
310
+ reader.Update()
311
+ surface_filter = vtk.vtkDataSetSurfaceFilter()
312
+ surface_filter.SetInputConnection(reader.GetOutputPort())
313
+ surface_filter.Update()
314
+ polydata = surface_filter.GetOutput()
315
+ writer = vtk.vtkOBJWriter()
316
+ tmp_obj = tempfile.NamedTemporaryFile(suffix='.obj', delete=True)
317
+ writer.SetFileName(tmp_obj.name)
318
+ writer.SetInputData(polydata)
319
+ writer.Update()
320
+ new_mesh = load_mesh_util(tmp_obj.name)
321
+ ##########################
322
+
323
+ new_mesh.vertices = new_mesh.vertices * (2.0 / size) - 1.0 # normalize it to [-1, 1]
324
+
325
+ mesh = new_mesh
326
+ ####################
327
+
328
+ except:
329
+ print("Error in tet.")
330
+ mesh = mesh
331
+
332
+ pc, _ = trimesh.sample.sample_surface(mesh, self.pc_num_pts)
333
+
334
+ result = {
335
+ 'uid': uid
336
+ }
337
+
338
+ result['pc'] = torch.tensor(pc, dtype=torch.float32)
339
+ result['vertices'] = mesh.vertices
340
+ result['faces'] = mesh.faces
341
+
342
+ return result
343
+
344
+ def __getitem__(self, index):
345
+
346
+ gc.collect()
347
+
348
+ return self.get_model(self.data_list[index])
349
+
350
+
351
+ class Correspondence_Demo_Dataset(Demo_Dataset):
352
+ def __init__(self, cfg):
353
+ super().__init__(cfg)
354
+
355
+ self.data_path = cfg.dataset.data_path
356
+ self.is_pc = cfg.is_pc
357
+
358
+ self.data_list = cfg.dataset.all_files
359
+
360
+ self.pc_num_pts = 100000
361
+
362
+ self.preprocess_mesh = cfg.preprocess_mesh
363
+ self.result_name = cfg.result_name
364
+
365
+ print("val dataset len:", len(self.data_list))
366
+
PartField/partfield/model/PVCNN/__pycache__/conv_pointnet.cpython-310.pyc ADDED
Binary file (6.6 kB). View file
 
PartField/partfield/model/PVCNN/__pycache__/dnnlib_util.cpython-310.pyc ADDED
Binary file (33.9 kB). View file
 
PartField/partfield/model/PVCNN/__pycache__/encoder_pc.cpython-310.pyc ADDED
Binary file (6.09 kB). View file
 
PartField/partfield/model/PVCNN/__pycache__/pc_encoder.cpython-310.pyc ADDED
Binary file (3.44 kB). View file
 
PartField/partfield/model/PVCNN/__pycache__/unet_3daware.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
PartField/partfield/model/PVCNN/conv_pointnet.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from gensdf
3
+ https://github.com/princeton-computational-imaging/gensdf
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ # from dnnlib.util import printarr
10
+ try:
11
+ from torch_scatter import scatter_mean, scatter_max
12
+ except:
13
+ pass
14
+ # from .unet import UNet
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+
20
+ # Resnet Blocks
21
+ class ResnetBlockFC(nn.Module):
22
+ ''' Fully connected ResNet Block class.
23
+ Args:
24
+ size_in (int): input dimension
25
+ size_out (int): output dimension
26
+ size_h (int): hidden dimension
27
+ '''
28
+
29
+ def __init__(self, size_in, size_out=None, size_h=None):
30
+ super().__init__()
31
+ # Attributes
32
+ if size_out is None:
33
+ size_out = size_in
34
+
35
+ if size_h is None:
36
+ size_h = min(size_in, size_out)
37
+
38
+ self.size_in = size_in
39
+ self.size_h = size_h
40
+ self.size_out = size_out
41
+ # Submodules
42
+ self.fc_0 = nn.Linear(size_in, size_h)
43
+ self.fc_1 = nn.Linear(size_h, size_out)
44
+ self.actvn = nn.ReLU()
45
+
46
+ if size_in == size_out:
47
+ self.shortcut = None
48
+ else:
49
+ self.shortcut = nn.Linear(size_in, size_out, bias=False)
50
+ # Initialization
51
+ nn.init.zeros_(self.fc_1.weight)
52
+
53
+ def forward(self, x):
54
+ net = self.fc_0(self.actvn(x))
55
+ dx = self.fc_1(self.actvn(net))
56
+
57
+ if self.shortcut is not None:
58
+ x_s = self.shortcut(x)
59
+ else:
60
+ x_s = x
61
+
62
+ return x_s + dx
63
+
64
+
65
+ class ConvPointnet(nn.Module):
66
+ ''' PointNet-based encoder network with ResNet blocks for each point.
67
+ Number of input points are fixed.
68
+
69
+ Args:
70
+ c_dim (int): dimension of latent code c
71
+ dim (int): input points dimension
72
+ hidden_dim (int): hidden dimension of the network
73
+ scatter_type (str): feature aggregation when doing local pooling
74
+ unet (bool): weather to use U-Net
75
+ unet_kwargs (str): U-Net parameters
76
+ plane_resolution (int): defined resolution for plane feature
77
+ plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
78
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
79
+ n_blocks (int): number of blocks ResNetBlockFC layers
80
+ '''
81
+
82
+ def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
83
+ # unet=False, unet_kwargs=None,
84
+ plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
85
+ super().__init__()
86
+ self.c_dim = c_dim
87
+
88
+ self.fc_pos = nn.Linear(dim, 2*hidden_dim)
89
+ self.blocks = nn.ModuleList([
90
+ ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
91
+ ])
92
+ self.fc_c = nn.Linear(hidden_dim, c_dim)
93
+
94
+ self.actvn = nn.ReLU()
95
+ self.hidden_dim = hidden_dim
96
+
97
+ # if unet:
98
+ # self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
99
+ # else:
100
+ # self.unet = None
101
+
102
+ self.reso_plane = plane_resolution
103
+ self.plane_type = plane_type
104
+ self.padding = padding
105
+
106
+ if scatter_type == 'max':
107
+ self.scatter = scatter_max
108
+ elif scatter_type == 'mean':
109
+ self.scatter = scatter_mean
110
+
111
+
112
+ # takes in "p": point cloud and "query": sdf_xyz
113
+ # sample plane features for unlabeled_query as well
114
+ def forward(self, p):#, query2):
115
+ batch_size, T, D = p.size()
116
+
117
+ # acquire the index for each point
118
+ coord = {}
119
+ index = {}
120
+ if 'xz' in self.plane_type:
121
+ coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
122
+ index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
123
+ if 'xy' in self.plane_type:
124
+ coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
125
+ index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
126
+ if 'yz' in self.plane_type:
127
+ coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
128
+ index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)
129
+
130
+
131
+ net = self.fc_pos(p)
132
+
133
+ net = self.blocks[0](net)
134
+ for block in self.blocks[1:]:
135
+ pooled = self.pool_local(coord, index, net)
136
+ net = torch.cat([net, pooled], dim=2)
137
+ net = block(net)
138
+
139
+ c = self.fc_c(net)
140
+
141
+ fea = {}
142
+ plane_feat_sum = 0
143
+ #second_sum = 0
144
+ if 'xz' in self.plane_type:
145
+ fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
146
+ # plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')
147
+ #second_sum += self.sample_plane_feature(query2, fea['xz'], 'xz')
148
+ if 'xy' in self.plane_type:
149
+ fea['xy'] = self.generate_plane_features(p, c, plane='xy')
150
+ # plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')
151
+ #second_sum += self.sample_plane_feature(query2, fea['xy'], 'xy')
152
+ if 'yz' in self.plane_type:
153
+ fea['yz'] = self.generate_plane_features(p, c, plane='yz')
154
+ # plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')
155
+ #second_sum += self.sample_plane_feature(query2, fea['yz'], 'yz')
156
+ return fea
157
+
158
+ # return plane_feat_sum.transpose(2,1)#, second_sum.transpose(2,1)
159
+
160
+
161
+ def normalize_coordinate(self, p, padding=0.1, plane='xz'):
162
+ ''' Normalize coordinate to [0, 1] for unit cube experiments
163
+
164
+ Args:
165
+ p (tensor): point
166
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
167
+ plane (str): plane feature type, ['xz', 'xy', 'yz']
168
+ '''
169
+ if plane == 'xz':
170
+ xy = p[:, :, [0, 2]]
171
+ elif plane =='xy':
172
+ xy = p[:, :, [0, 1]]
173
+ else:
174
+ xy = p[:, :, [1, 2]]
175
+
176
+ xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
177
+ xy_new = xy_new + 0.5 # range (0, 1)
178
+
179
+ # f there are outliers out of the range
180
+ if xy_new.max() >= 1:
181
+ xy_new[xy_new >= 1] = 1 - 10e-6
182
+ if xy_new.min() < 0:
183
+ xy_new[xy_new < 0] = 0.0
184
+ return xy_new
185
+
186
+
187
+ def coordinate2index(self, x, reso):
188
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
189
+ Corresponds to our 3D model
190
+
191
+ Args:
192
+ x (tensor): coordinate
193
+ reso (int): defined resolution
194
+ coord_type (str): coordinate type
195
+ '''
196
+ x = (x * reso).long()
197
+ index = x[:, :, 0] + reso * x[:, :, 1]
198
+ index = index[:, None, :]
199
+ return index
200
+
201
+
202
+ # xy is the normalized coordinates of the point cloud of each plane
203
+ # I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input
204
+ def pool_local(self, xy, index, c):
205
+ bs, fea_dim = c.size(0), c.size(2)
206
+ keys = xy.keys()
207
+
208
+ c_out = 0
209
+ for key in keys:
210
+ # scatter plane features from points
211
+ fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
212
+ if self.scatter == scatter_max:
213
+ fea = fea[0]
214
+ # gather feature back to points
215
+ fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
216
+ c_out += fea
217
+ return c_out.permute(0, 2, 1)
218
+
219
+
220
+ def generate_plane_features(self, p, c, plane='xz'):
221
+ # acquire indices of features in plane
222
+ xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
223
+ index = self.coordinate2index(xy, self.reso_plane)
224
+
225
+ # scatter plane features from points
226
+ fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
227
+ c = c.permute(0, 2, 1) # B x 512 x T
228
+ fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
229
+ fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
230
+
231
+ # printarr(fea_plane, c, p, xy, index)
232
+ # import pdb; pdb.set_trace()
233
+
234
+ # process the plane features with UNet
235
+ # if self.unet is not None:
236
+ # fea_plane = self.unet(fea_plane)
237
+
238
+ return fea_plane
239
+
240
+
241
+ # sample_plane_feature function copied from /src/conv_onet/models/decoder.py
242
+ # uses values from plane_feature and pixel locations from vgrid to interpolate feature
243
+ def sample_plane_feature(self, query, plane_feature, plane):
244
+ xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
245
+ xy = xy[:, :, None].float()
246
+ vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
247
+ sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
248
+ return sampled_feat
249
+
250
+
251
+
PartField/partfield/model/PVCNN/dnnlib_util.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+ from collections import namedtuple
11
+ import time
12
+ import ctypes
13
+ import fnmatch
14
+ import importlib
15
+ import inspect
16
+ import numpy as np
17
+ import json
18
+ import os
19
+ import shutil
20
+ import sys
21
+ import types
22
+ import io
23
+ import pickle
24
+ import re
25
+ # import requests
26
+ import html
27
+ import hashlib
28
+ import glob
29
+ import tempfile
30
+ import urllib
31
+ import urllib.request
32
+ import uuid
33
+ import boto3
34
+ import threading
35
+ from contextlib import ContextDecorator
36
+ from contextlib import contextmanager, nullcontext
37
+
38
+ from distutils.util import strtobool
39
+ from typing import Any, List, Tuple, Union
40
+ import importlib
41
+ from loguru import logger
42
+ # import wandb
43
+ import torch
44
+ import psutil
45
+ import subprocess
46
+
47
+ import random
48
+ import string
49
+ import pdb
50
+
51
+ # Util classes
52
+ # ------------------------------------------------------------------------------------------
53
+
54
+
55
+ class EasyDict(dict):
56
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
57
+
58
+ def __getattr__(self, name: str) -> Any:
59
+ try:
60
+ return self[name]
61
+ except KeyError:
62
+ raise AttributeError(name)
63
+
64
+ def __setattr__(self, name: str, value: Any) -> None:
65
+ self[name] = value
66
+
67
+ def __delattr__(self, name: str) -> None:
68
+ del self[name]
69
+
70
+
71
+ class Logger(object):
72
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
73
+
74
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
75
+ self.file = None
76
+
77
+ if file_name is not None:
78
+ self.file = open(file_name, file_mode)
79
+
80
+ self.should_flush = should_flush
81
+ self.stdout = sys.stdout
82
+ self.stderr = sys.stderr
83
+
84
+ sys.stdout = self
85
+ sys.stderr = self
86
+
87
+ def __enter__(self) -> "Logger":
88
+ return self
89
+
90
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
91
+ self.close()
92
+
93
+ def write(self, text: Union[str, bytes]) -> None:
94
+ """Write text to stdout (and a file) and optionally flush."""
95
+ if isinstance(text, bytes):
96
+ text = text.decode()
97
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
98
+ return
99
+
100
+ if self.file is not None:
101
+ self.file.write(text)
102
+
103
+ self.stdout.write(text)
104
+
105
+ if self.should_flush:
106
+ self.flush()
107
+
108
+ def flush(self) -> None:
109
+ """Flush written text to both stdout and a file, if open."""
110
+ if self.file is not None:
111
+ self.file.flush()
112
+
113
+ self.stdout.flush()
114
+
115
+ def close(self) -> None:
116
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
117
+ self.flush()
118
+
119
+ # if using multiple loggers, prevent closing in wrong order
120
+ if sys.stdout is self:
121
+ sys.stdout = self.stdout
122
+ if sys.stderr is self:
123
+ sys.stderr = self.stderr
124
+
125
+ if self.file is not None:
126
+ self.file.close()
127
+ self.file = None
128
+
129
+
130
+ # Cache directories
131
+ # ------------------------------------------------------------------------------------------
132
+
133
+ _dnnlib_cache_dir = None
134
+
135
+
136
+ def set_cache_dir(path: str) -> None:
137
+ global _dnnlib_cache_dir
138
+ _dnnlib_cache_dir = path
139
+
140
+
141
+ def make_cache_dir_path(*paths: str) -> str:
142
+ if _dnnlib_cache_dir is not None:
143
+ return os.path.join(_dnnlib_cache_dir, *paths)
144
+ if 'DNNLIB_CACHE_DIR' in os.environ:
145
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
146
+ if 'HOME' in os.environ:
147
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
148
+ if 'USERPROFILE' in os.environ:
149
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
150
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
151
+
152
+
153
+ # Small util functions
154
+ # ------------------------------------------------------------------------------------------
155
+
156
+
157
+ def format_time(seconds: Union[int, float]) -> str:
158
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
159
+ s = int(np.rint(seconds))
160
+
161
+ if s < 60:
162
+ return "{0}s".format(s)
163
+ elif s < 60 * 60:
164
+ return "{0}m {1:02}s".format(s // 60, s % 60)
165
+ elif s < 24 * 60 * 60:
166
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
167
+ else:
168
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
169
+
170
+
171
+ def format_time_brief(seconds: Union[int, float]) -> str:
172
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
173
+ s = int(np.rint(seconds))
174
+
175
+ if s < 60:
176
+ return "{0}s".format(s)
177
+ elif s < 60 * 60:
178
+ return "{0}m {1:02}s".format(s // 60, s % 60)
179
+ elif s < 24 * 60 * 60:
180
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
181
+ else:
182
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
183
+
184
+
185
+ def ask_yes_no(question: str) -> bool:
186
+ """Ask the user the question until the user inputs a valid answer."""
187
+ while True:
188
+ try:
189
+ print("{0} [y/n]".format(question))
190
+ return strtobool(input().lower())
191
+ except ValueError:
192
+ pass
193
+
194
+
195
+ def tuple_product(t: Tuple) -> Any:
196
+ """Calculate the product of the tuple elements."""
197
+ result = 1
198
+
199
+ for v in t:
200
+ result *= v
201
+
202
+ return result
203
+
204
+
205
+ _str_to_ctype = {
206
+ "uint8": ctypes.c_ubyte,
207
+ "uint16": ctypes.c_uint16,
208
+ "uint32": ctypes.c_uint32,
209
+ "uint64": ctypes.c_uint64,
210
+ "int8": ctypes.c_byte,
211
+ "int16": ctypes.c_int16,
212
+ "int32": ctypes.c_int32,
213
+ "int64": ctypes.c_int64,
214
+ "float32": ctypes.c_float,
215
+ "float64": ctypes.c_double
216
+ }
217
+
218
+
219
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
220
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
221
+ type_str = None
222
+
223
+ if isinstance(type_obj, str):
224
+ type_str = type_obj
225
+ elif hasattr(type_obj, "__name__"):
226
+ type_str = type_obj.__name__
227
+ elif hasattr(type_obj, "name"):
228
+ type_str = type_obj.name
229
+ else:
230
+ raise RuntimeError("Cannot infer type name from input")
231
+
232
+ assert type_str in _str_to_ctype.keys()
233
+
234
+ my_dtype = np.dtype(type_str)
235
+ my_ctype = _str_to_ctype[type_str]
236
+
237
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
238
+
239
+ return my_dtype, my_ctype
240
+
241
+
242
+ def is_pickleable(obj: Any) -> bool:
243
+ try:
244
+ with io.BytesIO() as stream:
245
+ pickle.dump(obj, stream)
246
+ return True
247
+ except:
248
+ return False
249
+
250
+
251
+ # Functionality to import modules/objects by name, and call functions by name
252
+ # ------------------------------------------------------------------------------------------
253
+
254
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
255
+ """Searches for the underlying module behind the name to some python object.
256
+ Returns the module and the object name (original name with module part removed)."""
257
+
258
+ # allow convenience shorthands, substitute them by full names
259
+ obj_name = re.sub("^np.", "numpy.", obj_name)
260
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
261
+
262
+ # list alternatives for (module_name, local_obj_name)
263
+ parts = obj_name.split(".")
264
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
265
+
266
+ # try each alternative in turn
267
+ for module_name, local_obj_name in name_pairs:
268
+ try:
269
+ module = importlib.import_module(module_name) # may raise ImportError
270
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
271
+ return module, local_obj_name
272
+ except:
273
+ pass
274
+
275
+ # maybe some of the modules themselves contain errors?
276
+ for module_name, _local_obj_name in name_pairs:
277
+ try:
278
+ importlib.import_module(module_name) # may raise ImportError
279
+ except ImportError:
280
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
281
+ raise
282
+
283
+ # maybe the requested attribute is missing?
284
+ for module_name, local_obj_name in name_pairs:
285
+ try:
286
+ module = importlib.import_module(module_name) # may raise ImportError
287
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
288
+ except ImportError:
289
+ pass
290
+
291
+ # we are out of luck, but we have no idea why
292
+ raise ImportError(obj_name)
293
+
294
+
295
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
296
+ """Traverses the object name and returns the last (rightmost) python object."""
297
+ if obj_name == '':
298
+ return module
299
+ obj = module
300
+ for part in obj_name.split("."):
301
+ obj = getattr(obj, part)
302
+ return obj
303
+
304
+
305
+ def get_obj_by_name(name: str) -> Any:
306
+ """Finds the python object with the given name."""
307
+ module, obj_name = get_module_from_obj_name(name)
308
+ return get_obj_from_module(module, obj_name)
309
+
310
+
311
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
312
+ """Finds the python object with the given name and calls it as a function."""
313
+ assert func_name is not None
314
+ func_obj = get_obj_by_name(func_name)
315
+ assert callable(func_obj)
316
+ return func_obj(*args, **kwargs)
317
+
318
+
319
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
320
+ """Finds the python class with the given name and constructs it with the given arguments."""
321
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
322
+
323
+
324
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
325
+ """Get the directory path of the module containing the given object name."""
326
+ module, _ = get_module_from_obj_name(obj_name)
327
+ return os.path.dirname(inspect.getfile(module))
328
+
329
+
330
+ def is_top_level_function(obj: Any) -> bool:
331
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
332
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
333
+
334
+
335
+ def get_top_level_function_name(obj: Any) -> str:
336
+ """Return the fully-qualified name of a top-level function."""
337
+ assert is_top_level_function(obj)
338
+ module = obj.__module__
339
+ if module == '__main__':
340
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
341
+ return module + "." + obj.__name__
342
+
343
+
344
+ # File system helpers
345
+ # ------------------------------------------------------------------------------------------
346
+
347
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
348
+ """List all files recursively in a given directory while ignoring given file and directory names.
349
+ Returns list of tuples containing both absolute and relative paths."""
350
+ assert os.path.isdir(dir_path)
351
+ base_name = os.path.basename(os.path.normpath(dir_path))
352
+
353
+ if ignores is None:
354
+ ignores = []
355
+
356
+ result = []
357
+
358
+ for root, dirs, files in os.walk(dir_path, topdown=True):
359
+ for ignore_ in ignores:
360
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
361
+
362
+ # dirs need to be edited in-place
363
+ for d in dirs_to_remove:
364
+ dirs.remove(d)
365
+
366
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
367
+
368
+ absolute_paths = [os.path.join(root, f) for f in files]
369
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
370
+
371
+ if add_base_to_relative:
372
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
373
+
374
+ assert len(absolute_paths) == len(relative_paths)
375
+ result += zip(absolute_paths, relative_paths)
376
+
377
+ return result
378
+
379
+
380
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
381
+ """Takes in a list of tuples of (src, dst) paths and copies files.
382
+ Will create all necessary directories."""
383
+ for file in files:
384
+ target_dir_name = os.path.dirname(file[1])
385
+
386
+ # will create all intermediate-level directories
387
+ if not os.path.exists(target_dir_name):
388
+ os.makedirs(target_dir_name)
389
+
390
+ shutil.copyfile(file[0], file[1])
391
+
392
+
393
+ # URL helpers
394
+ # ------------------------------------------------------------------------------------------
395
+
396
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
397
+ """Determine whether the given object is a valid URL string."""
398
+ if not isinstance(obj, str) or not "://" in obj:
399
+ return False
400
+ if allow_file_urls and obj.startswith('file://'):
401
+ return True
402
+ try:
403
+ res = requests.compat.urlparse(obj)
404
+ if not res.scheme or not res.netloc or not "." in res.netloc:
405
+ return False
406
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
407
+ if not res.scheme or not res.netloc or not "." in res.netloc:
408
+ return False
409
+ except:
410
+ return False
411
+ return True
412
+
413
+
414
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
415
+ """Download the given URL and return a binary-mode file object to access the data."""
416
+ assert num_attempts >= 1
417
+ assert not (return_filename and (not cache))
418
+
419
+ # Doesn't look like an URL scheme so interpret it as a local filename.
420
+ if not re.match('^[a-z]+://', url):
421
+ return url if return_filename else open(url, "rb")
422
+
423
+ # Handle file URLs. This code handles unusual file:// patterns that
424
+ # arise on Windows:
425
+ #
426
+ # file:///c:/foo.txt
427
+ #
428
+ # which would translate to a local '/c:/foo.txt' filename that's
429
+ # invalid. Drop the forward slash for such pathnames.
430
+ #
431
+ # If you touch this code path, you should test it on both Linux and
432
+ # Windows.
433
+ #
434
+ # Some internet resources suggest using urllib.request.url2pathname() but
435
+ # but that converts forward slashes to backslashes and this causes
436
+ # its own set of problems.
437
+ if url.startswith('file://'):
438
+ filename = urllib.parse.urlparse(url).path
439
+ if re.match(r'^/[a-zA-Z]:', filename):
440
+ filename = filename[1:]
441
+ return filename if return_filename else open(filename, "rb")
442
+
443
+ assert is_url(url)
444
+
445
+ # Lookup from cache.
446
+ if cache_dir is None:
447
+ cache_dir = make_cache_dir_path('downloads')
448
+
449
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
450
+ if cache:
451
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
452
+ if len(cache_files) == 1:
453
+ filename = cache_files[0]
454
+ return filename if return_filename else open(filename, "rb")
455
+
456
+ # Download.
457
+ url_name = None
458
+ url_data = None
459
+ with requests.Session() as session:
460
+ if verbose:
461
+ print("Downloading %s ..." % url, end="", flush=True)
462
+ for attempts_left in reversed(range(num_attempts)):
463
+ try:
464
+ with session.get(url) as res:
465
+ res.raise_for_status()
466
+ if len(res.content) == 0:
467
+ raise IOError("No data received")
468
+
469
+ if len(res.content) < 8192:
470
+ content_str = res.content.decode("utf-8")
471
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
472
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
473
+ if len(links) == 1:
474
+ url = requests.compat.urljoin(url, links[0])
475
+ raise IOError("Google Drive virus checker nag")
476
+ if "Google Drive - Quota exceeded" in content_str:
477
+ raise IOError("Google Drive download quota exceeded -- please try again later")
478
+
479
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
480
+ url_name = match[1] if match else url
481
+ url_data = res.content
482
+ if verbose:
483
+ print(" done")
484
+ break
485
+ except KeyboardInterrupt:
486
+ raise
487
+ except:
488
+ if not attempts_left:
489
+ if verbose:
490
+ print(" failed")
491
+ raise
492
+ if verbose:
493
+ print(".", end="", flush=True)
494
+
495
+ # Save to cache.
496
+ if cache:
497
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
498
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
499
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
500
+ os.makedirs(cache_dir, exist_ok=True)
501
+ with open(temp_file, "wb") as f:
502
+ f.write(url_data)
503
+ os.replace(temp_file, cache_file) # atomic
504
+ if return_filename:
505
+ return cache_file
506
+
507
+ # Return data as file object.
508
+ assert not return_filename
509
+ return io.BytesIO(url_data)
510
+
511
+ # ------------------------------------------------------------------------------------------
512
+ # util function modified from https://github.com/nv-tlabs/LION/blob/0467d2199076e95a7e88bafd99dcd7d48a04b4a7/utils/model_helper.py
513
+ def import_class(model_str):
514
+ from torch_utils.dist_utils import is_rank0
515
+ if is_rank0():
516
+ logger.info('import: {}', model_str)
517
+ p, m = model_str.rsplit('.', 1)
518
+ mod = importlib.import_module(p)
519
+ Model = getattr(mod, m)
520
+ return Model
521
+
522
+ class ScopedTorchProfiler(ContextDecorator):
523
+ """
524
+ Marks ranges for both nvtx profiling (with nsys) and torch autograd profiler
525
+ """
526
+ __global_counts = {}
527
+ enabled=False
528
+
529
+ def __init__(self, unique_name: str):
530
+ """
531
+ Names must be unique!
532
+ """
533
+ ScopedTorchProfiler.__global_counts[unique_name] = 0
534
+ self._name = unique_name
535
+ self._autograd_scope = torch.profiler.record_function(unique_name)
536
+
537
+ def __enter__(self):
538
+ if ScopedTorchProfiler.enabled:
539
+ torch.cuda.nvtx.range_push(self._name)
540
+ self._autograd_scope.__enter__()
541
+
542
+ def __exit__(self, exc_type, exc_value, traceback):
543
+ self._autograd_scope.__exit__(exc_type, exc_value, traceback)
544
+ if ScopedTorchProfiler.enabled:
545
+ torch.cuda.nvtx.range_pop()
546
+
547
+ class TimingsMonitor():
548
+ CUDATimer = namedtuple('CUDATimer', ['start', 'end'])
549
+ def __init__(self, device, enabled=True, timing_names:List[str]=[], cuda_timing_names:List[str]=[]):
550
+ """
551
+ Usage:
552
+ tmonitor = TimingsMonitor(device)
553
+ for i in range(n_iter):
554
+ # Record arbitrary scopes
555
+ with tmonitor.timing_scope('regular_scope_name'):
556
+ ...
557
+ with tmonitor.cuda_timing_scope('nested_scope_name'):
558
+ ...
559
+ with tmonitor.cuda_timing_scope('cuda_scope_name'):
560
+ ...
561
+ tmonitor.record_timing('duration_name', end_time - start_time)
562
+
563
+ # Gather timings
564
+ tmonitor.record_all_cuda_timings()
565
+ tmonitor.update_all_averages()
566
+ averages = tmonitor.get_average_timings()
567
+ all_timings = tmonitor.get_timings()
568
+
569
+ Two types of timers, standard report timing and cuda timings.
570
+ Cuda timing supports scoped context manager cuda_event_scope.
571
+ Args:
572
+ device: device to time on (needed for cuda timers)
573
+ # enabled: HACK to only report timings from rank 0, set enabled=(global_rank==0)
574
+ timing_names: timings to report optional (will auto add new names)
575
+ cuda_timing_names: cuda periods to time optional (will auto add new names)
576
+ """
577
+ self.enabled=enabled
578
+ self.device = device
579
+
580
+ # Normal timing
581
+ # self.all_timings_dict = {k:None for k in timing_names + cuda_timing_names}
582
+ self.all_timings_dict = {}
583
+ self.avg_meter_dict = {}
584
+
585
+ # Cuda event timers to measure time spent on pushing data to gpu and on training step
586
+ self.cuda_event_timers = {}
587
+
588
+ for k in timing_names:
589
+ self.add_new_timing(k)
590
+
591
+ for k in cuda_timing_names:
592
+ self.add_new_cuda_timing(k)
593
+
594
+ # Running averages
595
+ # self.avg_meter_dict = {k:AverageMeter() for k in self.all_timings_dict}
596
+
597
+ def add_new_timing(self, name):
598
+ self.avg_meter_dict[name] = AverageMeter()
599
+ self.all_timings_dict[name] = None
600
+
601
+ def add_new_cuda_timing(self, name):
602
+ start_event = torch.cuda.Event(enable_timing=True)
603
+ end_event = torch.cuda.Event(enable_timing=True)
604
+ self.cuda_event_timers[name] = self.CUDATimer(start=start_event, end=end_event)
605
+ self.add_new_timing(name)
606
+
607
+ def clear_timings(self):
608
+ self.all_timings_dict = {k:None for k in self.all_timings_dict}
609
+
610
+ def get_timings(self):
611
+ return self.all_timings_dict
612
+
613
+ def get_average_timings(self):
614
+ return {k:v.avg for k,v in self.avg_meter_dict.items()}
615
+
616
+ def update_all_averages(self):
617
+ """
618
+ Once per iter, when timings have been finished recording, one should
619
+ call update_average_iter to keep running average of timings.
620
+ """
621
+ for k,v in self.all_timings_dict.items():
622
+ if v is None:
623
+ print("none_timing", k)
624
+ continue
625
+ self.avg_meter_dict[k].update(v)
626
+
627
+ def record_timing(self, name, value):
628
+ if name not in self.all_timings_dict: self.add_new_timing(name)
629
+ # assert name in self.all_timings_dict
630
+ self.all_timings_dict[name] = value
631
+
632
+ def _record_cuda_event_start(self, name):
633
+ if name in self.cuda_event_timers:
634
+ self.cuda_event_timers[name].start.record(
635
+ torch.cuda.current_stream(self.device))
636
+
637
+ def _record_cuda_event_end(self, name):
638
+ if name in self.cuda_event_timers:
639
+ self.cuda_event_timers[name].end.record(
640
+ torch.cuda.current_stream(self.device))
641
+
642
+ @contextmanager
643
+ def cuda_timing_scope(self, name, profile=True):
644
+ if name not in self.all_timings_dict: self.add_new_cuda_timing(name)
645
+ with ScopedTorchProfiler(name) if profile else nullcontext():
646
+ self._record_cuda_event_start(name)
647
+ try:
648
+ yield
649
+ finally:
650
+ self._record_cuda_event_end(name)
651
+
652
+ @contextmanager
653
+ def timing_scope(self, name, profile=True):
654
+ if name not in self.all_timings_dict: self.add_new_timing(name)
655
+ with ScopedTorchProfiler(name) if profile else nullcontext():
656
+ start_time = time.time()
657
+ try:
658
+ yield
659
+ finally:
660
+ self.record_timing(name, time.time()-start_time)
661
+
662
+ def record_all_cuda_timings(self):
663
+ """ After all the cuda events call this to synchronize and record down the cuda timings. """
664
+ for k, events in self.cuda_event_timers.items():
665
+ with torch.no_grad():
666
+ events.end.synchronize()
667
+ # Convert to seconds
668
+ time_elapsed = events.start.elapsed_time(events.end)/1000.
669
+ self.all_timings_dict[k] = time_elapsed
670
+
671
+ def init_s3(config_file):
672
+ config = json.load(open(config_file, 'r'))
673
+ s3_client = boto3.client("s3", **config)
674
+ return s3_client
675
+
676
+ def download_from_s3(file_path, target_path, cfg):
677
+ tic = time.time()
678
+ s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
679
+ bucket_name = file_path.split('/')[2]
680
+ file_key = file_path.split(bucket_name+'/')[-1]
681
+ print(bucket_name, file_key)
682
+ s3_client.download_file(bucket_name, file_key, target_path)
683
+ logger.info(f'finish download from ! s3://{bucket_name}/{file_key} to {target_path} %.1f sec'%(
684
+ time.time() - tic))
685
+
686
+ def upload_to_s3(buffer, bucket_name, key, config_dict):
687
+ logger.info(f'start upload_to_s3! bucket_name={bucket_name}, key={key}')
688
+ tic = time.time()
689
+ s3 = boto3.client('s3', **config_dict)
690
+ s3.put_object(Bucket=bucket_name, Key=key, Body=buffer.getvalue())
691
+ logger.info(f'finish upload_to_s3! s3://{bucket_name}/{key} %.1f sec'%(time.time() - tic))
692
+
693
+ def write_ckpt_to_s3(cfg, all_model_dict, ckpt_name):
694
+ buffer = io.BytesIO()
695
+ tic = time.time()
696
+ torch.save(all_model_dict, buffer) # take ~0.25 sec
697
+ # logger.info('write ckpt to buffer: %.2f sec'%(time.time() - tic))
698
+ group, name = cfg.outdir.rstrip("/").split("/")[-2:]
699
+ key = f"checkpoints/{group}/{name}/ckpt/{ckpt_name}"
700
+ bucket_name = cfg.checkpoint.write_s3_bucket
701
+
702
+ s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
703
+
704
+ config_dict = json.load(open(cfg.checkpoint.write_s3_config, 'r'))
705
+ upload_thread = threading.Thread(target=upload_to_s3, args=(buffer, bucket_name, key, config_dict))
706
+ upload_thread.start()
707
+ path = f"s3://{bucket_name}/{key}"
708
+ return path
709
+
710
+ def upload_file_to_s3(cfg, file_path, key_name=None):
711
+ # file_path is the local file path, can be a yaml file
712
+ # this function is used to upload the ckecpoint only
713
+ tic = time.time()
714
+ group, name = cfg.outdir.rstrip("/").split("/")[-2:]
715
+ if key_name is None:
716
+ key = os.path.basename(file_path)
717
+ key = f"checkpoints/{group}/{name}/{key}"
718
+ bucket_name = cfg.checkpoint.write_s3_bucket
719
+ s3_client = init_s3(cfg.checkpoint.write_s3_config)
720
+ # Upload the file
721
+ with open(file_path, 'rb') as f:
722
+ s3_client.upload_fileobj(f, bucket_name, key)
723
+ full_s3_path = f"s3://{bucket_name}/{key}"
724
+ logger.info(f'upload_to_s3: {file_path} {full_s3_path} | use time: {time.time()-tic}')
725
+
726
+ return full_s3_path
727
+
728
+
729
+ def load_from_s3(file_path, cfg, load_fn):
730
+ """
731
+ ckpt_path example:
732
+ s3://xzeng/checkpoints/2023_0413/vae_kl_5e-1/ckpt/snapshot_epo000163_iter164000.pt
733
+ """
734
+ s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
735
+ bucket_name = file_path.split("s3://")[-1].split('/')[0]
736
+ key = file_path.split(f'{bucket_name}/')[-1]
737
+ # logger.info(f"-> try to load s3://{bucket_name}/{key} ")
738
+ tic = time.time()
739
+ for attemp in range(10):
740
+ try:
741
+ # Download the state dict from S3 into memory (as a binary stream)
742
+ with io.BytesIO() as buffer:
743
+ s3_client.download_fileobj(bucket_name, key, buffer)
744
+ buffer.seek(0)
745
+
746
+ # Load the state dict into a PyTorch model
747
+ # out = torch.load(buffer, map_location=torch.device("cpu"))
748
+ out = load_fn(buffer)
749
+ break
750
+ except:
751
+ logger.info(f"fail to load s3://{bucket_name}/{key} attemp: {attemp}")
752
+ from torch_utils.dist_utils import is_rank0
753
+ if is_rank0():
754
+ logger.info(f'loaded {file_path} | use time: {time.time()-tic:.1f} sec')
755
+ return out
756
+
757
+ def load_torch_dict_from_s3(ckpt_path, cfg):
758
+ """
759
+ ckpt_path example:
760
+ s3://xzeng/checkpoints/2023_0413/vae_kl_5e-1/ckpt/snapshot_epo000163_iter164000.pt
761
+ """
762
+ s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
763
+ bucket_name = ckpt_path.split("s3://")[-1].split('/')[0]
764
+ key = ckpt_path.split(f'{bucket_name}/')[-1]
765
+ for attemp in range(10):
766
+ try:
767
+ # Download the state dict from S3 into memory (as a binary stream)
768
+ with io.BytesIO() as buffer:
769
+ s3_client.download_fileobj(bucket_name, key, buffer)
770
+ buffer.seek(0)
771
+
772
+ # Load the state dict into a PyTorch model
773
+ out = torch.load(buffer, map_location=torch.device("cpu"))
774
+ break
775
+ except:
776
+ logger.info(f"fail to load s3://{bucket_name}/{key} attemp: {attemp}")
777
+ return out
778
+
779
+ def count_parameters_in_M(model):
780
+ return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
781
+
782
+ def printarr(*arrs, float_width=6, **kwargs):
783
+ """
784
+ Print a pretty table giving name, shape, dtype, type, and content information for input tensors or scalars.
785
+
786
+ Call like: printarr(my_arr, some_other_arr, maybe_a_scalar). Accepts a variable number of arguments.
787
+
788
+ Inputs can be:
789
+ - Numpy tensor arrays
790
+ - Pytorch tensor arrays
791
+ - Jax tensor arrays
792
+ - Python ints / floats
793
+ - None
794
+
795
+ It may also work with other array-like types, but they have not been tested.
796
+
797
+ Use the `float_width` option specify the precision to which floating point types are printed.
798
+
799
+ Author: Nicholas Sharp (nmwsharp.com)
800
+ Canonical source: https://gist.github.com/nmwsharp/54d04af87872a4988809f128e1a1d233
801
+ License: This snippet may be used under an MIT license, and it is also released into the public domain.
802
+ Please retain this docstring as a reference.
803
+ """
804
+
805
+ frame = inspect.currentframe().f_back
806
+ default_name = "[temporary]"
807
+
808
+ ## helpers to gather data about each array
809
+ def name_from_outer_scope(a):
810
+ if a is None:
811
+ return '[None]'
812
+ name = default_name
813
+ for k, v in frame.f_locals.items():
814
+ if v is a:
815
+ name = k
816
+ break
817
+ return name
818
+
819
+ def type_strip(type_str):
820
+ return type_str.lstrip('<class ').rstrip('>').replace('torch.', '').strip("'")
821
+
822
+ def dtype_str(a):
823
+ if a is None:
824
+ return 'None'
825
+ if isinstance(a, int):
826
+ return 'int'
827
+ if isinstance(a, float):
828
+ return 'float'
829
+ if isinstance(a, list) and len(a)>0:
830
+ return type_strip(str(type(a[0])))
831
+ if hasattr(a, 'dtype'):
832
+ return type_strip(str(a.dtype))
833
+ else:
834
+ return ''
835
+ def shape_str(a):
836
+ if a is None:
837
+ return 'N/A'
838
+ if isinstance(a, int):
839
+ return 'scalar'
840
+ if isinstance(a, float):
841
+ return 'scalar'
842
+ if isinstance(a, list):
843
+ return f"[{shape_str(a[0]) if len(a)>0 else '?'}]*{len(a)}"
844
+ if hasattr(a, 'shape'):
845
+ return str(tuple(a.shape))
846
+ else:
847
+ return ''
848
+ def type_str(a):
849
+ return type_strip(str(type(a))) # TODO this is is weird... what's the better way?
850
+ def device_str(a):
851
+ if hasattr(a, 'device'):
852
+ device_str = str(a.device)
853
+ if len(device_str) < 10:
854
+ # heuristic: jax returns some goofy long string we don't want, ignore it
855
+ return device_str
856
+ return ""
857
+ def format_float(x):
858
+ return f"{x:{float_width}g}"
859
+ def minmaxmean_str(a):
860
+ if a is None:
861
+ return ('N/A', 'N/A', 'N/A', 'N/A')
862
+ if isinstance(a, int) or isinstance(a, float):
863
+ return (format_float(a),)*4
864
+
865
+ # compute min/max/mean. if anything goes wrong, just print 'N/A'
866
+ min_str = "N/A"
867
+ try: min_str = format_float(a.min())
868
+ except: pass
869
+ max_str = "N/A"
870
+ try: max_str = format_float(a.max())
871
+ except: pass
872
+ mean_str = "N/A"
873
+ try: mean_str = format_float(a.mean())
874
+ except: pass
875
+ try: median_str = format_float(a.median())
876
+ except:
877
+ try: median_str = format_float(np.median(np.array(a)))
878
+ except: median_str = 'N/A'
879
+ return (min_str, max_str, mean_str, median_str)
880
+
881
+ def get_prop_dict(a,k=None):
882
+ minmaxmean = minmaxmean_str(a)
883
+ props = {
884
+ 'name' : name_from_outer_scope(a) if k is None else k,
885
+ # 'type' : str(type(a)).replace('torch.',''),
886
+ 'dtype' : dtype_str(a),
887
+ 'shape' : shape_str(a),
888
+ 'type' : type_str(a),
889
+ 'device' : device_str(a),
890
+ 'min' : minmaxmean[0],
891
+ 'max' : minmaxmean[1],
892
+ 'mean' : minmaxmean[2],
893
+ 'median': minmaxmean[3]
894
+ }
895
+ return props
896
+
897
+ try:
898
+
899
+ props = ['name', 'type', 'dtype', 'shape', 'device', 'min', 'max', 'mean', 'median']
900
+
901
+ # precompute all of the properties for each input
902
+ str_props = []
903
+ for a in arrs:
904
+ str_props.append(get_prop_dict(a))
905
+ for k,a in kwargs.items():
906
+ str_props.append(get_prop_dict(a, k=k))
907
+
908
+ # for each property, compute its length
909
+ maxlen = {}
910
+ for p in props: maxlen[p] = 0
911
+ for sp in str_props:
912
+ for p in props:
913
+ maxlen[p] = max(maxlen[p], len(sp[p]))
914
+
915
+ # if any property got all empty strings, don't bother printing it, remove if from the list
916
+ props = [p for p in props if maxlen[p] > 0]
917
+
918
+ # print a header
919
+ header_str = ""
920
+ for p in props:
921
+ prefix = "" if p == 'name' else " | "
922
+ fmt_key = ">" if p == 'name' else "<"
923
+ header_str += f"{prefix}{p:{fmt_key}{maxlen[p]}}"
924
+ print(header_str)
925
+ print("-"*len(header_str))
926
+
927
+ # now print the acual arrays
928
+ for strp in str_props:
929
+ for p in props:
930
+ prefix = "" if p == 'name' else " | "
931
+ fmt_key = ">" if p == 'name' else "<"
932
+ print(f"{prefix}{strp[p]:{fmt_key}{maxlen[p]}}", end='')
933
+ print("")
934
+
935
+ finally:
936
+ del frame
937
+
938
+ def debug_print_all_tensor_sizes(min_tot_size = 0):
939
+ import gc
940
+ print("---------------------------------------"*3)
941
+ for obj in gc.get_objects():
942
+ try:
943
+ if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
944
+ if np.prod(obj.size())>=min_tot_size:
945
+ print(type(obj), obj.size())
946
+ except:
947
+ pass
948
+ def print_cpu_usage():
949
+
950
+ # Get current CPU usage as a percentage
951
+ cpu_usage = psutil.cpu_percent()
952
+
953
+ # Get current memory usage
954
+ memory_usage = psutil.virtual_memory().used
955
+
956
+ # Convert memory usage to a human-readable format
957
+ memory_usage_str = psutil._common.bytes2human(memory_usage)
958
+
959
+ # Print CPU and memory usage
960
+ msg = f"Current CPU usage: {cpu_usage}% | "
961
+ msg += f"Current memory usage: {memory_usage_str}"
962
+ return msg
963
+
964
+ def calmsize(num_bytes):
965
+ if math.isnan(num_bytes):
966
+ return ''
967
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
968
+ if abs(num_bytes) < 1024.0:
969
+ return "{:.1f}{}B".format(num_bytes, unit)
970
+ num_bytes /= 1024.0
971
+ return "{:.1f}{}B".format(num_bytes, 'Y')
972
+
973
+ def readable_size(num_bytes: int) -> str:
974
+ return calmsize(num_bytes) ## '' if math.isnan(num_bytes) else '{:.1f}'.format(calmsize(num_bytes))
975
+
976
+ def get_gpu_memory():
977
+ """
978
+ Get the current GPU memory usage for each device as a dictionary
979
+ """
980
+ output = subprocess.check_output(["nvidia-smi", "--query-gpu=memory.used", "--format=csv"])
981
+ output = output.decode("utf-8")
982
+ gpu_memory_values = output.split("\n")[1:-1]
983
+ gpu_memory_values = [int(x.strip().split()[0]) for x in gpu_memory_values]
984
+ gpu_memory = dict(zip(range(len(gpu_memory_values)), gpu_memory_values))
985
+ return gpu_memory
986
+
987
+ def get_gpu_util():
988
+ """
989
+ Get the current GPU memory usage for each device as a dictionary
990
+ """
991
+ output = subprocess.check_output(["nvidia-smi", "--query-gpu=utilization.gpu", "--format=csv"])
992
+ output = output.decode("utf-8")
993
+ gpu_memory_values = output.split("\n")[1:-1]
994
+ gpu_memory_values = [int(x.strip().split()[0]) for x in gpu_memory_values]
995
+ gpu_util = dict(zip(range(len(gpu_memory_values)), gpu_memory_values))
996
+ return gpu_util
997
+
998
+
999
+ def print_gpu_usage():
1000
+ useage = get_gpu_memory()
1001
+ msg = f" | GPU usage: "
1002
+ for k, v in useage.items():
1003
+ msg += f"{k}: {v} MB "
1004
+ # utilization = get_gpu_util()
1005
+ # msg + ' | util '
1006
+ # for k, v in utilization.items():
1007
+ # msg += f"{k}: {v} % "
1008
+ return msg
1009
+
1010
+ class AverageMeter(object):
1011
+
1012
+ def __init__(self):
1013
+ self.reset()
1014
+
1015
+ def reset(self):
1016
+ self.avg = 0
1017
+ self.sum = 0
1018
+ self.cnt = 0
1019
+
1020
+ def update(self, val, n=1):
1021
+ self.sum += val * n
1022
+ self.cnt += n
1023
+ self.avg = self.sum / self.cnt
1024
+
1025
+
1026
+ def generate_random_string(length):
1027
+ # This script will generate a string of 10 random ASCII letters (both lowercase and uppercase).
1028
+ # You can adjust the length parameter to fit your needs.
1029
+ letters = string.ascii_letters
1030
+ return ''.join(random.choice(letters) for _ in range(length))
1031
+
1032
+
1033
+ class ForkedPdb(pdb.Pdb):
1034
+ """
1035
+ PDB Subclass for debugging multi-processed code
1036
+ Suggested in: https://stackoverflow.com/questions/4716533/how-to-attach-debugger-to-a-python-subproccess
1037
+ """
1038
+ def interaction(self, *args, **kwargs):
1039
+ _stdin = sys.stdin
1040
+ try:
1041
+ sys.stdin = open('/dev/stdin')
1042
+ pdb.Pdb.interaction(self, *args, **kwargs)
1043
+ finally:
1044
+ sys.stdin = _stdin
1045
+
1046
+ def check_exist_in_s3(file_path, s3_config):
1047
+ s3 = init_s3(s3_config)
1048
+ bucket_name, object_name = s3path_to_bucket_key(file_path)
1049
+
1050
+ try:
1051
+ s3.head_object(Bucket=bucket_name, Key=object_name)
1052
+ return 1
1053
+ except:
1054
+ logger.info(f'file not found: s3://{bucket_name}/{object_name}')
1055
+ return 0
1056
+
1057
+ def s3path_to_bucket_key(file_path):
1058
+ bucket_name = file_path.split('/')[2]
1059
+ object_name = file_path.split(bucket_name + '/')[-1]
1060
+ return bucket_name, object_name
1061
+
1062
+ def copy_file_to_s3(cfg, file_path_local, file_path_s3):
1063
+ # work similar as upload_file_to_s3, but not trying to parse the file path
1064
+ # file_path_s3: s3://{bucket}/{key}
1065
+ bucket_name, key = s3path_to_bucket_key(file_path_s3)
1066
+ tic = time.time()
1067
+ s3_client = init_s3(cfg.checkpoint.write_s3_config)
1068
+
1069
+ # Upload the file
1070
+ with open(file_path_local, 'rb') as f:
1071
+ s3_client.upload_fileobj(f, bucket_name, key)
1072
+ full_s3_path = f"s3://{bucket_name}/{key}"
1073
+ logger.info(f'copy file: {file_path_local} {full_s3_path} | use time: {time.time()-tic}')
1074
+ return full_s3_path
PartField/partfield/model/PVCNN/encoder_pc.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ from ast import Dict
10
+ import math
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import nn
15
+ import torch.nn.functional as F
16
+ from torch_scatter import scatter_mean #, scatter_max
17
+
18
+ from .unet_3daware import setup_unet #UNetTriplane3dAware
19
+ from .conv_pointnet import ConvPointnet
20
+
21
+ from .pc_encoder import PVCNNEncoder #PointNet
22
+
23
+ import einops
24
+
25
+ from .dnnlib_util import ScopedTorchProfiler, printarr
26
+
27
+ def generate_plane_features(p, c, resolution, plane='xz'):
28
+ """
29
+ Args:
30
+ p: (B,3,n_p)
31
+ c: (B,C,n_p)
32
+ """
33
+ padding = 0.
34
+ c_dim = c.size(1)
35
+ # acquire indices of features in plane
36
+ xy = normalize_coordinate(p.clone(), plane=plane, padding=padding) # normalize to the range of (0, 1)
37
+ index = coordinate2index(xy, resolution)
38
+
39
+ # scatter plane features from points
40
+ fea_plane = c.new_zeros(p.size(0), c_dim, resolution**2)
41
+ fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
42
+ fea_plane = fea_plane.reshape(p.size(0), c_dim, resolution, resolution) # sparce matrix (B x 512 x reso x reso)
43
+ return fea_plane
44
+
45
+ def normalize_coordinate(p, padding=0.1, plane='xz'):
46
+ ''' Normalize coordinate to [0, 1] for unit cube experiments
47
+
48
+ Args:
49
+ p (tensor): point
50
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
51
+ plane (str): plane feature type, ['xz', 'xy', 'yz']
52
+ '''
53
+ if plane == 'xz':
54
+ xy = p[:, :, [0, 2]]
55
+ elif plane =='xy':
56
+ xy = p[:, :, [0, 1]]
57
+ else:
58
+ xy = p[:, :, [1, 2]]
59
+
60
+ xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
61
+ xy_new = xy_new + 0.5 # range (0, 1)
62
+
63
+ # if there are outliers out of the range
64
+ if xy_new.max() >= 1:
65
+ xy_new[xy_new >= 1] = 1 - 10e-6
66
+ if xy_new.min() < 0:
67
+ xy_new[xy_new < 0] = 0.0
68
+ return xy_new
69
+
70
+
71
+ def coordinate2index(x, resolution):
72
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
73
+ Corresponds to our 3D model
74
+
75
+ Args:
76
+ x (tensor): coordinate
77
+ reso (int): defined resolution
78
+ coord_type (str): coordinate type
79
+ '''
80
+ x = (x * resolution).long()
81
+ index = x[:, :, 0] + resolution * x[:, :, 1]
82
+ index = index[:, None, :]
83
+ return index
84
+
85
+ def softclip(x, min, max, hardness=5):
86
+ # Soft clipping for the logsigma
87
+ x = min + F.softplus(hardness*(x - min))/hardness
88
+ x = max - F.softplus(-hardness*(x - max))/hardness
89
+ return x
90
+
91
+
92
+ def sample_triplane_feat(feature_triplane, normalized_pos):
93
+ '''
94
+ normalized_pos [-1, 1]
95
+ '''
96
+ tri_plane = torch.unbind(feature_triplane, dim=1)
97
+
98
+ x_feat = F.grid_sample(
99
+ tri_plane[0],
100
+ torch.cat(
101
+ [normalized_pos[:, :, 0:1], normalized_pos[:, :, 1:2]],
102
+ dim=-1).unsqueeze(dim=1), padding_mode='border',
103
+ align_corners=True)
104
+ y_feat = F.grid_sample(
105
+ tri_plane[1],
106
+ torch.cat(
107
+ [normalized_pos[:, :, 1:2], normalized_pos[:, :, 2:3]],
108
+ dim=-1).unsqueeze(dim=1), padding_mode='border',
109
+ align_corners=True)
110
+
111
+ z_feat = F.grid_sample(
112
+ tri_plane[2],
113
+ torch.cat(
114
+ [normalized_pos[:, :, 0:1], normalized_pos[:, :, 2:3]],
115
+ dim=-1).unsqueeze(dim=1), padding_mode='border',
116
+ align_corners=True)
117
+ final_feat = (x_feat + y_feat + z_feat)
118
+ final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1) # 32dimension
119
+ return final_feat
120
+
121
+
122
+ # @persistence.persistent_class
123
+ class TriPlanePC2Encoder(torch.nn.Module):
124
+ # Encoder that encode point cloud to triplane feature vector similar to ConvOccNet
125
+ def __init__(
126
+ self,
127
+ cfg,
128
+ device='cuda',
129
+ shape_min=-1.0,
130
+ shape_length=2.0,
131
+ use_2d_feat=False,
132
+ # point_encoder='pvcnn',
133
+ # use_point_scatter=False
134
+ ):
135
+ """
136
+ Outputs latent triplane from PC input
137
+ Configs:
138
+ max_logsigma: (float) Soft clip upper range for logsigm
139
+ min_logsigma: (float)
140
+ point_encoder_type: (str) one of ['pvcnn', 'pointnet']
141
+ pvcnn_flatten_voxels: (bool) for pvcnn whether to reduce voxel
142
+ features (instead of scattering point features)
143
+ unet_cfg: (dict)
144
+ z_triplane_channels: (int) output latent triplane
145
+ z_triplane_resolution: (int)
146
+ Args:
147
+
148
+ """
149
+ # assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
150
+ super().__init__()
151
+ self.device = device
152
+
153
+ self.cfg = cfg
154
+
155
+ self.shape_min = shape_min
156
+ self.shape_length = shape_length
157
+
158
+ self.z_triplane_resolution = cfg.z_triplane_resolution
159
+ z_triplane_channels = cfg.z_triplane_channels
160
+
161
+ point_encoder_out_dim = z_triplane_channels #* 2
162
+
163
+ in_channels = 6
164
+ # self.resample_filter=[1, 3, 3, 1]
165
+ if cfg.point_encoder_type == 'pvcnn':
166
+ self.pc_encoder = PVCNNEncoder(point_encoder_out_dim,
167
+ device=self.device, in_channels=in_channels, use_2d_feat=use_2d_feat) # Encode it to a volume vector.
168
+ elif cfg.point_encoder_type == 'pointnet':
169
+ # TODO the pointnet was buggy, investigate
170
+ self.pc_encoder = ConvPointnet(c_dim=point_encoder_out_dim,
171
+ dim=in_channels, hidden_dim=32,
172
+ plane_resolution=self.z_triplane_resolution,
173
+ padding=0)
174
+ else:
175
+ raise NotImplementedError(f"Point encoder {cfg.point_encoder_type} not implemented")
176
+
177
+ if cfg.unet_cfg.enabled:
178
+ self.unet_encoder = setup_unet(
179
+ output_channels=point_encoder_out_dim,
180
+ input_channels=point_encoder_out_dim,
181
+ unet_cfg=cfg.unet_cfg)
182
+ else:
183
+ self.unet_encoder = None
184
+
185
+ # @ScopedTorchProfiler('encode')
186
+ def encode(self, point_cloud_xyz, point_cloud_feature, mv_feat=None, pc2pc_idx=None) -> Dict:
187
+ # output = AttrDict()
188
+ point_cloud_xyz = (point_cloud_xyz - self.shape_min) / self.shape_length # [0, 1]
189
+ point_cloud_xyz = point_cloud_xyz - 0.5 # [-0.5, 0.5]
190
+ point_cloud = torch.cat([point_cloud_xyz, point_cloud_feature], dim=-1)
191
+
192
+ if self.cfg.point_encoder_type == 'pvcnn':
193
+ if mv_feat is not None:
194
+ pc_feat, points_feat = self.pc_encoder(point_cloud, mv_feat, pc2pc_idx)
195
+ else:
196
+ pc_feat, points_feat = self.pc_encoder(point_cloud) # 3D feature volume: BxDx32x32x32
197
+ if self.cfg.use_point_scatter:
198
+ # Scattering from PVCNN point features
199
+ points_feat_ = points_feat[0]
200
+ # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
201
+ pc_feat_1 = generate_plane_features(point_cloud_xyz, points_feat_,
202
+ resolution=self.z_triplane_resolution, plane='xy')
203
+ pc_feat_2 = generate_plane_features(point_cloud_xyz, points_feat_,
204
+ resolution=self.z_triplane_resolution, plane='yz')
205
+ pc_feat_3 = generate_plane_features(point_cloud_xyz, points_feat_,
206
+ resolution=self.z_triplane_resolution, plane='xz')
207
+ pc_feat = pc_feat[0]
208
+
209
+ else:
210
+ pc_feat = pc_feat[0]
211
+ sf = self.z_triplane_resolution//32 # 32 is PVCNN's voxel dim
212
+
213
+ pc_feat_1 = torch.mean(pc_feat, dim=-1) #xy_plane, normalize in z plane
214
+ pc_feat_2 = torch.mean(pc_feat, dim=-3) #yz_plane, normalize in x plane
215
+ pc_feat_3 = torch.mean(pc_feat, dim=-2) #xz_plane, normalize in y plane
216
+
217
+ # nearest upsample
218
+ pc_feat_1 = einops.repeat(pc_feat_1, 'b c h w -> b c (h hm ) (w wm)', hm = sf, wm = sf)
219
+ pc_feat_2 = einops.repeat(pc_feat_2, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf)
220
+ pc_feat_3 = einops.repeat(pc_feat_3, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf)
221
+ elif self.cfg.point_encoder_type == 'pointnet':
222
+ assert self.cfg.use_point_scatter
223
+ # Run ConvPointnet
224
+ pc_feat = self.pc_encoder(point_cloud)
225
+ pc_feat_1 = pc_feat['xy'] #
226
+ pc_feat_2 = pc_feat['yz']
227
+ pc_feat_3 = pc_feat['xz']
228
+ else:
229
+ raise NotImplementedError()
230
+
231
+ if self.unet_encoder is not None:
232
+ # TODO eval adding a skip connection
233
+ # Unet expects B, 3, C, H, W
234
+ pc_feat_tri_plane_stack_pre = torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1)
235
+ # dpc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre)
236
+ # pc_feat_tri_plane_stack = pc_feat_tri_plane_stack_pre + dpc_feat_tri_plane_stack
237
+ pc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre)
238
+ pc_feat_1, pc_feat_2, pc_feat_3 = torch.unbind(pc_feat_tri_plane_stack, dim=1)
239
+
240
+ return torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1)
241
+
242
+ def forward(self, point_cloud_xyz, point_cloud_feature=None, mv_feat=None, pc2pc_idx=None):
243
+ return self.encode(point_cloud_xyz, point_cloud_feature=point_cloud_feature, mv_feat=mv_feat, pc2pc_idx=pc2pc_idx)
PartField/partfield/model/PVCNN/pc_encoder.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import functools
5
+
6
+ from .pv_module import SharedMLP, PVConv
7
+
8
+ def create_pointnet_components(
9
+ blocks, in_channels, with_se=False, normalize=True, eps=0,
10
+ width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=False, device='cuda'):
11
+ r, vr = width_multiplier, voxel_resolution_multiplier
12
+ layers, concat_channels = [], 0
13
+ for out_channels, num_blocks, voxel_resolution in blocks:
14
+ out_channels = int(r * out_channels)
15
+ if voxel_resolution is None:
16
+ block = functools.partial(SharedMLP, device=device)
17
+ else:
18
+ block = functools.partial(
19
+ PVConv, kernel_size=3, resolution=int(vr * voxel_resolution),
20
+ with_se=with_se, normalize=normalize, eps=eps, scale_pvcnn=scale_pvcnn, device=device)
21
+ for _ in range(num_blocks):
22
+ layers.append(block(in_channels, out_channels))
23
+ in_channels = out_channels
24
+ concat_channels += out_channels
25
+ return layers, in_channels, concat_channels
26
+
27
+ class PCMerger(nn.Module):
28
+ # merge surface sampled PC and rendering backprojected PC (w/ 2D features):
29
+ def __init__(self, in_channels=204, device="cuda"):
30
+ super(PCMerger, self).__init__()
31
+ self.mlp_normal = SharedMLP(3, [128, 128], device=device)
32
+ self.mlp_rgb = SharedMLP(3, [128, 128], device=device)
33
+ self.mlp_sam = SharedMLP(204 - 6, [128, 128], device=device)
34
+
35
+ def forward(self, feat, mv_feat, pc2pc_idx):
36
+ mv_feat_normal = self.mlp_normal(mv_feat[:, :3, :])
37
+ mv_feat_rgb = self.mlp_rgb(mv_feat[:, 3:6, :])
38
+ mv_feat_sam = self.mlp_sam(mv_feat[:, 6:, :])
39
+
40
+ mv_feat_normal = mv_feat_normal.permute(0, 2, 1)
41
+ mv_feat_rgb = mv_feat_rgb.permute(0, 2, 1)
42
+ mv_feat_sam = mv_feat_sam.permute(0, 2, 1)
43
+ feat = feat.permute(0, 2, 1)
44
+
45
+ for i in range(mv_feat.shape[0]):
46
+ mask = (pc2pc_idx[i] != -1).reshape(-1)
47
+ idx = pc2pc_idx[i][mask].reshape(-1)
48
+ feat[i][mask] += mv_feat_normal[i][idx] + mv_feat_rgb[i][idx] + mv_feat_sam[i][idx]
49
+
50
+ return feat.permute(0, 2, 1)
51
+
52
+
53
+ class PVCNNEncoder(nn.Module):
54
+ def __init__(self, pvcnn_feat_dim, device='cuda', in_channels=3, use_2d_feat=False):
55
+ super(PVCNNEncoder, self).__init__()
56
+ self.device = device
57
+ self.blocks = ((pvcnn_feat_dim, 1, 32), (128, 2, 16), (256, 1, 8))
58
+ self.use_2d_feat=use_2d_feat
59
+ if in_channels == 6:
60
+ self.append_channel = 2
61
+ elif in_channels == 3:
62
+ self.append_channel = 1
63
+ else:
64
+ raise NotImplementedError
65
+ layers, channels_point, concat_channels_point = create_pointnet_components(
66
+ blocks=self.blocks, in_channels=in_channels + self.append_channel, with_se=False, normalize=False,
67
+ width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=True,
68
+ device=device
69
+ )
70
+ self.encoder = nn.ModuleList(layers)#.to(self.device)
71
+ if self.use_2d_feat:
72
+ self.merger = PCMerger()
73
+
74
+
75
+
76
+ def forward(self, input_pc, mv_feat=None, pc2pc_idx=None):
77
+ features = input_pc.permute(0, 2, 1) * 2 # make point cloud [-1, 1]
78
+ coords = features[:, :3, :]
79
+ out_features_list = []
80
+ voxel_feature_list = []
81
+ zero_padding = torch.zeros(features.shape[0], self.append_channel, features.shape[-1], device=features.device, dtype=torch.float)
82
+ features = torch.cat([features, zero_padding], dim=1)##################
83
+
84
+ for i in range(len(self.encoder)):
85
+ features, _, voxel_feature = self.encoder[i]((features, coords))
86
+ if i == 0 and mv_feat is not None:
87
+ features = self.merger(features, mv_feat.permute(0, 2, 1), pc2pc_idx)
88
+ out_features_list.append(features)
89
+ voxel_feature_list.append(voxel_feature)
90
+ return voxel_feature_list, out_features_list
PartField/partfield/model/PVCNN/pv_module/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .pvconv import PVConv
2
+ from .shared_mlp import SharedMLP
PartField/partfield/model/PVCNN/pv_module/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (284 Bytes). View file
 
PartField/partfield/model/PVCNN/pv_module/__pycache__/pvconv.cpython-310.pyc ADDED
Binary file (1.62 kB). View file
 
PartField/partfield/model/PVCNN/pv_module/__pycache__/shared_mlp.cpython-310.pyc ADDED
Binary file (1.18 kB). View file
 
PartField/partfield/model/PVCNN/pv_module/__pycache__/voxelization.cpython-310.pyc ADDED
Binary file (2.23 kB). View file
 
PartField/partfield/model/PVCNN/pv_module/ball_query.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from . import functional as F
5
+
6
+ __all__ = ['BallQuery']
7
+
8
+
9
+ class BallQuery(nn.Module):
10
+ def __init__(self, radius, num_neighbors, include_coordinates=True):
11
+ super().__init__()
12
+ self.radius = radius
13
+ self.num_neighbors = num_neighbors
14
+ self.include_coordinates = include_coordinates
15
+
16
+ def forward(self, points_coords, centers_coords, points_features=None):
17
+ points_coords = points_coords.contiguous()
18
+ centers_coords = centers_coords.contiguous()
19
+ neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
20
+ neighbor_coordinates = F.grouping(points_coords, neighbor_indices)
21
+ neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
22
+
23
+ if points_features is None:
24
+ assert self.include_coordinates, 'No Features For Grouping'
25
+ neighbor_features = neighbor_coordinates
26
+ else:
27
+ neighbor_features = F.grouping(points_features, neighbor_indices)
28
+ if self.include_coordinates:
29
+ neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1)
30
+ return neighbor_features
31
+
32
+ def extra_repr(self):
33
+ return 'radius={}, num_neighbors={}{}'.format(
34
+ self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
PartField/partfield/model/PVCNN/pv_module/frustum.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from . import functional as PF
7
+
8
+ __all__ = ['FrustumPointNetLoss', 'get_box_corners_3d']
9
+
10
+
11
+ class FrustumPointNetLoss(nn.Module):
12
+ def __init__(
13
+ self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0,
14
+ corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0):
15
+ super().__init__()
16
+ self.box_loss_weight = box_loss_weight
17
+ self.corners_loss_weight = corners_loss_weight
18
+ self.heading_residual_loss_weight = heading_residual_loss_weight
19
+ self.size_residual_loss_weight = size_residual_loss_weight
20
+
21
+ self.num_heading_angle_bins = num_heading_angle_bins
22
+ self.num_size_templates = num_size_templates
23
+ self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3))
24
+ self.register_buffer(
25
+ 'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
26
+ )
27
+
28
+ def forward(self, inputs, targets):
29
+ mask_logits = inputs['mask_logits'] # (B, 2, N)
30
+ center_reg = inputs['center_reg'] # (B, 3)
31
+ center = inputs['center'] # (B, 3)
32
+ heading_scores = inputs['heading_scores'] # (B, NH)
33
+ heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH)
34
+ heading_residuals = inputs['heading_residuals'] # (B, NH)
35
+ size_scores = inputs['size_scores'] # (B, NS)
36
+ size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3)
37
+ size_residuals = inputs['size_residuals'] # (B, NS, 3)
38
+
39
+ mask_logits_target = targets['mask_logits'] # (B, N)
40
+ center_target = targets['center'] # (B, 3)
41
+ heading_bin_id_target = targets['heading_bin_id'] # (B, )
42
+ heading_residual_target = targets['heading_residual'] # (B, )
43
+ size_template_id_target = targets['size_template_id'] # (B, )
44
+ size_residual_target = targets['size_residual'] # (B, 3)
45
+
46
+ batch_size = center.size(0)
47
+ batch_id = torch.arange(batch_size, device=center.device)
48
+
49
+ # Basic Classification and Regression losses
50
+ mask_loss = F.cross_entropy(mask_logits, mask_logits_target)
51
+ heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target)
52
+ size_loss = F.cross_entropy(size_scores, size_template_id_target)
53
+ center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0)
54
+ center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0)
55
+
56
+ # Refinement losses for size/heading
57
+ heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, )
58
+ heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins)
59
+ heading_residual_normalized_loss = PF.huber_loss(
60
+ heading_residuals_normalized - heading_residual_normalized_target, delta=1.0
61
+ )
62
+ size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3)
63
+ size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target]
64
+ size_residual_normalized_loss = PF.huber_loss(
65
+ torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0
66
+ )
67
+
68
+ # Bounding box losses
69
+ heading = (heading_residuals[batch_id, heading_bin_id_target]
70
+ + self.heading_angle_bin_centers[heading_bin_id_target]) # (B, )
71
+ # Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
72
+ size = (size_residuals[batch_id, size_template_id_target]
73
+ + self.size_templates[size_template_id_target]) # (B, 3)
74
+ corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
75
+ heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
76
+ size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
77
+ corners_target, corners_target_flip = get_box_corners_3d(
78
+ centers=center_target, headings=heading_target,
79
+ sizes=size_target, with_flip=True) # (B, 3, 8)
80
+ corners_loss = PF.huber_loss(
81
+ torch.min(
82
+ torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)
83
+ ), delta=1.0)
84
+ # Summing up
85
+ loss = mask_loss + self.box_loss_weight * (
86
+ center_loss + center_reg_loss + heading_loss + size_loss
87
+ + self.heading_residual_loss_weight * heading_residual_normalized_loss
88
+ + self.size_residual_loss_weight * size_residual_normalized_loss
89
+ + self.corners_loss_weight * corners_loss
90
+ )
91
+
92
+ return loss
93
+
94
+
95
+ def get_box_corners_3d(centers, headings, sizes, with_flip=False):
96
+ """
97
+ :param centers: coords of box centers, FloatTensor[N, 3]
98
+ :param headings: heading angles, FloatTensor[N, ]
99
+ :param sizes: box sizes, FloatTensor[N, 3]
100
+ :param with_flip: bool, whether to return flipped box (headings + np.pi)
101
+ :return:
102
+ coords of box corners, FloatTensor[N, 3, 8]
103
+ NOTE: corner points are in counter clockwise order, e.g.,
104
+ 2--1
105
+ 3--0 5
106
+ 7--4
107
+ """
108
+ l = sizes[:, 0] # (N,)
109
+ w = sizes[:, 1] # (N,)
110
+ h = sizes[:, 2] # (N,)
111
+ x_corners = torch.stack([l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2], dim=1) # (N, 8)
112
+ y_corners = torch.stack([h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2], dim=1) # (N, 8)
113
+ z_corners = torch.stack([w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2], dim=1) # (N, 8)
114
+
115
+ c = torch.cos(headings) # (N,)
116
+ s = torch.sin(headings) # (N,)
117
+ o = torch.ones_like(headings) # (N,)
118
+ z = torch.zeros_like(headings) # (N,)
119
+
120
+ centers = centers.unsqueeze(-1) # (B, 3, 1)
121
+ corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
122
+ R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3)
123
+ if with_flip:
124
+ R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3)
125
+ return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers
126
+ else:
127
+ return torch.matmul(R, corners) + centers
128
+
129
+ # centers = centers.unsqueeze(1) # (B, 1, 3)
130
+ # corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3)
131
+ # RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
132
+ # if with_flip:
133
+ # RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3)
134
+ # return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3)
135
+ # else:
136
+ # return torch.matmul(corners, RT) + centers # (N, 8, 3)
137
+
138
+ # corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
139
+ # R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
140
+ # corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8)
141
+ # corners = corners.transpose(1, 2) # (N, 8, 3)
PartField/partfield/model/PVCNN/pv_module/functional/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .devoxelization import trilinear_devoxelize
PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (273 Bytes). View file
 
PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/devoxelization.cpython-310.pyc ADDED
Binary file (748 Bytes). View file
 
PartField/partfield/model/PVCNN/pv_module/functional/devoxelization.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.autograd import Function
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = ['trilinear_devoxelize']
6
+
7
+ def trilinear_devoxelize(c, coords, r, training=None):
8
+ coords = (coords * 2 + 1.0) / r - 1.0
9
+ coords = coords.permute(0, 2, 1).reshape(c.shape[0], 1, 1, -1, 3)
10
+ f = F.grid_sample(input=c, grid=coords, padding_mode='border', align_corners=False)
11
+ f = f.squeeze(dim=2).squeeze(dim=2)
12
+ return f
PartField/partfield/model/PVCNN/pv_module/loss.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from . import functional as F
4
+
5
+ __all__ = ['KLLoss']
6
+
7
+
8
+ class KLLoss(nn.Module):
9
+ def forward(self, x, y):
10
+ return F.kl_loss(x, y)
PartField/partfield/model/PVCNN/pv_module/pointnet.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from . import functional as F
5
+ from .ball_query import BallQuery
6
+ from .shared_mlp import SharedMLP
7
+
8
+ __all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule']
9
+
10
+
11
+ class PointNetAModule(nn.Module):
12
+ def __init__(self, in_channels, out_channels, include_coordinates=True):
13
+ super().__init__()
14
+ if not isinstance(out_channels, (list, tuple)):
15
+ out_channels = [[out_channels]]
16
+ elif not isinstance(out_channels[0], (list, tuple)):
17
+ out_channels = [out_channels]
18
+
19
+ mlps = []
20
+ total_out_channels = 0
21
+ for _out_channels in out_channels:
22
+ mlps.append(
23
+ SharedMLP(
24
+ in_channels=in_channels + (3 if include_coordinates else 0),
25
+ out_channels=_out_channels, dim=1)
26
+ )
27
+ total_out_channels += _out_channels[-1]
28
+
29
+ self.include_coordinates = include_coordinates
30
+ self.out_channels = total_out_channels
31
+ self.mlps = nn.ModuleList(mlps)
32
+
33
+ def forward(self, inputs):
34
+ features, coords = inputs
35
+ if self.include_coordinates:
36
+ features = torch.cat([features, coords], dim=1)
37
+ coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
38
+ if len(self.mlps) > 1:
39
+ features_list = []
40
+ for mlp in self.mlps:
41
+ features_list.append(mlp(features).max(dim=-1, keepdim=True).values)
42
+ return torch.cat(features_list, dim=1), coords
43
+ else:
44
+ return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords
45
+
46
+ def extra_repr(self):
47
+ return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
48
+
49
+
50
+ class PointNetSAModule(nn.Module):
51
+ def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True):
52
+ super().__init__()
53
+ if not isinstance(radius, (list, tuple)):
54
+ radius = [radius]
55
+ if not isinstance(num_neighbors, (list, tuple)):
56
+ num_neighbors = [num_neighbors] * len(radius)
57
+ assert len(radius) == len(num_neighbors)
58
+ if not isinstance(out_channels, (list, tuple)):
59
+ out_channels = [[out_channels]] * len(radius)
60
+ elif not isinstance(out_channels[0], (list, tuple)):
61
+ out_channels = [out_channels] * len(radius)
62
+ assert len(radius) == len(out_channels)
63
+
64
+ groupers, mlps = [], []
65
+ total_out_channels = 0
66
+ for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors):
67
+ groupers.append(
68
+ BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates)
69
+ )
70
+ mlps.append(
71
+ SharedMLP(
72
+ in_channels=in_channels + (3 if include_coordinates else 0),
73
+ out_channels=_out_channels, dim=2)
74
+ )
75
+ total_out_channels += _out_channels[-1]
76
+
77
+ self.num_centers = num_centers
78
+ self.out_channels = total_out_channels
79
+ self.groupers = nn.ModuleList(groupers)
80
+ self.mlps = nn.ModuleList(mlps)
81
+
82
+ def forward(self, inputs):
83
+ features, coords = inputs
84
+ centers_coords = F.furthest_point_sample(coords, self.num_centers)
85
+ features_list = []
86
+ for grouper, mlp in zip(self.groupers, self.mlps):
87
+ features_list.append(mlp(grouper(coords, centers_coords, features)).max(dim=-1).values)
88
+ if len(features_list) > 1:
89
+ return torch.cat(features_list, dim=1), centers_coords
90
+ else:
91
+ return features_list[0], centers_coords
92
+
93
+ def extra_repr(self):
94
+ return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
95
+
96
+
97
+ class PointNetFPModule(nn.Module):
98
+ def __init__(self, in_channels, out_channels):
99
+ super().__init__()
100
+ self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1)
101
+
102
+ def forward(self, inputs):
103
+ if len(inputs) == 3:
104
+ points_coords, centers_coords, centers_features = inputs
105
+ points_features = None
106
+ else:
107
+ points_coords, centers_coords, centers_features, points_features = inputs
108
+ interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
109
+ if points_features is not None:
110
+ interpolated_features = torch.cat(
111
+ [interpolated_features, points_features], dim=1
112
+ )
113
+ return self.mlp(interpolated_features), points_coords
PartField/partfield/model/PVCNN/pv_module/pvconv.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from . import functional as F
4
+ from .voxelization import Voxelization
5
+ from .shared_mlp import SharedMLP
6
+ import torch
7
+
8
+ __all__ = ['PVConv']
9
+
10
+
11
+ class PVConv(nn.Module):
12
+ def __init__(
13
+ self, in_channels, out_channels, kernel_size, resolution, with_se=False, normalize=True, eps=0, scale_pvcnn=False,
14
+ device='cuda'):
15
+ super().__init__()
16
+ self.in_channels = in_channels
17
+ self.out_channels = out_channels
18
+ self.kernel_size = kernel_size
19
+ self.resolution = resolution
20
+ self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps, scale_pvcnn=scale_pvcnn)
21
+ voxel_layers = [
22
+ nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2, device=device),
23
+ nn.InstanceNorm3d(out_channels, eps=1e-4, device=device),
24
+ nn.LeakyReLU(0.1, True),
25
+ nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2, device=device),
26
+ nn.InstanceNorm3d(out_channels, eps=1e-4, device=device),
27
+ nn.LeakyReLU(0.1, True),
28
+ ]
29
+ self.voxel_layers = nn.Sequential(*voxel_layers)
30
+ self.point_features = SharedMLP(in_channels, out_channels, device=device)
31
+
32
+ def forward(self, inputs):
33
+ features, coords = inputs
34
+ voxel_features, voxel_coords = self.voxelization(features, coords)
35
+ voxel_features = self.voxel_layers(voxel_features)
36
+ devoxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training)
37
+ fused_features = devoxel_features + self.point_features(features)
38
+ return fused_features, coords, voxel_features
PartField/partfield/model/PVCNN/pv_module/shared_mlp.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ __all__ = ['SharedMLP']
4
+
5
+
6
+ class SharedMLP(nn.Module):
7
+ def __init__(self, in_channels, out_channels, dim=1, device='cuda'):
8
+ super().__init__()
9
+ # print('==> SharedMLP device: ', device)
10
+ if dim == 1:
11
+ conv = nn.Conv1d
12
+ bn = nn.InstanceNorm1d
13
+ elif dim == 2:
14
+ conv = nn.Conv2d
15
+ bn = nn.InstanceNorm1d
16
+ else:
17
+ raise ValueError
18
+ if not isinstance(out_channels, (list, tuple)):
19
+ out_channels = [out_channels]
20
+ layers = []
21
+ for oc in out_channels:
22
+ layers.extend(
23
+ [
24
+ conv(in_channels, oc, 1, device=device),
25
+ bn(oc, device=device),
26
+ nn.ReLU(True),
27
+ ])
28
+ in_channels = oc
29
+ self.layers = nn.Sequential(*layers)
30
+
31
+ def forward(self, inputs):
32
+ if isinstance(inputs, (list, tuple)):
33
+ return (self.layers(inputs[0]), *inputs[1:])
34
+ else:
35
+ return self.layers(inputs)
PartField/partfield/model/PVCNN/pv_module/voxelization.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from . import functional as F
5
+
6
+ __all__ = ['Voxelization']
7
+
8
+
9
+ def my_voxelization(features, coords, resolution):
10
+ b, c, _ = features.shape
11
+ result = torch.zeros(b, c + 1, resolution * resolution * resolution, device=features.device, dtype=torch.float)
12
+ r = resolution
13
+ r2 = resolution * resolution
14
+ indices = coords[:, 0] * r2 + coords[:, 1] * r + coords[:, 2]
15
+ indices = indices.unsqueeze(dim=1).expand(-1, result.shape[1], -1)
16
+ features = torch.cat([features, torch.ones(features.shape[0], 1, features.shape[2], device=features.device, dtype=features.dtype)], dim=1)
17
+ out_feature = result.scatter_(index=indices.long(), src=features, dim=2, reduce='add')
18
+ cnt = out_feature[:, -1:, :]
19
+ zero_mask = (cnt == 0).float()
20
+ cnt = cnt * (1 - zero_mask) + zero_mask * 1e-5
21
+ vox_feature = out_feature[:, :-1, :] / cnt
22
+ return vox_feature.view(b, c, resolution, resolution, resolution)
23
+
24
+ class Voxelization(nn.Module):
25
+ def __init__(self, resolution, normalize=True, eps=0, scale_pvcnn=False):
26
+ super().__init__()
27
+ self.r = int(resolution)
28
+ self.normalize = normalize
29
+ self.eps = eps
30
+ self.scale_pvcnn = scale_pvcnn
31
+ assert not normalize
32
+
33
+ def forward(self, features, coords):
34
+ with torch.no_grad():
35
+ coords = coords.detach()
36
+
37
+ if self.normalize:
38
+ norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5
39
+ else:
40
+ if self.scale_pvcnn:
41
+ norm_coords = (coords + 1) / 2.0 # [0, 1]
42
+ else:
43
+ norm_coords = (norm_coords + 1) / 2.0
44
+ norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
45
+ vox_coords = torch.round(norm_coords)
46
+ new_vox_feat = my_voxelization(features, vox_coords, self.r)
47
+ return new_vox_feat, norm_coords
48
+
49
+ def extra_repr(self):
50
+ return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '')
PartField/partfield/model/PVCNN/unet_3daware.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import init
7
+
8
+ import einops
9
+
10
+ def conv3x3(in_channels, out_channels, stride=1,
11
+ padding=1, bias=True, groups=1):
12
+ return nn.Conv2d(
13
+ in_channels,
14
+ out_channels,
15
+ kernel_size=3,
16
+ stride=stride,
17
+ padding=padding,
18
+ bias=bias,
19
+ groups=groups)
20
+
21
+ def upconv2x2(in_channels, out_channels, mode='transpose'):
22
+ if mode == 'transpose':
23
+ return nn.ConvTranspose2d(
24
+ in_channels,
25
+ out_channels,
26
+ kernel_size=2,
27
+ stride=2)
28
+ else:
29
+ # out_channels is always going to be the same
30
+ # as in_channels
31
+ return nn.Sequential(
32
+ nn.Upsample(mode='bilinear', scale_factor=2),
33
+ conv1x1(in_channels, out_channels))
34
+
35
+ def conv1x1(in_channels, out_channels, groups=1):
36
+ return nn.Conv2d(
37
+ in_channels,
38
+ out_channels,
39
+ kernel_size=1,
40
+ groups=groups,
41
+ stride=1)
42
+
43
+ class ConvTriplane3dAware(nn.Module):
44
+ """ 3D aware triplane conv (as described in RODIN) """
45
+ def __init__(self, internal_conv_f, in_channels, out_channels, order='xz'):
46
+ """
47
+ Args:
48
+ internal_conv_f: function that should return a 2D convolution Module
49
+ given in and out channels
50
+ order: if triplane input is in 'xz' order
51
+ """
52
+ super(ConvTriplane3dAware, self).__init__()
53
+ # Need 3 seperate convolutions
54
+ self.in_channels = in_channels
55
+ self.out_channels = out_channels
56
+ assert order in ['xz', 'zx']
57
+ self.order = order
58
+ # Going to stack from other planes
59
+ self.plane_convs = nn.ModuleList([
60
+ internal_conv_f(3*self.in_channels, self.out_channels) for _ in range(3)])
61
+
62
+ def forward(self, triplanes_list):
63
+ """
64
+ Args:
65
+ triplanes_list: [(B,Ci,H,W)]*3 in xy,yz,(zx or xz) depending on order
66
+ Returns:
67
+ out_triplanes_list: [(B,Co,H,W)]*3 in xy,yz,(zx or xz) depending on order
68
+ """
69
+ inps = list(triplanes_list)
70
+ xp = 1 #(yz)
71
+ yp = 2 #(zx)
72
+ zp = 0 #(xy)
73
+
74
+ if self.order == 'xz':
75
+ # get into zx order
76
+ inps[yp] = einops.rearrange(inps[yp], 'b c x z -> b c z x')
77
+
78
+
79
+ oplanes = [None]*3
80
+ # order shouldn't matter
81
+ for iplane in [zp, xp, yp]:
82
+ # i_plane -> (j,k)
83
+
84
+ # need to average out i and convert to (j,k)
85
+ # j_plane -> (k,i)
86
+ # k_plane -> (i,j)
87
+ jplane = (iplane+1)%3
88
+ kplane = (iplane+2)%3
89
+
90
+ ifeat = inps[iplane]
91
+ # need to average out nonshared dim
92
+ # Average pool across
93
+
94
+ # j_plane -> (k,i) -> (k,1) -> (1,k) -> (j,k)
95
+ # b c k i -> b c k 1
96
+ jpool = torch.mean(inps[jplane], dim=3 ,keepdim=True)
97
+ jpool = einops.rearrange(jpool, 'b c k 1 -> b c 1 k')
98
+ jpool = einops.repeat(jpool, 'b c 1 k -> b c j k', j=ifeat.size(2))
99
+
100
+ # k_plane -> (i,j) -> (1,j) -> (j,1) -> (j,k)
101
+ # b c i j -> b c 1 j
102
+ kpool = torch.mean(inps[kplane], dim=2 ,keepdim=True)
103
+ kpool = einops.rearrange(kpool, 'b c 1 j -> b c j 1')
104
+ kpool = einops.repeat(kpool, 'b c j 1 -> b c j k', k=ifeat.size(3))
105
+
106
+ # b c h w
107
+ # jpool = jpool.expand_as(ifeat)
108
+ # kpool = kpool.expand_as(ifeat)
109
+
110
+ # concat and conv on feature dim
111
+ catfeat = torch.cat([ifeat, jpool, kpool], dim=1)
112
+ oplane = self.plane_convs[iplane](catfeat)
113
+ oplanes[iplane] = oplane
114
+
115
+ if self.order == 'xz':
116
+ # get back into xz order
117
+ oplanes[yp] = einops.rearrange(oplanes[yp], 'b c z x -> b c x z')
118
+
119
+ return oplanes
120
+
121
+ def roll_triplanes(triplanes_list):
122
+ # B, C, tri, h, w
123
+ tristack = torch.stack((triplanes_list),dim=2)
124
+ return einops.rearrange(tristack, 'b c tri h w -> b c (tri h) w', tri=3)
125
+
126
+ def unroll_triplanes(rolled_triplane):
127
+ # B, C, tri*h, w
128
+ tristack = einops.rearrange(rolled_triplane, 'b c (tri h) w -> b c tri h w', tri=3)
129
+ return torch.unbind(tristack, dim=2)
130
+
131
+ def conv1x1triplane3daware(in_channels, out_channels, order='xz', **kwargs):
132
+ return ConvTriplane3dAware(lambda inp, out: conv1x1(inp,out,**kwargs),
133
+ in_channels, out_channels,order=order)
134
+
135
+ def Normalize(in_channels, num_groups=32):
136
+ num_groups = min(in_channels, num_groups) # avoid error if in_channels < 32
137
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
138
+
139
+ def nonlinearity(x):
140
+ # return F.relu(x)
141
+ # Swish
142
+ return x*torch.sigmoid(x)
143
+
144
+ class Upsample(nn.Module):
145
+ def __init__(self, in_channels, with_conv):
146
+ super().__init__()
147
+ self.with_conv = with_conv
148
+ if self.with_conv:
149
+ self.conv = torch.nn.Conv2d(in_channels,
150
+ in_channels,
151
+ kernel_size=3,
152
+ stride=1,
153
+ padding=1)
154
+
155
+ def forward(self, x):
156
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
157
+ if self.with_conv:
158
+ x = self.conv(x)
159
+ return x
160
+
161
+ class Downsample(nn.Module):
162
+ def __init__(self, in_channels, with_conv):
163
+ super().__init__()
164
+ self.with_conv = with_conv
165
+ if self.with_conv:
166
+ # no asymmetric padding in torch conv, must do it ourselves
167
+ self.conv = torch.nn.Conv2d(in_channels,
168
+ in_channels,
169
+ kernel_size=3,
170
+ stride=2,
171
+ padding=0)
172
+
173
+ def forward(self, x):
174
+ if self.with_conv:
175
+ pad = (0,1,0,1)
176
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
177
+ x = self.conv(x)
178
+ else:
179
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
180
+ return x
181
+
182
+ class ResnetBlock3dAware(nn.Module):
183
+ def __init__(self, in_channels, out_channels=None):
184
+ #, conv_shortcut=False):
185
+ super().__init__()
186
+ self.in_channels = in_channels
187
+ out_channels = in_channels if out_channels is None else out_channels
188
+ self.out_channels = out_channels
189
+ # self.use_conv_shortcut = conv_shortcut
190
+
191
+ self.norm1 = Normalize(in_channels)
192
+ self.conv1 = conv3x3(self.in_channels, self.out_channels)
193
+
194
+ self.norm_mid = Normalize(out_channels)
195
+ self.conv_3daware = conv1x1triplane3daware(self.out_channels, self.out_channels)
196
+
197
+ self.norm2 = Normalize(out_channels)
198
+ self.conv2 = conv3x3(self.out_channels, self.out_channels)
199
+
200
+ if self.in_channels != self.out_channels:
201
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
202
+ out_channels,
203
+ kernel_size=1,
204
+ stride=1,
205
+ padding=0)
206
+
207
+ def forward(self, x):
208
+ # 3x3 plane comm
209
+ h = x
210
+ h = self.norm1(h)
211
+ h = nonlinearity(h)
212
+ h = self.conv1(h)
213
+
214
+ # 1x1 3d aware, crossplane comm
215
+ h = self.norm_mid(h)
216
+ h = nonlinearity(h)
217
+ h = unroll_triplanes(h)
218
+ h = self.conv_3daware(h)
219
+ h = roll_triplanes(h)
220
+
221
+ # 3x3 plane comm
222
+ h = self.norm2(h)
223
+ h = nonlinearity(h)
224
+ h = self.conv2(h)
225
+
226
+ if self.in_channels != self.out_channels:
227
+ x = self.nin_shortcut(x)
228
+
229
+ return x+h
230
+
231
+ class DownConv3dAware(nn.Module):
232
+ """
233
+ A helper Module that performs 2 convolutions and 1 MaxPool.
234
+ A ReLU activation follows each convolution.
235
+ """
236
+ def __init__(self, in_channels, out_channels, downsample=True, with_conv=False):
237
+ super(DownConv3dAware, self).__init__()
238
+
239
+ self.in_channels = in_channels
240
+ self.out_channels = out_channels
241
+
242
+ self.block = ResnetBlock3dAware(in_channels=in_channels,
243
+ out_channels=out_channels)
244
+
245
+ self.do_downsample = downsample
246
+ self.downsample = Downsample(out_channels, with_conv=with_conv)
247
+
248
+ def forward(self, x):
249
+ """
250
+ rolled input, rolled output
251
+ Args:
252
+ x: rolled (b c (tri*h) w)
253
+ """
254
+ x = self.block(x)
255
+ before_pool = x
256
+ # if self.pooling:
257
+ # x = self.pool(x)
258
+ if self.do_downsample:
259
+ # unroll and cat channel-wise (to prevent pooling across triplane boundaries)
260
+ x = einops.rearrange(x, 'b c (tri h) w -> b (c tri) h w', tri=3)
261
+ x = self.downsample(x)
262
+ # undo
263
+ x = einops.rearrange(x, 'b (c tri) h w -> b c (tri h) w', tri=3)
264
+ return x, before_pool
265
+
266
+ class UpConv3dAware(nn.Module):
267
+ """
268
+ A helper Module that performs 2 convolutions and 1 UpConvolution.
269
+ A ReLU activation follows each convolution.
270
+ """
271
+ def __init__(self, in_channels, out_channels,
272
+ merge_mode='concat', with_conv=False): #up_mode='transpose', ):
273
+ super(UpConv3dAware, self).__init__()
274
+
275
+ self.in_channels = in_channels
276
+ self.out_channels = out_channels
277
+ self.merge_mode = merge_mode
278
+
279
+ self.upsample = Upsample(in_channels, with_conv)
280
+
281
+ if self.merge_mode == 'concat':
282
+ self.norm1 = Normalize(in_channels+out_channels)
283
+ self.block = ResnetBlock3dAware(in_channels=in_channels+out_channels,
284
+ out_channels=out_channels)
285
+ else:
286
+ self.norm1 = Normalize(in_channels)
287
+ self.block = ResnetBlock3dAware(in_channels=in_channels,
288
+ out_channels=out_channels)
289
+
290
+
291
+ def forward(self, from_down, from_up):
292
+ """ Forward pass
293
+ rolled inputs, rolled output
294
+ rolled (b c (tri*h) w)
295
+ Arguments:
296
+ from_down: tensor from the encoder pathway
297
+ from_up: upconv'd tensor from the decoder pathway
298
+ """
299
+ # from_up = self.upconv(from_up)
300
+ from_up = self.upsample(from_up)
301
+ if self.merge_mode == 'concat':
302
+ x = torch.cat((from_up, from_down), 1)
303
+ else:
304
+ x = from_up + from_down
305
+
306
+ x = self.norm1(x)
307
+ x = self.block(x)
308
+ return x
309
+
310
+ class UNetTriplane3dAware(nn.Module):
311
+ def __init__(self, out_channels, in_channels=3, depth=5,
312
+ start_filts=64,# up_mode='transpose',
313
+ use_initial_conv=False,
314
+ merge_mode='concat', **kwargs):
315
+ """
316
+ Arguments:
317
+ in_channels: int, number of channels in the input tensor.
318
+ Default is 3 for RGB images.
319
+ depth: int, number of MaxPools in the U-Net.
320
+ start_filts: int, number of convolutional filters for the
321
+ first conv.
322
+ """
323
+ super(UNetTriplane3dAware, self).__init__()
324
+
325
+
326
+ self.out_channels = out_channels
327
+ self.in_channels = in_channels
328
+ self.start_filts = start_filts
329
+ self.depth = depth
330
+
331
+ self.use_initial_conv = use_initial_conv
332
+ if use_initial_conv:
333
+ self.conv_initial = conv1x1(self.in_channels, self.start_filts)
334
+
335
+ self.down_convs = []
336
+ self.up_convs = []
337
+
338
+ # create the encoder pathway and add to a list
339
+ for i in range(depth):
340
+ if i == 0:
341
+ ins = self.start_filts if use_initial_conv else self.in_channels
342
+ else:
343
+ ins = outs
344
+ outs = self.start_filts*(2**i)
345
+ downsamp_it = True if i < depth-1 else False
346
+
347
+ down_conv = DownConv3dAware(ins, outs, downsample = downsamp_it)
348
+ self.down_convs.append(down_conv)
349
+
350
+ for i in range(depth-1):
351
+ ins = outs
352
+ outs = ins // 2
353
+ up_conv = UpConv3dAware(ins, outs,
354
+ merge_mode=merge_mode)
355
+ self.up_convs.append(up_conv)
356
+
357
+ # add the list of modules to current module
358
+ self.down_convs = nn.ModuleList(self.down_convs)
359
+ self.up_convs = nn.ModuleList(self.up_convs)
360
+
361
+ self.norm_out = Normalize(outs)
362
+ self.conv_final = conv1x1(outs, self.out_channels)
363
+
364
+ self.reset_params()
365
+
366
+ @staticmethod
367
+ def weight_init(m):
368
+ if isinstance(m, nn.Conv2d):
369
+ # init.xavier_normal_(m.weight, gain=0.1)
370
+ init.xavier_normal_(m.weight)
371
+ init.constant_(m.bias, 0)
372
+
373
+
374
+ def reset_params(self):
375
+ for i, m in enumerate(self.modules()):
376
+ self.weight_init(m)
377
+
378
+
379
+ def forward(self, x):
380
+ """
381
+ Args:
382
+ x: Stacked triplane expected to be in (B,3,C,H,W)
383
+ """
384
+ # Roll
385
+ x = einops.rearrange(x, 'b tri c h w -> b c (tri h) w', tri=3)
386
+
387
+ if self.use_initial_conv:
388
+ x = self.conv_initial(x)
389
+
390
+ encoder_outs = []
391
+ # encoder pathway, save outputs for merging
392
+ for i, module in enumerate(self.down_convs):
393
+ x, before_pool = module(x)
394
+ encoder_outs.append(before_pool)
395
+
396
+ # Spend a block in the middle
397
+ # x = self.block_mid(x)
398
+
399
+ for i, module in enumerate(self.up_convs):
400
+ before_pool = encoder_outs[-(i+2)]
401
+ x = module(before_pool, x)
402
+
403
+ x = self.norm_out(x)
404
+
405
+ # No softmax is used. This means you need to use
406
+ # nn.CrossEntropyLoss is your training script,
407
+ # as this module includes a softmax already.
408
+ x = self.conv_final(nonlinearity(x))
409
+
410
+ # Unroll
411
+ x = einops.rearrange(x, 'b c (tri h) w -> b tri c h w', tri=3)
412
+ return x
413
+
414
+
415
+ def setup_unet(output_channels, input_channels, unet_cfg):
416
+ if unet_cfg['use_3d_aware']:
417
+ assert(unet_cfg['rolled'])
418
+ unet = UNetTriplane3dAware(
419
+ out_channels=output_channels,
420
+ in_channels=input_channels,
421
+ depth=unet_cfg['depth'],
422
+ use_initial_conv=unet_cfg['use_initial_conv'],
423
+ start_filts=unet_cfg['start_hidden_channels'],)
424
+ else:
425
+ raise NotImplementedError
426
+ return unet
427
+
PartField/partfield/model/UNet/__pycache__/buildingblocks.cpython-310.pyc ADDED
Binary file (17.2 kB). View file