Spaces:
Running
on
Zero
Running
on
Zero
Ruining Li
commited on
Commit
·
4f22fc0
1
Parent(s):
5cbf9bb
Init: add PartField + particulate, track example assets via LFS
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- PartField/LICENSE +36 -0
- PartField/README.md +242 -0
- PartField/applications/.polyscope.ini +6 -0
- PartField/applications/README.md +142 -0
- PartField/applications/multi_shape_cosegment.py +482 -0
- PartField/applications/pack_labels_to_obj.py +47 -0
- PartField/applications/run_smooth_functional_map.py +80 -0
- PartField/applications/shape_pair.py +385 -0
- PartField/applications/single_shape.py +758 -0
- PartField/compute_metric.py +97 -0
- PartField/configs/final/correspondence_demo.yaml +44 -0
- PartField/configs/final/demo.yaml +28 -0
- PartField/download_demo_data.sh +19 -0
- PartField/environment.yml +772 -0
- PartField/partfield/__pycache__/dataloader.cpython-310.pyc +0 -0
- PartField/partfield/__pycache__/model_trainer_pvcnn_only_demo.cpython-310.pyc +0 -0
- PartField/partfield/__pycache__/utils.cpython-310.pyc +0 -0
- PartField/partfield/config/__init__.py +26 -0
- PartField/partfield/config/__pycache__/__init__.cpython-310.pyc +0 -0
- PartField/partfield/config/__pycache__/defaults.cpython-310.pyc +0 -0
- PartField/partfield/config/defaults.py +92 -0
- PartField/partfield/dataloader.py +366 -0
- PartField/partfield/model/PVCNN/__pycache__/conv_pointnet.cpython-310.pyc +0 -0
- PartField/partfield/model/PVCNN/__pycache__/dnnlib_util.cpython-310.pyc +0 -0
- PartField/partfield/model/PVCNN/__pycache__/encoder_pc.cpython-310.pyc +0 -0
- PartField/partfield/model/PVCNN/__pycache__/pc_encoder.cpython-310.pyc +0 -0
- PartField/partfield/model/PVCNN/__pycache__/unet_3daware.cpython-310.pyc +0 -0
- PartField/partfield/model/PVCNN/conv_pointnet.py +251 -0
- PartField/partfield/model/PVCNN/dnnlib_util.py +1074 -0
- PartField/partfield/model/PVCNN/encoder_pc.py +243 -0
- PartField/partfield/model/PVCNN/pc_encoder.py +90 -0
- PartField/partfield/model/PVCNN/pv_module/__init__.py +2 -0
- PartField/partfield/model/PVCNN/pv_module/__pycache__/__init__.cpython-310.pyc +0 -0
- PartField/partfield/model/PVCNN/pv_module/__pycache__/pvconv.cpython-310.pyc +0 -0
- PartField/partfield/model/PVCNN/pv_module/__pycache__/shared_mlp.cpython-310.pyc +0 -0
- PartField/partfield/model/PVCNN/pv_module/__pycache__/voxelization.cpython-310.pyc +0 -0
- PartField/partfield/model/PVCNN/pv_module/ball_query.py +34 -0
- PartField/partfield/model/PVCNN/pv_module/frustum.py +141 -0
- PartField/partfield/model/PVCNN/pv_module/functional/__init__.py +1 -0
- PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/__init__.cpython-310.pyc +0 -0
- PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/devoxelization.cpython-310.pyc +0 -0
- PartField/partfield/model/PVCNN/pv_module/functional/devoxelization.py +12 -0
- PartField/partfield/model/PVCNN/pv_module/loss.py +10 -0
- PartField/partfield/model/PVCNN/pv_module/pointnet.py +113 -0
- PartField/partfield/model/PVCNN/pv_module/pvconv.py +38 -0
- PartField/partfield/model/PVCNN/pv_module/shared_mlp.py +35 -0
- PartField/partfield/model/PVCNN/pv_module/voxelization.py +50 -0
- PartField/partfield/model/PVCNN/unet_3daware.py +427 -0
- PartField/partfield/model/UNet/__pycache__/buildingblocks.cpython-310.pyc +0 -0
.gitattributes
CHANGED
|
@@ -33,4 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 36 |
examples/*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/*.glb filter=lfs diff=lfs merge=lfs -text
|
| 37 |
examples/*.png filter=lfs diff=lfs merge=lfs -text
|
PartField/LICENSE
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NVIDIA License
|
| 2 |
+
|
| 3 |
+
1. Definitions
|
| 4 |
+
|
| 5 |
+
“Licensor” means any person or entity that distributes its Work.
|
| 6 |
+
“Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
|
| 7 |
+
The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
|
| 8 |
+
Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
|
| 9 |
+
|
| 10 |
+
2. License Grant
|
| 11 |
+
|
| 12 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
|
| 13 |
+
|
| 14 |
+
3. Limitations
|
| 15 |
+
|
| 16 |
+
3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
|
| 17 |
+
|
| 18 |
+
3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
|
| 19 |
+
|
| 20 |
+
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for non-commercial research and educational purposes only.
|
| 21 |
+
|
| 22 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
|
| 23 |
+
|
| 24 |
+
3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
|
| 25 |
+
|
| 26 |
+
3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
|
| 27 |
+
|
| 28 |
+
4. Disclaimer of Warranty.
|
| 29 |
+
|
| 30 |
+
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
| 31 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
|
| 32 |
+
|
| 33 |
+
5. Limitation of Liability.
|
| 34 |
+
|
| 35 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
| 36 |
+
|
PartField/README.md
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PartField: Learning 3D Feature Fields for Part Segmentation and Beyond [ICCV 2025]
|
| 2 |
+
**[[Project]](https://research.nvidia.com/labs/toronto-ai/partfield-release/)** **[[PDF]](https://arxiv.org/pdf/2504.11451)**
|
| 3 |
+
|
| 4 |
+
Minghua Liu*, Mikaela Angelina Uy*, Donglai Xiang, Hao Su, Sanja Fidler, Nicholas Sharp, Jun Gao
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
## Overview
|
| 9 |
+

|
| 10 |
+
|
| 11 |
+
PartField is a feedforward model that predicts part-based feature fields for 3D shapes. Our learned features can be clustered to yield a high-quality part decomposition, outperforming the latest open-world 3D part segmentation approaches in both quality and speed. PartField can be applied to a wide variety of inputs in terms of modality, semantic class, and style. The learned feature field exhibits consistency across shapes, enabling applications such as cosegmentation, interactive selection, and correspondence.
|
| 12 |
+
|
| 13 |
+
## Table of Contents
|
| 14 |
+
|
| 15 |
+
- [Pretrained Model](#pretrained-model)
|
| 16 |
+
- [Environment Setup](#environment-setup)
|
| 17 |
+
- [TLDR](#tldr)
|
| 18 |
+
- [Example Run](#example-run)
|
| 19 |
+
- [Interactive Tools and Applications](#interactive-tools-and-applications)
|
| 20 |
+
- [Evaluation on PartObjaverse-Tiny](#evaluation-on-partobjaverse-tiny)
|
| 21 |
+
- [Discussion](#discussion-clustering-with-messy-mesh-connectivities)
|
| 22 |
+
- [Citation](#citation)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## Pretrained Model
|
| 26 |
+
```
|
| 27 |
+
mkdir model
|
| 28 |
+
```
|
| 29 |
+
The link to download our pretrained model is here: [Trained on Objaverse](https://huggingface.co/mikaelaangel/partfield-ckpt/blob/main/model_objaverse.ckpt). Due to licensing restrictions, we are unable to release the model that was also trained on PartNet.
|
| 30 |
+
|
| 31 |
+
## Environment Setup
|
| 32 |
+
|
| 33 |
+
We use Python 3.10 with PyTorch 2.4 and CUDA 12.4. The environment and required packages can be installed individually as follows:
|
| 34 |
+
```
|
| 35 |
+
conda create -n partfield python=3.10
|
| 36 |
+
conda activate partfield
|
| 37 |
+
conda install nvidia/label/cuda-12.4.0::cuda
|
| 38 |
+
pip install psutil
|
| 39 |
+
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
|
| 40 |
+
pip install lightning==2.2 h5py yacs trimesh scikit-image loguru boto3
|
| 41 |
+
pip install mesh2sdf tetgen pymeshlab plyfile einops libigl polyscope potpourri3d simple_parsing arrgh open3d
|
| 42 |
+
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu124.html
|
| 43 |
+
apt install libx11-6 libgl1 libxrender1
|
| 44 |
+
pip install vtk
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
An environment file is also provided and can be used for installation:
|
| 48 |
+
```
|
| 49 |
+
conda env create -f environment.yml
|
| 50 |
+
conda activate partfield
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## TLDR
|
| 54 |
+
1. Input data (`.obj` or `.glb` for meshes, `.ply` for splats) are stored in subfolders under `data/`. You can create a new subfolder and copy your custom files into it.
|
| 55 |
+
2. Extract PartField features by running the script `partfield_inference.py`, passing the arguments `result_name [FEAT_FOL]` and `dataset.data_path [DATA_PATH]`. The output features will be saved in `exp_results/partfield_features/[FEAT_FOL]`.
|
| 56 |
+
3. Segmented parts can be obtained by running the script `run_part_clustering.py`, using the arguments `--root exp/[FEAT_FOL]` and `--dump_dir [PART_OUT_FOL]`. The output segmentations will be saved in `exp_results/clustering/[PART_OUT_FOL]`.
|
| 57 |
+
4. Application demo scripts are available in the `applications/` directory and can be used after extracting PartField features (i.e., after running `partfield_inference.py` on the desired demo data).
|
| 58 |
+
|
| 59 |
+
## Example Run
|
| 60 |
+
### Download Demo Data
|
| 61 |
+
|
| 62 |
+
#### Mesh Data
|
| 63 |
+
We showcase the feasibility of PartField using sample meshes from Objaverse (artist-created) and Trellis3D (AI-generated). Sample data can be downloaded below:
|
| 64 |
+
```
|
| 65 |
+
sh download_demo_data.sh
|
| 66 |
+
```
|
| 67 |
+
Downloaded meshes can be found in `data/objaverse_samples/` and `data/trellis_samples/`.
|
| 68 |
+
|
| 69 |
+
#### Gaussian Splats
|
| 70 |
+
We also demonstrate our approach using Gaussian splatting reconstructions as input. Sample splat reconstruction data from the NeRF dataset can be found [here](https://drive.google.com/drive/folders/1l0njShLq37hn1TovgeF-PVGBBrAdNQnf?usp=sharing). Download the data and place it in the `data/splat_samples/` folder.
|
| 71 |
+
|
| 72 |
+
### Extract Feature Field
|
| 73 |
+
#### Mesh Data
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/objaverse dataset.data_path data/objaverse_samples
|
| 77 |
+
python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/trellis dataset.data_path data/trellis_samples
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
#### Point Clouds / Gaussian Splats
|
| 81 |
+
```
|
| 82 |
+
python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/splat dataset.data_path data/splat_samples is_pc True
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
### Part Segmentation
|
| 86 |
+
#### Mesh Data
|
| 87 |
+
|
| 88 |
+
We use agglomerative clustering for part segmentation on mesh inputs.
|
| 89 |
+
```
|
| 90 |
+
python run_part_clustering.py --root exp_results/partfield_features/objaverse --dump_dir exp_results/clustering/objaverse --source_dir data/objaverse_samples --use_agglo True --max_num_clusters 20 --option 0
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
When the input mesh has multiple connected components or poor connectivity, defining face adjacency by connecting geometrically close faces can yield better results (see discussion below):
|
| 94 |
+
```
|
| 95 |
+
python run_part_clustering.py --root exp_results/partfield_features/trellis --dump_dir exp_results/clustering/trellis --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20 --option 1 --with_knn True
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
Note that agglomerative clustering does not return a fixed clustering result, but rather a hierarchical part tree, where the root node represents the whole shape and each leaf node corresponds to a single triangle face. You can explore more clustering results by adaptively traversing the tree, such as deciding which part should be further segmented.
|
| 99 |
+
|
| 100 |
+
#### Point Cloud / Gaussian Splats
|
| 101 |
+
We use K-Means clustering for part segmentation on point cloud inputs.
|
| 102 |
+
```
|
| 103 |
+
python run_part_clustering.py --root exp_results/partfield_features/splat --dump_dir exp_results/clustering/splat --source_dir data/splat_samples --max_num_clusters 20 --is_pc True
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
## Interactive Tools and Applications
|
| 107 |
+
We include UI tools to demonstrate various applications of PartField. Set up and try out our demos [here](applications/)!
|
| 108 |
+
|
| 109 |
+

|
| 110 |
+
|
| 111 |
+

|
| 112 |
+
|
| 113 |
+
## Evaluation on PartObjaverse-Tiny
|
| 114 |
+
|
| 115 |
+

|
| 116 |
+
|
| 117 |
+
To evaluate all models in PartObjaverse-Tiny, you can download the data [here](https://github.com/Pointcept/SAMPart3D/blob/main/PartObjaverse-Tiny/PartObjaverse-Tiny.md) and run the following commands:
|
| 118 |
+
```
|
| 119 |
+
python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/partobjtiny dataset.data_path data/PartObjaverse-Tiny/PartObjaverse-Tiny_mesh n_point_per_face 2000 n_sample_each 10000
|
| 120 |
+
python run_part_clustering.py --root exp_results/partfield_features/partobjtiny/ --dump_dir exp_results/clustering/partobjtiny --source_dir data/PartObjaverse-Tiny/PartObjaverse-Tiny_mesh --use_agglo True --max_num_clusters 20 --option 0
|
| 121 |
+
```
|
| 122 |
+
If an OOM error occurs, you can reduce the number of points sampled per face—for example, by setting `n_point_per_face` to 500.
|
| 123 |
+
|
| 124 |
+
Evaluation metrics can be obtained by running the command below. The per-category average mIoU reported in the paper is also computed.
|
| 125 |
+
```
|
| 126 |
+
python compute_metric.py
|
| 127 |
+
```
|
| 128 |
+
This evaluation code builds on top of the implementation released by [SAMPart3D](https://github.com/Pointcept/SAMPart3D). Users with their own data and corresponding ground truths can easily modify this script to compute their metrics.
|
| 129 |
+
|
| 130 |
+
## Discussion: Clustering with Messy Mesh Connectivities
|
| 131 |
+
<!-- Some meshes can get messy with a lot of connected components, here the connectivity information may not be useful, causing failure cases when using Agglomerative clustering. In these cases, we provide two alternatives, 1) cluster using KMeans. We provide sample code below, or 2) converting the input mesh to a manifold surface mesh.
|
| 132 |
+
|
| 133 |
+
Sample data download:
|
| 134 |
+
```
|
| 135 |
+
cd data
|
| 136 |
+
mkdir messy_meshes_samples
|
| 137 |
+
cd messy_meshes_samples
|
| 138 |
+
wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-007/00790c705e4c4a1fbc0af9bf5c9e9525.glb
|
| 139 |
+
wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-132/13cc3ffc69964894a2bc94154aed687f.glb
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
Extract Partfield feature on the original mesh and run KMeans clustering:
|
| 143 |
+
```
|
| 144 |
+
python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/messy_meshes_samples dataset.data_path data/messy_meshes_samples
|
| 145 |
+
python run_part_clustering.py --root exp_results/partfield_features/messy_meshes_samples/ --dump_dir exp_results/clustering/messy_meshes_samples_kmeans/ --source_dir data/messy_meshes_samples --max_num_clusters 20
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
Extract convert mesh into a surface manifold, extract Partfield feature and run agglomerative clustering:
|
| 149 |
+
```
|
| 150 |
+
python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/messy_meshes_samples_remesh dataset.data_path data/messy_meshes_samples remesh_demo True
|
| 151 |
+
python run_part_clustering_remesh.py --root exp_results/partfield_features/messy_meshes_samples_remesh --dump_dir exp_results/clustering/messy_meshes_samples_remesh --source_dir data/messy_meshes_samples --use_agglo True --max_num_clusters 20
|
| 152 |
+
|
| 153 |
+
python run_part_clustering_remesh.py --root exp_results/partfield_features/trellis_remesh --dump_dir exp_results/clustering/trellis_remesh --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20
|
| 154 |
+
``` -->
|
| 155 |
+
When using agglomerative clustering for part segmentation, an adjacency matrix is passed into the algorithm, which ideally requires the mesh to be a single connected component. However, some meshes can be messy, containing multiple connected components. If the input mesh is not a single connected component, we add pseudo-edges to the adjacency matrix to make it one. By default, we take a simple approach: adding `N-1` pseudo-edges as a chain to connect `N` components together. However, this approach can lead to poor results when the mesh is poorly connected and fragmented.
|
| 156 |
+
|
| 157 |
+
<img src="assets/messy_meshes_screenshot/option0.png" width="480"/>
|
| 158 |
+
|
| 159 |
+
```
|
| 160 |
+
python run_part_clustering.py --root exp_results/partfield_features/trellis --dump_dir exp_results/clustering/trellis_bad --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20 --option 0
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
When this occurs, we explore different options that can lead to better results:
|
| 164 |
+
|
| 165 |
+
### 1. Preprocess Input Mesh
|
| 166 |
+
|
| 167 |
+
We can perform a simple cleanup on the input meshes by removing duplicate vertices and faces, and by merging nearby vertices using `pymeshlab`. This preprocessing step can be enabled via a flag when generating PartField features:
|
| 168 |
+
|
| 169 |
+
```
|
| 170 |
+
python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/trellis_preprocess dataset.data_path data/trellis_samples preprocess_mesh True
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
When running agglomerative clustering on a cleaned-up mesh, we observe improved part segmentation:
|
| 174 |
+
|
| 175 |
+
<img src="assets/messy_meshes_screenshot/preprocess.png" width="480"/>
|
| 176 |
+
|
| 177 |
+
```
|
| 178 |
+
python run_part_clustering.py --root exp_results/partfield_features/trellis_preprocess --dump_dir exp_results/clustering/trellis_preprocess --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20 --option 0
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
### 2. Cluster with KMeans
|
| 182 |
+
|
| 183 |
+
If modifying the input mesh is not desirable and you prefer to avoid preprocessing, an alternative is to use KMeans clustering, which does not rely on an adjacency matrix.
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
<img src="assets/messy_meshes_screenshot/kmeans.png" width="480"/>
|
| 187 |
+
|
| 188 |
+
```
|
| 189 |
+
python run_part_clustering.py --root exp_results/partfield_features/trellis --dump_dir exp_results/clustering/trellis_kmeans --source_dir data/trellis_samples --max_num_clusters 20
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
### 3. MST-based Adjacency Matrix
|
| 193 |
+
|
| 194 |
+
Instead of simply chaining the connected components of the input mesh, we also explore adding pseudo-edges to the adjacency matrix by constructing a KNN graph using face centroids and computing the minimum spanning tree of that graph.
|
| 195 |
+
|
| 196 |
+
<img src="assets/messy_meshes_screenshot/option1.png" width="480"/>
|
| 197 |
+
|
| 198 |
+
```
|
| 199 |
+
python run_part_clustering.py --root exp_results/partfield_features/trellis --dump_dir exp_results/clustering/trellis_faceadj --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20 --option 1 --with_knn True
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
<!-- ### Remesh with Marching Cubes (experimental)
|
| 203 |
+
|
| 204 |
+
We also explore computing the SDF of the input mesh and then running marching cubes, resulting in a surface mesh that is guaranteed to be a single connected component. We then run clustering on the new mesh and map back the segmentation labels to the original mesh by a voting scheme.
|
| 205 |
+
|
| 206 |
+
<img src="assets/messy_meshes_screenshot/remesh.png" width="480"/>
|
| 207 |
+
|
| 208 |
+
```
|
| 209 |
+
python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/trellis_remesh dataset.data_path data/trellis_samples remesh_demo True
|
| 210 |
+
|
| 211 |
+
python run_part_clustering_remesh.py --root exp_results/partfield_features/trellis_remesh --dump_dir exp_results/clustering/trellis_remesh --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20
|
| 212 |
+
```-->
|
| 213 |
+
|
| 214 |
+
### More Challenging Meshes!
|
| 215 |
+
The proposed approaches improve results for some meshes, but we find that certain cases still do not produce satisfactory segmentations. We leave these challenges for future work. If you're interested, here are some examples of difficult meshes we encountered:
|
| 216 |
+
|
| 217 |
+
**Challenging Meshes:**
|
| 218 |
+
```
|
| 219 |
+
cd data
|
| 220 |
+
mkdir challenge_samples
|
| 221 |
+
cd challenge_samples
|
| 222 |
+
wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-007/00790c705e4c4a1fbc0af9bf5c9e9525.glb
|
| 223 |
+
wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-132/13cc3ffc69964894a2bc94154aed687f.glb
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
## Citation
|
| 227 |
+
```
|
| 228 |
+
@inproceedings{partfield2025,
|
| 229 |
+
title={PartField: Learning 3D Feature Fields for Part Segmentation and Beyond},
|
| 230 |
+
author={Minghua Liu and Mikaela Angelina Uy and Donglai Xiang and Hao Su and Sanja Fidler and Nicholas Sharp and Jun Gao},
|
| 231 |
+
year={2025}
|
| 232 |
+
}
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
## References
|
| 236 |
+
PartField borrows code from the following repositories:
|
| 237 |
+
- [OpenLRM](https://github.com/3DTopia/OpenLRM)
|
| 238 |
+
- [PyTorch 3D UNet](https://github.com/wolny/pytorch-3dunet)
|
| 239 |
+
- [PVCNN](https://github.com/mit-han-lab/pvcnn)
|
| 240 |
+
- [SAMPart3D](https://github.com/Pointcept/SAMPart3D) — evaluation script
|
| 241 |
+
|
| 242 |
+
Many thanks to the authors for sharing their code!
|
PartField/applications/.polyscope.ini
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"windowHeight": 1104,
|
| 3 |
+
"windowPosX": 66,
|
| 4 |
+
"windowPosY": 121,
|
| 5 |
+
"windowWidth": 2215
|
| 6 |
+
}
|
PartField/applications/README.md
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Interactive Tools and Applications
|
| 2 |
+
|
| 3 |
+
## Single-Shape Feature and Segmentation Visualization Tool
|
| 4 |
+
We can visualize the output features and segmentation of a single shape by running the script below:
|
| 5 |
+
|
| 6 |
+
```
|
| 7 |
+
cd applications/
|
| 8 |
+
python single_shape.py --data_root ../exp_results/partfield_features/trellis/ --filename dwarf
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
- `Mode: pca, feature_viz, cluster_agglo, cluster_kmeans`
|
| 12 |
+
- `pca` : Visualizes the pca Partfield features of the input model as colors.
|
| 13 |
+
- `feature_viz` : Visualizes each dimension of the PartField features as a colormap.
|
| 14 |
+
- `cluster_agglo` : Visualizes the part segmentation of the input model using Agglomerative clustering.
|
| 15 |
+
- Number of clusters is specified with the slider.
|
| 16 |
+
- `Adj Matrix Def`: Specifies how the adjacency matrix is defined for the clustering algorithm by adding dummy edges to make the input mesh a single connected component.
|
| 17 |
+
- `Add KNN edges` : Adds additional dummy edges based on k nearest neighbors.
|
| 18 |
+
- `cluster_kmeans` : Visualizes the part segmentation of the input model using KMeans clustering.
|
| 19 |
+
- Number of clusters is specified with the slider.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## Shape-Pair Co-Segmentation and Feature Exploration Tool
|
| 23 |
+
We provide a tool to analyze and visualize a pair of shapes that has two main functionalities: 1) **Co-segmentation** via co-clustering and 2) Partfield **feature exploration** and visualization. Try it out as follows:
|
| 24 |
+
|
| 25 |
+
```
|
| 26 |
+
cd applications/
|
| 27 |
+
python shape_pair.py --data_root ../exp_results/partfield_features/trellis/ --filename dwarf --filename_alt goblin
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
### Co-Clustering for Co-Segmentation
|
| 31 |
+
|
| 32 |
+
Here explains the use-case for `Mode: co-segmentation`.
|
| 33 |
+
|
| 34 |
+

|
| 35 |
+
|
| 36 |
+
The shape-pair is co-segmented by running co-clustering, In this application, we use the KMeans clustering algorithm. The `first shape (left)` is separated into parts via **unsupervised clustering** of its features with KMeans, from which the parts of the `second shape (right)` are then defined.
|
| 37 |
+
|
| 38 |
+
Below are a list parameters:
|
| 39 |
+
- `Source init`:
|
| 40 |
+
- `True`: Initializes the cluster centers of the second shape (right) with the cluster centers of the first shape (left).
|
| 41 |
+
- `False`: Uses KMeans++ to initialize the cluster centers for KMeans for the second shape.
|
| 42 |
+
- `Independent`:
|
| 43 |
+
- `True`: Labels after running KMeans clustering are directly used as parts for the second shape. Correspondence with the parts of the first shape is not explicitly computed after KMeans clustering.
|
| 44 |
+
- `False`: After KMeans clustering is ran on the features of the second shape, the mean features for each unique part is then computed. The mean part feature for each part of the first shape is also computed. Then the parts of the second shaped are assigned labels based on the nearest neighbor part of the first shape.
|
| 45 |
+
- `Num cluster`:
|
| 46 |
+
- `Model1` : A slider is used to specify the number of parts for the first shape, i.e. number of clusters for KMeans clustering.
|
| 47 |
+
- `Model2` : A slider is used to specify the number of parts for the second shape, i.e. number of clusters for KMeans clustering. Note: if `Source init` is set to `True` then this slider is ignored and the number of clusters for Model1 is used.
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
### Feature Exploration and Visualization
|
| 51 |
+
|
| 52 |
+
Here explains the use-case for `Mode: feature_explore`.
|
| 53 |
+
|
| 54 |
+

|
| 55 |
+
|
| 56 |
+
This feature allows us to select a query point from the first shape (left) and the feature distance to all points in the second shape (left) and itself is then visualized as a colormap.
|
| 57 |
+
- `range` : A slider to specify the distance radius for feature similarity visualization. Large values will result in bigger highlighter areas.
|
| 58 |
+
- `continuous` :
|
| 59 |
+
- `False` : Query point is specified with a mouse click.
|
| 60 |
+
- `True` : You can slide your mouse around the first mesh to visualize feature distances.
|
| 61 |
+
|
| 62 |
+
## Multi-shape Cosegmentation Tool
|
| 63 |
+
We further demonstrate PartField for cosegmentation of multiple/a set of shapes. Try out our demo application as follows:
|
| 64 |
+
|
| 65 |
+
### Dependency Installation
|
| 66 |
+
Let's first install the necessary dependencies for this tool:
|
| 67 |
+
```
|
| 68 |
+
pip install cuml-cu12
|
| 69 |
+
pip install xgboost
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### Dataset
|
| 73 |
+
We use the Shape COSEG dataset for our demo. We first download the dataset here:
|
| 74 |
+
```
|
| 75 |
+
mkdir data/coseg_guitar
|
| 76 |
+
cd data/coseg_guitar
|
| 77 |
+
wget https://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/shapes.zip
|
| 78 |
+
wget https://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/gt.zip
|
| 79 |
+
unzip shapes.zip
|
| 80 |
+
unzip gt.zip
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
Now, let's extract Partfield features for the set of shapes:
|
| 84 |
+
```
|
| 85 |
+
python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/coseg_guitar/ dataset.data_path data/coseg_guitar/shapes
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
Now, we're ready to run the tool! We support two modes: 1) **Few-shot** with click-based annotations and 2) **Supervised** with ground truth labels.
|
| 89 |
+
|
| 90 |
+
### Annotate Mode
|
| 91 |
+

|
| 92 |
+
|
| 93 |
+
We can run our few-shot segmentation tool as follows:
|
| 94 |
+
```
|
| 95 |
+
cd applications/
|
| 96 |
+
python multi_shape_cosegment.py --meshes ../exp_results/partfield_features/coseg_guitar/
|
| 97 |
+
```
|
| 98 |
+
We can annotate the segments with a few clicks. A classifier is then ran to obtain the part segmentation.
|
| 99 |
+
- N_class: number of segmentation class labels
|
| 100 |
+
- Annotate:
|
| 101 |
+
- `00, 01, 02, ...`: Select the segmentation class label, then click on a shape region of that class.
|
| 102 |
+
- `Undo Last Selection`: Removes and disregards the last annotation made.
|
| 103 |
+
- Fit:
|
| 104 |
+
- Fit Method: Selects the classification method used for fitting. Default uses `Logistic Regression`.
|
| 105 |
+
- `Update Fit`: By default, the fitting process is automatically updated. This can also be changed to a manual update.
|
| 106 |
+
|
| 107 |
+
### Ground Truth Labels Mode
|
| 108 |
+
|
| 109 |
+

|
| 110 |
+
|
| 111 |
+
Alternatively, we can use the ground truth labels of a subset of the shapes to train the classifier.
|
| 112 |
+
|
| 113 |
+
```
|
| 114 |
+
cd applications/
|
| 115 |
+
python multi_shape_cosegment.py --meshes ../exp_results/partfield_features/coseg_guitar/ --n_train_subset 15
|
| 116 |
+
```
|
| 117 |
+
`Fit Method` can also be selected here to choose the classifier to be used.
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
## 3D Correspondences
|
| 121 |
+
|
| 122 |
+
First, we clone the repository [SmoothFunctionalMaps](https://github.com/RobinMagnet/SmoothFunctionalMaps) and install additional packages.
|
| 123 |
+
```
|
| 124 |
+
pip install omegaconf robust_laplacian
|
| 125 |
+
git submodule init
|
| 126 |
+
git submodule update --recursive
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
Download the [DenseCorr3D dataset](https://drive.google.com/file/d/1bpgsNu8JewRafhdRN4woQL7ObQtfgcpu/view?usp=sharing) into the `data` folder. Unzip the contents and ensure that the file structure is organized so that you can access
|
| 130 |
+
`data/DenseCorr3D/animals/071b8_toy_animals_017`.
|
| 131 |
+
|
| 132 |
+
Extract the PartField features.
|
| 133 |
+
```
|
| 134 |
+
# run in root directory of this repo
|
| 135 |
+
python partfield_inference.py -c configs/final/correspondence_demo.yaml --opts continue_ckpt model/model_objaverse.ckpt preprocess_mesh True
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
Run the functional map.
|
| 139 |
+
```
|
| 140 |
+
cd applications/
|
| 141 |
+
python run_smooth_functional_map.py -c ../configs/final/correspondence_demo.yaml --opts
|
| 142 |
+
```
|
PartField/applications/multi_shape_cosegment.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import argparse
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
from arrgh import arrgh
|
| 7 |
+
import polyscope as ps
|
| 8 |
+
import polyscope.imgui as psim
|
| 9 |
+
import potpourri3d as pp3d
|
| 10 |
+
import trimesh
|
| 11 |
+
|
| 12 |
+
import cuml
|
| 13 |
+
import xgboost as xgb
|
| 14 |
+
|
| 15 |
+
import os, random
|
| 16 |
+
|
| 17 |
+
import sys
|
| 18 |
+
sys.path.append("..")
|
| 19 |
+
from partfield.utils import *
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class State:
|
| 23 |
+
|
| 24 |
+
objects = None
|
| 25 |
+
train_objects = None
|
| 26 |
+
|
| 27 |
+
# Input options
|
| 28 |
+
subsample_inputs: int = -1
|
| 29 |
+
n_train_subset: int = 0
|
| 30 |
+
|
| 31 |
+
# Label
|
| 32 |
+
N_class: int = 2
|
| 33 |
+
|
| 34 |
+
# Annotations
|
| 35 |
+
# A annotations (initially A = 0)
|
| 36 |
+
anno_feat: np.array = np.zeros((0,448), dtype=np.float32) # [A,F]
|
| 37 |
+
anno_label: np.array = np.zeros((0,), dtype=np.int32) # [A]
|
| 38 |
+
anno_pos: np.array = np.zeros((0,3), dtype=np.float32) # [A,3]
|
| 39 |
+
|
| 40 |
+
# Intermediate selection data
|
| 41 |
+
is_selecting: bool = False
|
| 42 |
+
selection_class: int = 0
|
| 43 |
+
|
| 44 |
+
# Fitting algorithm
|
| 45 |
+
fit_to: str = "Annotations"
|
| 46 |
+
fit_method : str = "LogisticRegression"
|
| 47 |
+
auto_update_fit: bool = True
|
| 48 |
+
|
| 49 |
+
# Training data
|
| 50 |
+
# T training datapoints
|
| 51 |
+
train_feat: np.array = np.zeros((0,448), dtype=np.float32) # [T,F]
|
| 52 |
+
train_label: np.array = np.zeros((0,), dtype=np.int32) # [T]
|
| 53 |
+
|
| 54 |
+
# Viz
|
| 55 |
+
grid_w : int = 8
|
| 56 |
+
per_obj_shift : float = 2.
|
| 57 |
+
anno_radius : float = 0.01
|
| 58 |
+
ps_cloud_annotation = None
|
| 59 |
+
ps_structure_name_to_index_map = {}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
fit_methods_list = ["LinearRegression", "LogisticRegression", "LinearSVC", "RandomForest", "NearestNeighbors", "XGBoost"]
|
| 63 |
+
fit_to_list = ["Annotations", "TrainingSet"]
|
| 64 |
+
|
| 65 |
+
def load_mesh_and_features(mesh_filepath, ind, require_gt=False, gt_label_fol = ""):
|
| 66 |
+
|
| 67 |
+
dirpath, filename = os.path.split(mesh_filepath)
|
| 68 |
+
filename_core = filename[9:-6] # splits off "feat_pca_" ... "_0.ply"
|
| 69 |
+
feature_filename = "part_feat_"+ filename_core + "_0_batch.npy"
|
| 70 |
+
feature_filepath = os.path.join(dirpath, feature_filename)
|
| 71 |
+
|
| 72 |
+
gt_filename = filename_core + ".seg"
|
| 73 |
+
gt_filepath = os.path.join(gt_label_fol, gt_filename)
|
| 74 |
+
have_gt = os.path.isfile(gt_filepath)
|
| 75 |
+
|
| 76 |
+
print(" Reading file:")
|
| 77 |
+
print(f" Mesh filename: {mesh_filepath}")
|
| 78 |
+
print(f" Feature filename: {feature_filepath}")
|
| 79 |
+
print(f" Ground Truth Label filename: {gt_filepath} -- present = {have_gt}")
|
| 80 |
+
|
| 81 |
+
# load features
|
| 82 |
+
feat = np.load(feature_filepath, allow_pickle=False)
|
| 83 |
+
feat = feat.astype(np.float32)
|
| 84 |
+
|
| 85 |
+
# load mesh things
|
| 86 |
+
# TODO replace this with just loading V/F from numpy archive
|
| 87 |
+
tm = load_mesh_util(mesh_filepath)
|
| 88 |
+
|
| 89 |
+
V = np.array(tm.vertices, dtype=np.float32)
|
| 90 |
+
F = np.array(tm.faces)
|
| 91 |
+
|
| 92 |
+
# load ground truth, if available
|
| 93 |
+
if have_gt:
|
| 94 |
+
gt_labels = np.loadtxt(gt_filepath)
|
| 95 |
+
gt_labels = gt_labels.astype(np.int32) - 1
|
| 96 |
+
else:
|
| 97 |
+
if require_gt:
|
| 98 |
+
raise ValueError("could not find ground-truth file, but it is required")
|
| 99 |
+
gt_labels = None
|
| 100 |
+
|
| 101 |
+
# pca_colors = None
|
| 102 |
+
|
| 103 |
+
return {
|
| 104 |
+
'nicename' : f"{ind:02d}_{filename_core}",
|
| 105 |
+
'mesh_filepath' : mesh_filepath,
|
| 106 |
+
'feature_filepath' : feature_filepath,
|
| 107 |
+
'V' : V,
|
| 108 |
+
'F' : F,
|
| 109 |
+
'feat_np' : feat,
|
| 110 |
+
# 'feat_pt' : torch.tensor(feat, device='cuda'),
|
| 111 |
+
'gt_labels' : gt_labels
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
def shift_for_ind(state : State, ind):
|
| 115 |
+
|
| 116 |
+
x_ind = ind % state.grid_w
|
| 117 |
+
y_ind = ind // state.grid_w
|
| 118 |
+
|
| 119 |
+
shift = np.array([state.per_obj_shift * x_ind, 0, -state.per_obj_shift * y_ind])
|
| 120 |
+
|
| 121 |
+
return shift
|
| 122 |
+
|
| 123 |
+
def viz_upper_limit(state : State, ind_count):
|
| 124 |
+
|
| 125 |
+
x_max = min(ind_count, state.grid_w)
|
| 126 |
+
y_max = ind_count // state.grid_w
|
| 127 |
+
|
| 128 |
+
bound = np.array([state.per_obj_shift * x_max, 0, -state.per_obj_shift * y_max])
|
| 129 |
+
|
| 130 |
+
return bound
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def initialize_object_viz(state : State, obj, index=0):
|
| 134 |
+
obj['ps_mesh'] = ps.register_surface_mesh(obj['nicename'], obj['V'], obj['F'], color=(.8, .8, .8))
|
| 135 |
+
shift = shift_for_ind(state, index)
|
| 136 |
+
obj['ps_mesh'].translate(shift)
|
| 137 |
+
obj['ps_mesh'].set_selection_mode('faces_only')
|
| 138 |
+
state.ps_structure_name_to_index_map[obj['nicename']] = index
|
| 139 |
+
|
| 140 |
+
def update_prediction(state: State):
|
| 141 |
+
|
| 142 |
+
print("Updating predictions..")
|
| 143 |
+
|
| 144 |
+
N_anno = state.anno_label.shape[0]
|
| 145 |
+
|
| 146 |
+
# Quick out if we don't have at least two distinct class labels present
|
| 147 |
+
if(state.fit_to == "Annotations" and len(np.unique(state.anno_label)) <= 1):
|
| 148 |
+
return state
|
| 149 |
+
|
| 150 |
+
# Quick out if we don't have
|
| 151 |
+
if(state.fit_to == "TrainingSet" and state.train_objects is None):
|
| 152 |
+
return state
|
| 153 |
+
|
| 154 |
+
if state.fit_method == "LinearRegression":
|
| 155 |
+
classifier = cuml.multiclass.MulticlassClassifier(cuml.linear_model.LinearRegression(), strategy='ovr')
|
| 156 |
+
elif state.fit_method == "LogisticRegression":
|
| 157 |
+
classifier = cuml.multiclass.MulticlassClassifier(cuml.linear_model.LogisticRegression(), strategy='ovr')
|
| 158 |
+
elif state.fit_method == "LinearSVC":
|
| 159 |
+
classifier = cuml.multiclass.MulticlassClassifier(cuml.svm.LinearSVC(), strategy='ovr')
|
| 160 |
+
elif state.fit_method == "RandomForest":
|
| 161 |
+
classifier = cuml.ensemble.RandomForestClassifier()
|
| 162 |
+
elif state.fit_method == "NearestNeighbors":
|
| 163 |
+
classifier = cuml.multiclass.MulticlassClassifier(cuml.neighbors.KNeighborsRegressor(n_neighbors=1), strategy='ovr')
|
| 164 |
+
elif state.fit_method == "XGBoost":
|
| 165 |
+
classifier = xgb.XGBClassifier(max_depth=7, n_estimators=1000)
|
| 166 |
+
else:
|
| 167 |
+
raise ValueError("unrecognized fit method")
|
| 168 |
+
|
| 169 |
+
if state.fit_to == "TrainingSet":
|
| 170 |
+
|
| 171 |
+
all_train_feats = []
|
| 172 |
+
all_train_labels = []
|
| 173 |
+
for obj in state.train_objects:
|
| 174 |
+
all_train_feats.append(obj['feat_np'])
|
| 175 |
+
all_train_labels.append(obj['gt_labels'])
|
| 176 |
+
|
| 177 |
+
all_train_feats = np.concatenate(all_train_feats, axis=0)
|
| 178 |
+
all_train_labels = np.concatenate(all_train_labels, axis=0)
|
| 179 |
+
|
| 180 |
+
state.N_class = np.max(all_train_labels) + 1
|
| 181 |
+
|
| 182 |
+
classifier.fit(all_train_feats, all_train_labels)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
elif state.fit_to == "Annotations":
|
| 186 |
+
classifier.fit(state.anno_feat,state.anno_label)
|
| 187 |
+
else:
|
| 188 |
+
raise ValueError("unrecognized fit to")
|
| 189 |
+
|
| 190 |
+
n_total = 0
|
| 191 |
+
n_correct = 0
|
| 192 |
+
|
| 193 |
+
for obj in state.objects:
|
| 194 |
+
obj['pred_label'] = classifier.predict(obj['feat_np'])
|
| 195 |
+
|
| 196 |
+
if obj['gt_labels'] is not None:
|
| 197 |
+
n_total += obj['gt_labels'].shape[0]
|
| 198 |
+
n_correct += np.sum(obj['pred_label'] == obj['gt_labels'], dtype=np.int32)
|
| 199 |
+
|
| 200 |
+
if(state.fit_to == "TrainingSet" and n_total > 0):
|
| 201 |
+
frac = n_correct / n_total
|
| 202 |
+
print(f"Test accuracy: {n_correct:d} / {n_total:d} {100*frac:.02f}%")
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
print("Done updating predictions.")
|
| 206 |
+
|
| 207 |
+
return state
|
| 208 |
+
|
| 209 |
+
def update_prediction_viz(state: State):
|
| 210 |
+
|
| 211 |
+
for obj in state.objects:
|
| 212 |
+
if 'pred_label' in obj:
|
| 213 |
+
obj['ps_mesh'].add_scalar_quantity("pred labels", obj['pred_label'], defined_on='faces', vminmax=(0,state.N_class-1), cmap='turbo', enabled=True)
|
| 214 |
+
|
| 215 |
+
return state
|
| 216 |
+
|
| 217 |
+
def update_annotation_viz(state: State):
|
| 218 |
+
|
| 219 |
+
ps_cloud = ps.register_point_cloud("annotations", state.anno_pos, radius=state.anno_radius, material='candy')
|
| 220 |
+
ps_cloud.add_scalar_quantity("labels", state.anno_label, vminmax=(0,state.N_class-1), cmap='turbo', enabled=True)
|
| 221 |
+
|
| 222 |
+
state.ps_cloud_annotation = ps_cloud
|
| 223 |
+
|
| 224 |
+
return state
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def filter_old_labels(state: State):
|
| 228 |
+
"""
|
| 229 |
+
Filter out annotations from classes that don't exist any more
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
keep_mask = state.anno_label < state.N_class
|
| 233 |
+
state.anno_feat = state.anno_feat[keep_mask,:]
|
| 234 |
+
state.anno_label = state.anno_label[keep_mask]
|
| 235 |
+
state.anno_pos = state.anno_pos[keep_mask,:]
|
| 236 |
+
|
| 237 |
+
return state
|
| 238 |
+
|
| 239 |
+
def undo_last_annotation(state: State):
|
| 240 |
+
|
| 241 |
+
state.anno_feat = state.anno_feat[:-1,:]
|
| 242 |
+
state.anno_label = state.anno_label[:-1]
|
| 243 |
+
state.anno_pos = state.anno_pos[:-1,:]
|
| 244 |
+
|
| 245 |
+
return state
|
| 246 |
+
|
| 247 |
+
def ps_callback(state_list):
|
| 248 |
+
state : State = state_list[0] # hacky pass-by-reference, since we want to edit it below
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# If we're in selection mode, that's the only thing we can do
|
| 252 |
+
if state.is_selecting:
|
| 253 |
+
|
| 254 |
+
psim.TextUnformatted(f"Annotating class {state.selection_class:02d}. Click on any mesh face.")
|
| 255 |
+
|
| 256 |
+
io = psim.GetIO()
|
| 257 |
+
if io.MouseClicked[0]:
|
| 258 |
+
screen_coords = io.MousePos
|
| 259 |
+
pick_result = ps.pick(screen_coords=screen_coords)
|
| 260 |
+
|
| 261 |
+
# Check if we hit one of the meshes
|
| 262 |
+
if pick_result.is_hit and pick_result.structure_name in state.ps_structure_name_to_index_map:
|
| 263 |
+
if pick_result.structure_data['element_type'] != "face":
|
| 264 |
+
# shouldn't be possible
|
| 265 |
+
raise ValueError("pick returned non-face")
|
| 266 |
+
|
| 267 |
+
i_obj = state.ps_structure_name_to_index_map[pick_result.structure_name]
|
| 268 |
+
f_hit = pick_result.structure_data['index']
|
| 269 |
+
|
| 270 |
+
obj = state.objects[i_obj]
|
| 271 |
+
V = obj['V']
|
| 272 |
+
F = obj['F']
|
| 273 |
+
feat = obj['feat_np']
|
| 274 |
+
|
| 275 |
+
face_corners = V[F[f_hit,:],:]
|
| 276 |
+
new_anno_feat = feat[f_hit,:]
|
| 277 |
+
new_anno_label = state.selection_class
|
| 278 |
+
new_anno_pos = np.mean(face_corners, axis=0) + shift_for_ind(state, i_obj)
|
| 279 |
+
|
| 280 |
+
state.anno_feat = np.concatenate((state.anno_feat, new_anno_feat[None,:]))
|
| 281 |
+
state.anno_label = np.concatenate((state.anno_label, np.array((new_anno_label,))))
|
| 282 |
+
state.anno_pos = np.concatenate((state.anno_pos, new_anno_pos[None,:]))
|
| 283 |
+
|
| 284 |
+
state = update_annotation_viz(state)
|
| 285 |
+
state.is_selecting = False
|
| 286 |
+
needs_pred_update = True
|
| 287 |
+
|
| 288 |
+
if state.auto_update_fit:
|
| 289 |
+
state = update_prediction(state)
|
| 290 |
+
state = update_prediction_viz(state)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
return
|
| 294 |
+
|
| 295 |
+
# If not selecting, build the main UI
|
| 296 |
+
needs_pred_update = False
|
| 297 |
+
|
| 298 |
+
psim.PushItemWidth(150)
|
| 299 |
+
changed, state.N_class = psim.InputInt("N_class", state.N_class, step=1)
|
| 300 |
+
psim.PopItemWidth()
|
| 301 |
+
if changed:
|
| 302 |
+
state = filter_old_labels(state)
|
| 303 |
+
state = update_annotation_viz(state)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# Check for keypress annotation
|
| 307 |
+
io = psim.GetIO()
|
| 308 |
+
class_keys = { 'w' : 0, '1' : 1, '2' : 2, '3' : 3, '4' : 4, '5' : 5, '6' : 6, '7' : 7, '8' : 8, '9' : 9,}
|
| 309 |
+
for c in class_keys:
|
| 310 |
+
if class_keys[c] >= state.N_class:
|
| 311 |
+
continue
|
| 312 |
+
|
| 313 |
+
if psim.IsKeyPressed(ps.get_key_code(c)):
|
| 314 |
+
state.is_selecting = True
|
| 315 |
+
state.selection_class = class_keys[c]
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver)
|
| 319 |
+
if(psim.TreeNode("Annotate")):
|
| 320 |
+
|
| 321 |
+
psim.TextUnformatted("New class annotation. Select class to add add annotation for:")
|
| 322 |
+
psim.TextUnformatted("(alternately, press key {w,1,2,3,4...})")
|
| 323 |
+
for i_class in range(state.N_class):
|
| 324 |
+
|
| 325 |
+
if i_class > 0:
|
| 326 |
+
psim.SameLine()
|
| 327 |
+
|
| 328 |
+
if psim.Button(f"{i_class:02d}"):
|
| 329 |
+
state.is_selecting = True
|
| 330 |
+
state.selection_class = i_class
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
if psim.Button("Undo Last Annotation"):
|
| 334 |
+
state = undo_last_annotation(state)
|
| 335 |
+
state = update_annotation_viz(state)
|
| 336 |
+
needs_pred_update = True
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
psim.TreePop()
|
| 341 |
+
|
| 342 |
+
psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver)
|
| 343 |
+
if(psim.TreeNode("Fit")):
|
| 344 |
+
|
| 345 |
+
psim.PushItemWidth(150)
|
| 346 |
+
|
| 347 |
+
changed, ind = psim.Combo("Fit To", fit_to_list.index(state.fit_to), fit_to_list)
|
| 348 |
+
if changed:
|
| 349 |
+
state.fit_to = fit_methods_list[ind]
|
| 350 |
+
needs_pred_update = True
|
| 351 |
+
|
| 352 |
+
changed, ind = psim.Combo("Fit Method", fit_methods_list.index(state.fit_method), fit_methods_list)
|
| 353 |
+
if changed:
|
| 354 |
+
state.fit_method = fit_methods_list[ind]
|
| 355 |
+
needs_pred_update = True
|
| 356 |
+
|
| 357 |
+
if psim.Button("Update fit"):
|
| 358 |
+
state = update_prediction(state)
|
| 359 |
+
state = update_prediction_viz(state)
|
| 360 |
+
|
| 361 |
+
psim.SameLine()
|
| 362 |
+
|
| 363 |
+
changed, state.auto_update_fit = psim.Checkbox("Auto-update fit", state.auto_update_fit)
|
| 364 |
+
if changed:
|
| 365 |
+
needs_pred_update = True
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
psim.PopItemWidth()
|
| 369 |
+
|
| 370 |
+
psim.TreePop()
|
| 371 |
+
|
| 372 |
+
psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver)
|
| 373 |
+
if(psim.TreeNode("Visualization")):
|
| 374 |
+
|
| 375 |
+
psim.PushItemWidth(150)
|
| 376 |
+
changed, state.anno_radius = psim.SliderFloat("Annotation Point Radius", state.anno_radius, 0.00001, 0.02)
|
| 377 |
+
if changed:
|
| 378 |
+
state = update_annotation_viz(state)
|
| 379 |
+
psim.PopItemWidth()
|
| 380 |
+
|
| 381 |
+
psim.TreePop()
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
if needs_pred_update and state.auto_update_fit:
|
| 385 |
+
state = update_prediction(state)
|
| 386 |
+
state = update_prediction_viz(state)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def main():
|
| 390 |
+
|
| 391 |
+
state = State()
|
| 392 |
+
|
| 393 |
+
## Parse args
|
| 394 |
+
parser = argparse.ArgumentParser()
|
| 395 |
+
|
| 396 |
+
parser.add_argument('--meshes', nargs='+', help='List of meshes to process.', required=True)
|
| 397 |
+
parser.add_argument('--n_train_subset', default=0, help='How many meshes to train on.')
|
| 398 |
+
parser.add_argument('--gt_label_fol', default="../data/coseg_guitar/gt", help='Path where labels are stored.')
|
| 399 |
+
parser.add_argument('--subsample_inputs', default=state.subsample_inputs, help='Only show a random fraction of inputs')
|
| 400 |
+
parser.add_argument('--per_obj_shift', default=state.per_obj_shift, help='How to space out objects in UI grid')
|
| 401 |
+
parser.add_argument('--grid_w', default=state.grid_w, help='Grid width')
|
| 402 |
+
|
| 403 |
+
args = parser.parse_args()
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
state.n_train_subset = int(args.n_train_subset)
|
| 407 |
+
state.subsample_inputs = int(args.subsample_inputs)
|
| 408 |
+
state.per_obj_shift = float(args.per_obj_shift)
|
| 409 |
+
state.grid_w = int(args.grid_w)
|
| 410 |
+
|
| 411 |
+
## Load data
|
| 412 |
+
# First, resolve directories to load all files in directory
|
| 413 |
+
all_filepaths = []
|
| 414 |
+
print("Resolving passed directories")
|
| 415 |
+
for entry in args.meshes:
|
| 416 |
+
if os.path.isdir(entry):
|
| 417 |
+
dir_path = entry
|
| 418 |
+
print(f" processing directory {dir_path}")
|
| 419 |
+
for filename in os.listdir(dir_path):
|
| 420 |
+
file_path = os.path.join(dir_path, filename)
|
| 421 |
+
if os.path.isfile(file_path) and file_path.endswith(".ply") and "feat_pca" in file_path:
|
| 422 |
+
print(f" adding file {file_path}")
|
| 423 |
+
all_filepaths.append(file_path)
|
| 424 |
+
else:
|
| 425 |
+
all_filepaths.append(entry)
|
| 426 |
+
|
| 427 |
+
random.shuffle(all_filepaths)
|
| 428 |
+
|
| 429 |
+
if state.subsample_inputs != -1:
|
| 430 |
+
all_filepaths = all_filepaths[:state.subsample_inputs]
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
if state.n_train_subset != 0:
|
| 434 |
+
|
| 435 |
+
print(state.n_train_subset)
|
| 436 |
+
|
| 437 |
+
train_filepaths = all_filepaths[:state.n_train_subset]
|
| 438 |
+
all_filepaths = all_filepaths[state.n_train_subset:]
|
| 439 |
+
|
| 440 |
+
print(f"Loading {len(train_filepaths)} files")
|
| 441 |
+
state.train_objects = []
|
| 442 |
+
for i, file_path in enumerate(train_filepaths):
|
| 443 |
+
state.train_objects.append(load_mesh_and_features(file_path, i, require_gt=True, gt_label_fol=args.gt_label_fol))
|
| 444 |
+
|
| 445 |
+
state.fit_to = "TrainingSet"
|
| 446 |
+
|
| 447 |
+
# Load files
|
| 448 |
+
print(f"Loading {len(all_filepaths)} files")
|
| 449 |
+
state.objects = []
|
| 450 |
+
for i, file_path in enumerate(all_filepaths):
|
| 451 |
+
state.objects.append(load_mesh_and_features(file_path, i))
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
## Set up visualization
|
| 455 |
+
ps.init()
|
| 456 |
+
ps.set_automatically_compute_scene_extents(False)
|
| 457 |
+
lim = viz_upper_limit(state, len(state.objects))
|
| 458 |
+
ps.set_length_scale(np.linalg.norm(lim) / 4.)
|
| 459 |
+
low = np.array((0, -1., -1.))
|
| 460 |
+
high = lim
|
| 461 |
+
ps.set_bounding_box(low, high)
|
| 462 |
+
|
| 463 |
+
for ind, o in enumerate(state.objects):
|
| 464 |
+
initialize_object_viz(state, o, ind)
|
| 465 |
+
|
| 466 |
+
print(f"Loaded {len(state.objects)} objects")
|
| 467 |
+
if state.n_train_subset != 0:
|
| 468 |
+
print(f"Loaded {len(state.train_objects)} training objects")
|
| 469 |
+
|
| 470 |
+
# One first prediction
|
| 471 |
+
# (does nothing if there is no annotatoins / training data)
|
| 472 |
+
state = update_prediction(state)
|
| 473 |
+
state = update_prediction_viz(state)
|
| 474 |
+
|
| 475 |
+
# Start the interactive UI
|
| 476 |
+
ps.set_user_callback(lambda : ps_callback([state]))
|
| 477 |
+
ps.show()
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
if __name__ == "__main__":
|
| 481 |
+
main()
|
| 482 |
+
|
PartField/applications/pack_labels_to_obj.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys, os, fnmatch, re
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib
|
| 6 |
+
from matplotlib import colors as mcolors
|
| 7 |
+
import matplotlib.cm
|
| 8 |
+
import potpourri3d as pp3d
|
| 9 |
+
import igl
|
| 10 |
+
from arrgh import arrgh
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
|
| 14 |
+
parser = argparse.ArgumentParser()
|
| 15 |
+
|
| 16 |
+
parser.add_argument("--input_mesh", type=str, required=True, help="The mesh to read from from, mesh file format.")
|
| 17 |
+
parser.add_argument("--input_labels", type=str, required=True, help="The labels, as a text file with one entry per line")
|
| 18 |
+
parser.add_argument("--label_count", type=int, default=-1, help="The number of labels to use for the visualization. If -1, computed as max of given labels.")
|
| 19 |
+
parser.add_argument("--output", type=str, required=True, help="The obj file to write output to")
|
| 20 |
+
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Read the mesh
|
| 25 |
+
V, F = igl.read_triangle_mesh(args.input_mesh)
|
| 26 |
+
|
| 27 |
+
# Read the scalar function
|
| 28 |
+
S = np.loadtxt(args.input_labels)
|
| 29 |
+
|
| 30 |
+
# Convert integers to scalars on [0,1]
|
| 31 |
+
if args.label_count == -1:
|
| 32 |
+
N_max = np.max(S) + 1
|
| 33 |
+
else:
|
| 34 |
+
N_max = args.label_count
|
| 35 |
+
S = S.astype(np.float32) / max(N_max-1, 1)
|
| 36 |
+
|
| 37 |
+
# Validate and write
|
| 38 |
+
if len(S.shape) != 1 or S.shape[0] != F.shape[0]:
|
| 39 |
+
raise ValueError(f"when scalar_on==faces, the scalar should be a length num-faces numpy array, but it has shape {S.shape[0]} and F={F.shape[0]}")
|
| 40 |
+
|
| 41 |
+
S = np.stack((S, np.zeros_like(S)), axis=-1)
|
| 42 |
+
|
| 43 |
+
pp3d.write_mesh(V, F, args.output, UV_coords=S, UV_type='per-face')
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
main()
|
PartField/applications/run_smooth_functional_map.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import trimesh
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
sys.path.append("..")
|
| 8 |
+
sys.path.append("../third_party/SmoothFunctionalMaps")
|
| 9 |
+
sys.path.append("../third_party/SmoothFunctionalMaps/pyFM")
|
| 10 |
+
|
| 11 |
+
from partfield.config import default_argument_parser, setup
|
| 12 |
+
from pyFM.mesh import TriMesh
|
| 13 |
+
from pyFM.spectral import mesh_FM_to_p2p
|
| 14 |
+
import DiscreteOpt
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def vertex_color_map(vertices):
|
| 18 |
+
min_coord, max_coord = np.min(vertices, axis=0, keepdims=True), np.max(vertices, axis=0, keepdims=True)
|
| 19 |
+
cmap = (vertices - min_coord) / (max_coord - min_coord)
|
| 20 |
+
return cmap
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
if __name__ == '__main__':
|
| 24 |
+
parser = default_argument_parser()
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
cfg = setup(args, freeze=False)
|
| 27 |
+
|
| 28 |
+
feature_dir = os.path.join("../exp_results", cfg.result_name)
|
| 29 |
+
|
| 30 |
+
all_files = cfg.dataset.all_files
|
| 31 |
+
assert len(all_files) % 2 == 0
|
| 32 |
+
num_pairs = len(all_files) // 2
|
| 33 |
+
|
| 34 |
+
device = "cuda"
|
| 35 |
+
|
| 36 |
+
output_dir = "../exp_results/correspondence/"
|
| 37 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
for i in range(num_pairs):
|
| 40 |
+
file0 = all_files[2 * i]
|
| 41 |
+
file1 = all_files[2 * i + 1]
|
| 42 |
+
|
| 43 |
+
uid0 = file0.split(".")[-2].replace("/", "_")
|
| 44 |
+
uid1 = file1.split(".")[-2].replace("/", "_")
|
| 45 |
+
|
| 46 |
+
mesh0 = trimesh.load(os.path.join(feature_dir, f"input_{uid0}_0.ply"), process=True)
|
| 47 |
+
mesh1 = trimesh.load(os.path.join(feature_dir, f"input_{uid1}_0.ply"), process=True)
|
| 48 |
+
|
| 49 |
+
feat0 = np.load(os.path.join(feature_dir, f"part_feat_{uid0}_0_batch.npy"))
|
| 50 |
+
feat1 = np.load(os.path.join(feature_dir, f"part_feat_{uid1}_0_batch.npy"))
|
| 51 |
+
|
| 52 |
+
assert mesh0.vertices.shape[0] == feat0.shape[0], "num of vertices should match num of features"
|
| 53 |
+
assert mesh1.vertices.shape[0] == feat1.shape[0], "num of vertices should match num of features"
|
| 54 |
+
|
| 55 |
+
th_descr0 = torch.tensor(feat0, device=device, dtype=torch.float32)
|
| 56 |
+
th_descr1 = torch.tensor(feat1, device=device, dtype=torch.float32)
|
| 57 |
+
|
| 58 |
+
cdist_01 = torch.cdist(th_descr0, th_descr1, p=2)
|
| 59 |
+
p2p_10_init = cdist_01.argmin(dim=0).cpu().numpy()
|
| 60 |
+
p2p_01_init = cdist_01.argmin(dim=1).cpu().numpy()
|
| 61 |
+
|
| 62 |
+
fm_mesh0 = TriMesh(mesh0.vertices, mesh0.faces, area_normalize=True, center=True).process(k=200, intrinsic=True)
|
| 63 |
+
fm_mesh1 = TriMesh(mesh1.vertices, mesh1.faces, area_normalize=True, center=True).process(k=200, intrinsic=True)
|
| 64 |
+
|
| 65 |
+
model = DiscreteOpt.SmoothDiscreteOptimization(fm_mesh0, fm_mesh1)
|
| 66 |
+
model.set_params("zoomout_rhm")
|
| 67 |
+
model.opt_params.step = 10
|
| 68 |
+
model.solve_from_p2p(p2p_21=p2p_10_init, p2p_12=p2p_01_init, n_jobs=30, verbose=True)
|
| 69 |
+
|
| 70 |
+
p2p_10_FM = mesh_FM_to_p2p(model.FM_12, fm_mesh0, fm_mesh1, use_adj=True)
|
| 71 |
+
|
| 72 |
+
color0 = vertex_color_map(mesh0.vertices)
|
| 73 |
+
color1 = color0[p2p_10_FM]
|
| 74 |
+
|
| 75 |
+
output_mesh0 = trimesh.Trimesh(mesh0.vertices, mesh0.faces, vertex_colors=color0)
|
| 76 |
+
output_mesh1 = trimesh.Trimesh(mesh1.vertices, mesh1.faces, vertex_colors=color1)
|
| 77 |
+
|
| 78 |
+
output_mesh0.export(os.path.join(output_dir, f"correspondence_{uid0}_{uid1}_0.ply"))
|
| 79 |
+
output_mesh1.export(os.path.join(output_dir, f"correspondence_{uid0}_{uid1}_1.ply"))
|
| 80 |
+
|
PartField/applications/shape_pair.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import polyscope as ps
|
| 4 |
+
import polyscope.imgui as psim
|
| 5 |
+
import potpourri3d as pp3d
|
| 6 |
+
import trimesh
|
| 7 |
+
import igl
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from simple_parsing import ArgumentParser
|
| 10 |
+
from arrgh import arrgh
|
| 11 |
+
|
| 12 |
+
### For clustering
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
from sklearn.cluster import AgglomerativeClustering, DBSCAN, KMeans
|
| 15 |
+
from scipy.sparse import coo_matrix, csr_matrix
|
| 16 |
+
from scipy.spatial import KDTree
|
| 17 |
+
from scipy.sparse.csgraph import connected_components
|
| 18 |
+
from sklearn.neighbors import NearestNeighbors
|
| 19 |
+
import networkx as nx
|
| 20 |
+
|
| 21 |
+
from scipy.optimize import linear_sum_assignment
|
| 22 |
+
|
| 23 |
+
import os, sys
|
| 24 |
+
sys.path.append("..")
|
| 25 |
+
from partfield.utils import *
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class Options:
|
| 29 |
+
|
| 30 |
+
""" Basic Options """
|
| 31 |
+
filename: str
|
| 32 |
+
filename_alt: str = None
|
| 33 |
+
|
| 34 |
+
"""System Options"""
|
| 35 |
+
device: str = "cuda" # Device
|
| 36 |
+
debug: bool = False # enable debug checks
|
| 37 |
+
extras: bool = False # include extra output for viz/debugging
|
| 38 |
+
|
| 39 |
+
""" State """
|
| 40 |
+
mode: str = 'co-segmentation'
|
| 41 |
+
m: dict = None # mesh
|
| 42 |
+
m_alt: dict = None # second mesh
|
| 43 |
+
|
| 44 |
+
# pca mode
|
| 45 |
+
|
| 46 |
+
# feature explore mode
|
| 47 |
+
i_feature: int = 0
|
| 48 |
+
|
| 49 |
+
i_cluster: int = 1
|
| 50 |
+
i_cluster2: int = 1
|
| 51 |
+
|
| 52 |
+
i_eps: int = 0.6
|
| 53 |
+
|
| 54 |
+
### For mixing in clustering
|
| 55 |
+
weight_dist = 1.0
|
| 56 |
+
weight_feat = 1.0
|
| 57 |
+
|
| 58 |
+
### For clustering visualization
|
| 59 |
+
independent: bool = True
|
| 60 |
+
source_init: bool = True
|
| 61 |
+
|
| 62 |
+
feature_range: float = 0.1
|
| 63 |
+
continuous_explore: bool = False
|
| 64 |
+
|
| 65 |
+
viz_mode: str = "faces"
|
| 66 |
+
|
| 67 |
+
output_fol: str = "results_pair"
|
| 68 |
+
|
| 69 |
+
### counter for screenshot
|
| 70 |
+
counter: int = 0
|
| 71 |
+
|
| 72 |
+
modes_list = ['feature_explore', "co-segmentation"]
|
| 73 |
+
|
| 74 |
+
def load_features(feature_filename, mesh_filename, viz_mode):
|
| 75 |
+
|
| 76 |
+
print("Reading features:")
|
| 77 |
+
print(f" Feature filename: {feature_filename}")
|
| 78 |
+
print(f" Mesh filename: {mesh_filename}")
|
| 79 |
+
|
| 80 |
+
# load features
|
| 81 |
+
feat = np.load(feature_filename, allow_pickle=True)
|
| 82 |
+
feat = feat.astype(np.float32)
|
| 83 |
+
|
| 84 |
+
# load mesh things
|
| 85 |
+
tm = load_mesh_util(mesh_filename)
|
| 86 |
+
|
| 87 |
+
V = np.array(tm.vertices, dtype=np.float32)
|
| 88 |
+
F = np.array(tm.faces)
|
| 89 |
+
|
| 90 |
+
if viz_mode == "faces":
|
| 91 |
+
pca_colors = np.array(tm.visual.face_colors, dtype=np.float32)
|
| 92 |
+
pca_colors = pca_colors[:,:3] / 255.
|
| 93 |
+
|
| 94 |
+
else:
|
| 95 |
+
pca_colors = np.array(tm.visual.vertex_colors, dtype=np.float32)
|
| 96 |
+
pca_colors = pca_colors[:,:3] / 255.
|
| 97 |
+
|
| 98 |
+
arrgh(V, F, pca_colors, feat)
|
| 99 |
+
|
| 100 |
+
return {
|
| 101 |
+
'V' : V,
|
| 102 |
+
'F' : F,
|
| 103 |
+
'pca_colors' : pca_colors,
|
| 104 |
+
'feat_np' : feat,
|
| 105 |
+
'feat_pt' : torch.tensor(feat, device='cuda'),
|
| 106 |
+
'trimesh' : tm,
|
| 107 |
+
'label' : None,
|
| 108 |
+
'num_cluster' : 1,
|
| 109 |
+
'scalar' : None
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
def prep_feature_mesh(m, name='mesh'):
|
| 113 |
+
ps_mesh = ps.register_surface_mesh(name, m['V'], m['F'])
|
| 114 |
+
ps_mesh.set_selection_mode('faces_only')
|
| 115 |
+
m['ps_mesh'] = ps_mesh
|
| 116 |
+
|
| 117 |
+
def viz_pca_colors(m):
|
| 118 |
+
m['ps_mesh'].add_color_quantity('pca colors', m['pca_colors'], enabled=True, defined_on=m["viz_mode"])
|
| 119 |
+
|
| 120 |
+
def viz_feature(m, ind):
|
| 121 |
+
m['ps_mesh'].add_scalar_quantity('pca colors', m['feat_np'][:,ind], cmap='turbo', enabled=True, defined_on=m["viz_mode"])
|
| 122 |
+
|
| 123 |
+
def feature_distance_np(feats, query_feat):
|
| 124 |
+
# normalize
|
| 125 |
+
feats = feats / np.linalg.norm(feats,axis=1)[:,None]
|
| 126 |
+
query_feat = query_feat / np.linalg.norm(query_feat)
|
| 127 |
+
# cosine distance
|
| 128 |
+
cos_sim = np.dot(feats, query_feat)
|
| 129 |
+
cos_dist = (1 - cos_sim) / 2.
|
| 130 |
+
return cos_dist
|
| 131 |
+
|
| 132 |
+
def feature_distance_pt(feats, query_feat):
|
| 133 |
+
return (1. - torch.nn.functional.cosine_similarity(feats, query_feat[None,:], dim=-1)) / 2.
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def ps_callback(opts):
|
| 137 |
+
m = opts.m
|
| 138 |
+
|
| 139 |
+
changed, ind = psim.Combo("Mode", modes_list.index(opts.mode), modes_list)
|
| 140 |
+
if changed:
|
| 141 |
+
opts.mode = modes_list[ind]
|
| 142 |
+
m['ps_mesh'].remove_all_quantities()
|
| 143 |
+
if opts.m_alt is not None:
|
| 144 |
+
opts.m_alt['ps_mesh'].remove_all_quantities()
|
| 145 |
+
|
| 146 |
+
elif opts.mode == 'feature_explore':
|
| 147 |
+
psim.TextUnformatted("Click on the mesh on the left")
|
| 148 |
+
psim.TextUnformatted("to highlight all faces within a given radius in feature space.""")
|
| 149 |
+
|
| 150 |
+
io = psim.GetIO()
|
| 151 |
+
if io.MouseClicked[0] or opts.continuous_explore:
|
| 152 |
+
screen_coords = io.MousePos
|
| 153 |
+
cam_params = ps.get_view_camera_parameters()
|
| 154 |
+
|
| 155 |
+
pick_result = ps.pick(screen_coords=screen_coords)
|
| 156 |
+
|
| 157 |
+
# Check if we hit one of the meshes
|
| 158 |
+
if pick_result.is_hit and pick_result.structure_name == "mesh":
|
| 159 |
+
if pick_result.structure_data['element_type'] != "face":
|
| 160 |
+
# shouldn't be possible
|
| 161 |
+
raise ValueError("pick returned non-face")
|
| 162 |
+
|
| 163 |
+
f_hit = pick_result.structure_data['index']
|
| 164 |
+
bary_weights = np.array(pick_result.structure_data['bary_coords'])
|
| 165 |
+
|
| 166 |
+
# get the feature via interpolation
|
| 167 |
+
point_feat = m['feat_np'][f_hit,:]
|
| 168 |
+
point_feat_pt = torch.tensor(point_feat, device='cuda')
|
| 169 |
+
|
| 170 |
+
all_dists1 = feature_distance_pt(m['feat_pt'], point_feat_pt).detach().cpu().numpy()
|
| 171 |
+
m['ps_mesh'].add_scalar_quantity("distance", all_dists1, cmap='blues', vminmax=(0, opts.feature_range), enabled=True, defined_on=m["viz_mode"])
|
| 172 |
+
opts.m['scalar'] = all_dists1
|
| 173 |
+
|
| 174 |
+
if opts.m_alt is not None:
|
| 175 |
+
all_dists2 = feature_distance_pt(opts.m_alt['feat_pt'], point_feat_pt).detach().cpu().numpy()
|
| 176 |
+
opts.m_alt['ps_mesh'].add_scalar_quantity("distance", all_dists2, cmap='blues', vminmax=(0, opts.feature_range), enabled=True, defined_on=m["viz_mode"])
|
| 177 |
+
opts.m_alt['scalar'] = all_dists2
|
| 178 |
+
|
| 179 |
+
else:
|
| 180 |
+
# not hit
|
| 181 |
+
pass
|
| 182 |
+
|
| 183 |
+
if psim.Button("Export"):
|
| 184 |
+
### Save output
|
| 185 |
+
OUTPUT_FOL = opts.output_fol
|
| 186 |
+
fname1 = opts.filename
|
| 187 |
+
out_mesh_file = os.path.join(OUTPUT_FOL, fname1+'.obj')
|
| 188 |
+
|
| 189 |
+
igl.write_obj(out_mesh_file, opts.m["V"], opts.m["F"])
|
| 190 |
+
print("Saved '{}'.".format(out_mesh_file))
|
| 191 |
+
|
| 192 |
+
out_face_ids_file = os.path.join(OUTPUT_FOL, fname1 + '_feat_dist_' + str(opts.counter) +'.txt')
|
| 193 |
+
np.savetxt(out_face_ids_file, opts.m['scalar'], fmt='%f')
|
| 194 |
+
print("Saved '{}'.".format(out_face_ids_file))
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
fname2 = opts.filename_alt
|
| 198 |
+
out_mesh_file = os.path.join(OUTPUT_FOL, fname2+'.obj')
|
| 199 |
+
|
| 200 |
+
igl.write_obj(out_mesh_file, opts.m_alt["V"], opts.m_alt["F"])
|
| 201 |
+
print("Saved '{}'.".format(out_mesh_file))
|
| 202 |
+
|
| 203 |
+
out_face_ids_file = os.path.join(OUTPUT_FOL, fname2 + '_feat_dist_' + str(opts.counter) +'.txt')
|
| 204 |
+
np.savetxt(out_face_ids_file, opts.m_alt['scalar'], fmt='%f')
|
| 205 |
+
# print("Saved '{}'.".format(out_face_ids_file))
|
| 206 |
+
|
| 207 |
+
opts.counter += 1
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
_, opts.feature_range = psim.SliderFloat('range', opts.feature_range, v_min=0., v_max=1.0, power=3)
|
| 211 |
+
_, opts.continuous_explore = psim.Checkbox('continuous', opts.continuous_explore)
|
| 212 |
+
|
| 213 |
+
# TODO nsharp remember how the keycodes work
|
| 214 |
+
if io.KeysDown[ord('q')]:
|
| 215 |
+
opts.feature_range += 0.01
|
| 216 |
+
if io.KeysDown[ord('w')]:
|
| 217 |
+
opts.feature_range -= 0.01
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
elif opts.mode == "co-segmentation":
|
| 221 |
+
|
| 222 |
+
changed, opts.source_init = psim.Checkbox("Source Init", opts.source_init)
|
| 223 |
+
changed, opts.independent = psim.Checkbox("Independent", opts.independent)
|
| 224 |
+
|
| 225 |
+
psim.TextUnformatted("Use the slider to toggle the number of desired clusters.")
|
| 226 |
+
cluster_changed, opts.i_cluster = psim.SliderInt("num clusters for model1", opts.i_cluster, v_min=1, v_max=30)
|
| 227 |
+
cluster_changed, opts.i_cluster2 = psim.SliderInt("num clusters for model2", opts.i_cluster2, v_min=1, v_max=30)
|
| 228 |
+
|
| 229 |
+
# if cluster_changed:
|
| 230 |
+
if psim.Button("Recompute"):
|
| 231 |
+
|
| 232 |
+
### Run clustering algorithm
|
| 233 |
+
|
| 234 |
+
### Mesh 1
|
| 235 |
+
num_clusters1 = opts.i_cluster
|
| 236 |
+
point_feat1 = m['feat_np']
|
| 237 |
+
point_feat1 = point_feat1 / np.linalg.norm(point_feat1, axis=-1, keepdims=True)
|
| 238 |
+
clustering1 = KMeans(n_clusters=num_clusters1, random_state=0, n_init="auto").fit(point_feat1)
|
| 239 |
+
|
| 240 |
+
### Get feature means per cluster
|
| 241 |
+
feature_means1 = []
|
| 242 |
+
for j in range(num_clusters1):
|
| 243 |
+
all_cluster_feat = point_feat1[clustering1.labels_==j]
|
| 244 |
+
mean_feat = np.mean(all_cluster_feat, axis=0)
|
| 245 |
+
feature_means1.append(mean_feat)
|
| 246 |
+
|
| 247 |
+
feature_means1 = np.array(feature_means1)
|
| 248 |
+
tree = KDTree(feature_means1)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
if opts.source_init:
|
| 252 |
+
num_clusters2 = opts.i_cluster
|
| 253 |
+
init_mode = np.array(feature_means1)
|
| 254 |
+
|
| 255 |
+
## default is kmeans++
|
| 256 |
+
else:
|
| 257 |
+
num_clusters2 = opts.i_cluster2
|
| 258 |
+
init_mode = "k-means++"
|
| 259 |
+
|
| 260 |
+
### Mesh 2
|
| 261 |
+
point_feat2 = opts.m_alt['feat_np']
|
| 262 |
+
point_feat2 = point_feat2 / np.linalg.norm(point_feat2, axis=-1, keepdims=True)
|
| 263 |
+
|
| 264 |
+
clustering2 = KMeans(n_clusters=num_clusters2, random_state=0, init=init_mode).fit(point_feat2)
|
| 265 |
+
|
| 266 |
+
### Get feature means per cluster
|
| 267 |
+
feature_means2 = []
|
| 268 |
+
for j in range(num_clusters2):
|
| 269 |
+
all_cluster_feat = point_feat2[clustering2.labels_==j]
|
| 270 |
+
mean_feat = np.mean(all_cluster_feat, axis=0)
|
| 271 |
+
feature_means2.append(mean_feat)
|
| 272 |
+
|
| 273 |
+
feature_means2 = np.array(feature_means2)
|
| 274 |
+
_, nn_idx = tree.query(feature_means2, k=1)
|
| 275 |
+
|
| 276 |
+
print(nn_idx)
|
| 277 |
+
print("Both KMeans")
|
| 278 |
+
print(np.unique(clustering1.labels_))
|
| 279 |
+
print(np.unique(clustering2.labels_))
|
| 280 |
+
|
| 281 |
+
relabelled_2 = nn_idx[clustering2.labels_]
|
| 282 |
+
|
| 283 |
+
print(np.unique(relabelled_2))
|
| 284 |
+
print()
|
| 285 |
+
|
| 286 |
+
m['ps_mesh'].add_scalar_quantity("cluster_both_kmeans", clustering1.labels_, cmap='turbo', vminmax=(0, num_clusters1-1), enabled=True, defined_on=m["viz_mode"])
|
| 287 |
+
opts.m['label'] = clustering1.labels_
|
| 288 |
+
opts.m['num_cluster'] = num_clusters1
|
| 289 |
+
|
| 290 |
+
if opts.independent:
|
| 291 |
+
opts.m_alt['ps_mesh'].add_scalar_quantity("cluster", clustering2.labels_, cmap='turbo', vminmax=(0, num_clusters2-1), enabled=True, defined_on=m["viz_mode"])
|
| 292 |
+
opts.m_alt['label'] = clustering2.labels_
|
| 293 |
+
opts.m_alt['num_cluster'] = num_clusters2
|
| 294 |
+
else:
|
| 295 |
+
opts.m_alt['ps_mesh'].add_scalar_quantity("cluster", relabelled_2, cmap='turbo', vminmax=(0, num_clusters1-1), enabled=True, defined_on=m["viz_mode"])
|
| 296 |
+
opts.m_alt['label'] = relabelled_2
|
| 297 |
+
opts.m_alt['num_cluster'] = num_clusters1
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
if psim.Button("Export"):
|
| 301 |
+
### Save output
|
| 302 |
+
OUTPUT_FOL = opts.output_fol
|
| 303 |
+
fname1 = opts.filename
|
| 304 |
+
out_mesh_file = os.path.join(OUTPUT_FOL, fname1+'.obj')
|
| 305 |
+
|
| 306 |
+
igl.write_obj(out_mesh_file, opts.m["V"], opts.m["F"])
|
| 307 |
+
print("Saved '{}'.".format(out_mesh_file))
|
| 308 |
+
|
| 309 |
+
if m["viz_mode"] == "faces":
|
| 310 |
+
out_face_ids_file = os.path.join(OUTPUT_FOL, fname1 + "_" + str(opts.m['num_cluster']) + '_pred_face_ids.txt')
|
| 311 |
+
else:
|
| 312 |
+
out_face_ids_file = os.path.join(OUTPUT_FOL, fname1 + "_" + str(opts.m['num_cluster']) + '_pred_vertices_ids.txt')
|
| 313 |
+
|
| 314 |
+
np.savetxt(out_face_ids_file, opts.m['label'], fmt='%d')
|
| 315 |
+
print("Saved '{}'.".format(out_face_ids_file))
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
fname2 = opts.filename_alt
|
| 319 |
+
out_mesh_file = os.path.join(OUTPUT_FOL, fname2 +'.obj')
|
| 320 |
+
|
| 321 |
+
igl.write_obj(out_mesh_file, opts.m_alt["V"], opts.m_alt["F"])
|
| 322 |
+
print("Saved '{}'.".format(out_mesh_file))
|
| 323 |
+
|
| 324 |
+
if m["viz_mode"] == "faces":
|
| 325 |
+
out_face_ids_file = os.path.join(OUTPUT_FOL, fname2 + "_" + str(opts.m_alt['num_cluster']) + '_pred_face_ids.txt')
|
| 326 |
+
else:
|
| 327 |
+
out_face_ids_file = os.path.join(OUTPUT_FOL, fname2 + "_" + str(opts.m_alt['num_cluster']) + '_pred_vertices_ids.txt')
|
| 328 |
+
|
| 329 |
+
np.savetxt(out_face_ids_file, opts.m_alt['label'], fmt='%d')
|
| 330 |
+
print("Saved '{}'.".format(out_face_ids_file))
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def main():
|
| 334 |
+
## Parse args
|
| 335 |
+
# Uses simple_parsing library to automatically construct parser from the dataclass Options
|
| 336 |
+
parser = ArgumentParser()
|
| 337 |
+
parser.add_arguments(Options, dest="options")
|
| 338 |
+
parser.add_argument('--data_root', default="../exp_results/partfield_features/trellis", help='Path the model features are stored.')
|
| 339 |
+
args = parser.parse_args()
|
| 340 |
+
opts: Options = args.options
|
| 341 |
+
|
| 342 |
+
DATA_ROOT = args.data_root
|
| 343 |
+
|
| 344 |
+
shape_1 = opts.filename
|
| 345 |
+
shape_2 = opts.filename_alt
|
| 346 |
+
|
| 347 |
+
if os.path.exists(os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0.npy")):
|
| 348 |
+
feature_fname1 = os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0.npy")
|
| 349 |
+
feature_fname2 = os.path.join(DATA_ROOT, "part_feat_"+ shape_2 + "_0.npy")
|
| 350 |
+
|
| 351 |
+
mesh_fname1 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_1 + "_0.ply")
|
| 352 |
+
mesh_fname2 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_2 + "_0.ply")
|
| 353 |
+
else:
|
| 354 |
+
feature_fname1 = os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0_batch.npy")
|
| 355 |
+
feature_fname2 = os.path.join(DATA_ROOT, "part_feat_"+ shape_2 + "_0_batch.npy")
|
| 356 |
+
|
| 357 |
+
mesh_fname1 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_1 + "_0.ply")
|
| 358 |
+
mesh_fname2 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_2 + "_0.ply")
|
| 359 |
+
|
| 360 |
+
#### To save output ####
|
| 361 |
+
os.makedirs(opts.output_fol, exist_ok=True)
|
| 362 |
+
########################
|
| 363 |
+
|
| 364 |
+
# Initialize
|
| 365 |
+
ps.init()
|
| 366 |
+
|
| 367 |
+
mesh_dict = load_features(feature_fname1, mesh_fname1, opts.viz_mode)
|
| 368 |
+
prep_feature_mesh(mesh_dict)
|
| 369 |
+
mesh_dict["viz_mode"] = opts.viz_mode
|
| 370 |
+
opts.m = mesh_dict
|
| 371 |
+
|
| 372 |
+
mesh_dict_alt = load_features(feature_fname2, mesh_fname2, opts.viz_mode)
|
| 373 |
+
prep_feature_mesh(mesh_dict_alt, name='mesh_alt')
|
| 374 |
+
mesh_dict_alt['ps_mesh'].translate((2.5, 0., 0.))
|
| 375 |
+
mesh_dict_alt["viz_mode"] = opts.viz_mode
|
| 376 |
+
opts.m_alt = mesh_dict_alt
|
| 377 |
+
|
| 378 |
+
# Start the interactive UI
|
| 379 |
+
ps.set_user_callback(lambda : ps_callback(opts))
|
| 380 |
+
ps.show()
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
if __name__ == "__main__":
|
| 384 |
+
main()
|
| 385 |
+
|
PartField/applications/single_shape.py
ADDED
|
@@ -0,0 +1,758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import polyscope as ps
|
| 4 |
+
import polyscope.imgui as psim
|
| 5 |
+
import potpourri3d as pp3d
|
| 6 |
+
import trimesh
|
| 7 |
+
import igl
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from simple_parsing import ArgumentParser
|
| 10 |
+
from arrgh import arrgh
|
| 11 |
+
|
| 12 |
+
### For clustering
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
from sklearn.cluster import AgglomerativeClustering, DBSCAN, KMeans
|
| 15 |
+
from scipy.sparse import coo_matrix, csr_matrix
|
| 16 |
+
from scipy.spatial import KDTree
|
| 17 |
+
from scipy.sparse.csgraph import connected_components
|
| 18 |
+
from sklearn.neighbors import NearestNeighbors
|
| 19 |
+
import networkx as nx
|
| 20 |
+
|
| 21 |
+
from scipy.optimize import linear_sum_assignment
|
| 22 |
+
|
| 23 |
+
import os, sys
|
| 24 |
+
sys.path.append("..")
|
| 25 |
+
from partfield.utils import *
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class Options:
|
| 29 |
+
|
| 30 |
+
""" Basic Options """
|
| 31 |
+
filename: str
|
| 32 |
+
|
| 33 |
+
"""System Options"""
|
| 34 |
+
device: str = "cuda" # Device
|
| 35 |
+
debug: bool = False # enable debug checks
|
| 36 |
+
extras: bool = False # include extra output for viz/debugging
|
| 37 |
+
|
| 38 |
+
""" State """
|
| 39 |
+
mode: str = 'pca'
|
| 40 |
+
m: dict = None # mesh
|
| 41 |
+
|
| 42 |
+
# pca mode
|
| 43 |
+
|
| 44 |
+
# feature explore mode
|
| 45 |
+
i_feature: int = 0
|
| 46 |
+
|
| 47 |
+
i_cluster: int = 1
|
| 48 |
+
|
| 49 |
+
i_eps: int = 0.6
|
| 50 |
+
|
| 51 |
+
### For mixing in clustering
|
| 52 |
+
weight_dist = 1.0
|
| 53 |
+
weight_feat = 1.0
|
| 54 |
+
|
| 55 |
+
### For clustering visualization
|
| 56 |
+
feature_range: float = 0.1
|
| 57 |
+
continuous_explore: bool = False
|
| 58 |
+
|
| 59 |
+
viz_mode: str = "faces"
|
| 60 |
+
|
| 61 |
+
output_fol: str = "results_single"
|
| 62 |
+
|
| 63 |
+
### For adj_matrix
|
| 64 |
+
adj_mode: str = "Vanilla"
|
| 65 |
+
add_knn_edges: bool = False
|
| 66 |
+
|
| 67 |
+
### counter for screenshot
|
| 68 |
+
counter: int = 0
|
| 69 |
+
|
| 70 |
+
modes_list = ['pca', 'feature_viz', 'cluster_agglo', 'cluster_kmeans']
|
| 71 |
+
adj_mode_list = ["Vanilla", "Face_MST", "CC_MST"]
|
| 72 |
+
|
| 73 |
+
#### For clustering
|
| 74 |
+
class UnionFind:
|
| 75 |
+
def __init__(self, n):
|
| 76 |
+
self.parent = list(range(n))
|
| 77 |
+
self.rank = [1] * n
|
| 78 |
+
|
| 79 |
+
def find(self, x):
|
| 80 |
+
if self.parent[x] != x:
|
| 81 |
+
self.parent[x] = self.find(self.parent[x])
|
| 82 |
+
return self.parent[x]
|
| 83 |
+
|
| 84 |
+
def union(self, x, y):
|
| 85 |
+
rootX = self.find(x)
|
| 86 |
+
rootY = self.find(y)
|
| 87 |
+
|
| 88 |
+
if rootX != rootY:
|
| 89 |
+
if self.rank[rootX] > self.rank[rootY]:
|
| 90 |
+
self.parent[rootY] = rootX
|
| 91 |
+
elif self.rank[rootX] < self.rank[rootY]:
|
| 92 |
+
self.parent[rootX] = rootY
|
| 93 |
+
else:
|
| 94 |
+
self.parent[rootY] = rootX
|
| 95 |
+
self.rank[rootX] += 1
|
| 96 |
+
|
| 97 |
+
#####################################
|
| 98 |
+
## Face adjacency computation options
|
| 99 |
+
#####################################
|
| 100 |
+
def construct_face_adjacency_matrix_ccmst(face_list, vertices, k=10, with_knn=True):
|
| 101 |
+
"""
|
| 102 |
+
Given a list of faces (each face is a 3-tuple of vertex indices),
|
| 103 |
+
construct a face-based adjacency matrix of shape (num_faces, num_faces).
|
| 104 |
+
|
| 105 |
+
Two faces are adjacent if they share an edge (the "mesh adjacency").
|
| 106 |
+
If multiple connected components remain, we:
|
| 107 |
+
1) Compute the centroid of each connected component as the mean of all face centroids.
|
| 108 |
+
2) Use a KNN graph (k=10) based on centroid distances on each connected component.
|
| 109 |
+
3) Compute MST of that KNN graph.
|
| 110 |
+
4) Add MST edges that connect different components as "dummy" edges
|
| 111 |
+
in the face adjacency matrix, ensuring one connected component. The selected face for
|
| 112 |
+
each connected component is the face closest to the component centroid.
|
| 113 |
+
|
| 114 |
+
Parameters
|
| 115 |
+
----------
|
| 116 |
+
face_list : list of tuples
|
| 117 |
+
List of faces, each face is a tuple (v0, v1, v2) of vertex indices.
|
| 118 |
+
vertices : np.ndarray of shape (num_vertices, 3)
|
| 119 |
+
Array of vertex coordinates.
|
| 120 |
+
k : int, optional
|
| 121 |
+
Number of neighbors to use in centroid KNN. Default is 10.
|
| 122 |
+
|
| 123 |
+
Returns
|
| 124 |
+
-------
|
| 125 |
+
face_adjacency : scipy.sparse.csr_matrix
|
| 126 |
+
A CSR sparse matrix of shape (num_faces, num_faces),
|
| 127 |
+
containing 1s for adjacent faces (shared-edge adjacency)
|
| 128 |
+
plus dummy edges ensuring a single connected component.
|
| 129 |
+
"""
|
| 130 |
+
num_faces = len(face_list)
|
| 131 |
+
if num_faces == 0:
|
| 132 |
+
# Return an empty matrix if no faces
|
| 133 |
+
return csr_matrix((0, 0))
|
| 134 |
+
|
| 135 |
+
#--------------------------------------------------------------------------
|
| 136 |
+
# 1) Build adjacency based on shared edges.
|
| 137 |
+
# (Same logic as the original code, plus import statements.)
|
| 138 |
+
#--------------------------------------------------------------------------
|
| 139 |
+
edge_to_faces = defaultdict(list)
|
| 140 |
+
uf = UnionFind(num_faces)
|
| 141 |
+
for f_idx, (v0, v1, v2) in enumerate(face_list):
|
| 142 |
+
# Sort each edge’s endpoints so (i, j) == (j, i)
|
| 143 |
+
edges = [
|
| 144 |
+
tuple(sorted((v0, v1))),
|
| 145 |
+
tuple(sorted((v1, v2))),
|
| 146 |
+
tuple(sorted((v2, v0)))
|
| 147 |
+
]
|
| 148 |
+
for e in edges:
|
| 149 |
+
edge_to_faces[e].append(f_idx)
|
| 150 |
+
|
| 151 |
+
row = []
|
| 152 |
+
col = []
|
| 153 |
+
for edge, face_indices in edge_to_faces.items():
|
| 154 |
+
unique_faces = list(set(face_indices))
|
| 155 |
+
if len(unique_faces) > 1:
|
| 156 |
+
# For every pair of distinct faces that share this edge,
|
| 157 |
+
# mark them as mutually adjacent
|
| 158 |
+
for i in range(len(unique_faces)):
|
| 159 |
+
for j in range(i + 1, len(unique_faces)):
|
| 160 |
+
fi = unique_faces[i]
|
| 161 |
+
fj = unique_faces[j]
|
| 162 |
+
row.append(fi)
|
| 163 |
+
col.append(fj)
|
| 164 |
+
row.append(fj)
|
| 165 |
+
col.append(fi)
|
| 166 |
+
uf.union(fi, fj)
|
| 167 |
+
|
| 168 |
+
data = np.ones(len(row), dtype=np.int8)
|
| 169 |
+
face_adjacency = coo_matrix(
|
| 170 |
+
(data, (row, col)), shape=(num_faces, num_faces)
|
| 171 |
+
).tocsr()
|
| 172 |
+
|
| 173 |
+
#--------------------------------------------------------------------------
|
| 174 |
+
# 2) Check if the graph from shared edges is already connected.
|
| 175 |
+
#--------------------------------------------------------------------------
|
| 176 |
+
n_components = 0
|
| 177 |
+
for i in range(num_faces):
|
| 178 |
+
if uf.find(i) == i:
|
| 179 |
+
n_components += 1
|
| 180 |
+
print("n_components", n_components)
|
| 181 |
+
|
| 182 |
+
if n_components == 1:
|
| 183 |
+
# Already a single connected component, no need for dummy edges
|
| 184 |
+
return face_adjacency
|
| 185 |
+
|
| 186 |
+
#--------------------------------------------------------------------------
|
| 187 |
+
# 3) Compute centroids of each face for building a KNN graph.
|
| 188 |
+
#--------------------------------------------------------------------------
|
| 189 |
+
face_centroids = []
|
| 190 |
+
for (v0, v1, v2) in face_list:
|
| 191 |
+
centroid = (vertices[v0] + vertices[v1] + vertices[v2]) / 3.0
|
| 192 |
+
face_centroids.append(centroid)
|
| 193 |
+
face_centroids = np.array(face_centroids)
|
| 194 |
+
|
| 195 |
+
#--------------------------------------------------------------------------
|
| 196 |
+
# 4b) Build a KNN graph on connected components
|
| 197 |
+
#--------------------------------------------------------------------------
|
| 198 |
+
# Group faces by their root representative in the Union-Find structure
|
| 199 |
+
component_dict = {}
|
| 200 |
+
for face_idx in range(num_faces):
|
| 201 |
+
root = uf.find(face_idx)
|
| 202 |
+
if root not in component_dict:
|
| 203 |
+
component_dict[root] = set()
|
| 204 |
+
component_dict[root].add(face_idx)
|
| 205 |
+
|
| 206 |
+
connected_components = list(component_dict.values())
|
| 207 |
+
|
| 208 |
+
print("Using connected component MST.")
|
| 209 |
+
component_centroid_face_idx = []
|
| 210 |
+
connected_component_centroids = []
|
| 211 |
+
knn = NearestNeighbors(n_neighbors=1, algorithm='auto')
|
| 212 |
+
for component in connected_components:
|
| 213 |
+
curr_component_faces = list(component)
|
| 214 |
+
curr_component_face_centroids = face_centroids[curr_component_faces]
|
| 215 |
+
component_centroid = np.mean(curr_component_face_centroids, axis=0)
|
| 216 |
+
|
| 217 |
+
### Assign a face closest to the centroid
|
| 218 |
+
face_idx = curr_component_faces[np.argmin(np.linalg.norm(curr_component_face_centroids-component_centroid, axis=-1))]
|
| 219 |
+
|
| 220 |
+
connected_component_centroids.append(component_centroid)
|
| 221 |
+
component_centroid_face_idx.append(face_idx)
|
| 222 |
+
|
| 223 |
+
component_centroid_face_idx = np.array(component_centroid_face_idx)
|
| 224 |
+
connected_component_centroids = np.array(connected_component_centroids)
|
| 225 |
+
|
| 226 |
+
if n_components < k:
|
| 227 |
+
knn = NearestNeighbors(n_neighbors=n_components, algorithm='auto')
|
| 228 |
+
else:
|
| 229 |
+
knn = NearestNeighbors(n_neighbors=k, algorithm='auto')
|
| 230 |
+
knn.fit(connected_component_centroids)
|
| 231 |
+
distances, indices = knn.kneighbors(connected_component_centroids)
|
| 232 |
+
|
| 233 |
+
#--------------------------------------------------------------------------
|
| 234 |
+
# 5) Build a weighted graph in NetworkX using centroid-distances as edges
|
| 235 |
+
#--------------------------------------------------------------------------
|
| 236 |
+
G = nx.Graph()
|
| 237 |
+
# Add each face as a node in the graph
|
| 238 |
+
G.add_nodes_from(range(num_faces))
|
| 239 |
+
|
| 240 |
+
# For each face i, add edges (i -> j) for each neighbor j in the KNN
|
| 241 |
+
for idx1 in range(n_components):
|
| 242 |
+
i = component_centroid_face_idx[idx1]
|
| 243 |
+
for idx2, dist in zip(indices[idx1], distances[idx1]):
|
| 244 |
+
j = component_centroid_face_idx[idx2]
|
| 245 |
+
if i == j:
|
| 246 |
+
continue # skip self-loop
|
| 247 |
+
# Add an undirected edge with 'weight' = distance
|
| 248 |
+
# NetworkX handles parallel edges gracefully via last add_edge,
|
| 249 |
+
# but it typically overwrites the weight if (i, j) already exists.
|
| 250 |
+
G.add_edge(i, j, weight=dist)
|
| 251 |
+
|
| 252 |
+
#--------------------------------------------------------------------------
|
| 253 |
+
# 6) Compute MST on that KNN graph
|
| 254 |
+
#--------------------------------------------------------------------------
|
| 255 |
+
mst = nx.minimum_spanning_tree(G, weight='weight')
|
| 256 |
+
# Sort MST edges by ascending weight, so we add the shortest edges first
|
| 257 |
+
mst_edges_sorted = sorted(
|
| 258 |
+
mst.edges(data=True), key=lambda e: e[2]['weight']
|
| 259 |
+
)
|
| 260 |
+
print("mst edges sorted", len(mst_edges_sorted))
|
| 261 |
+
#--------------------------------------------------------------------------
|
| 262 |
+
# 7) Use a union-find structure to add MST edges only if they
|
| 263 |
+
# connect two currently disconnected components of the adjacency matrix
|
| 264 |
+
#--------------------------------------------------------------------------
|
| 265 |
+
|
| 266 |
+
# Convert face_adjacency to LIL format for efficient edge addition
|
| 267 |
+
adjacency_lil = face_adjacency.tolil()
|
| 268 |
+
|
| 269 |
+
# Now, step through MST edges in ascending order
|
| 270 |
+
for (u, v, attr) in mst_edges_sorted:
|
| 271 |
+
if uf.find(u) != uf.find(v):
|
| 272 |
+
# These belong to different components, so unify them
|
| 273 |
+
uf.union(u, v)
|
| 274 |
+
# And add a "dummy" edge to our adjacency matrix
|
| 275 |
+
adjacency_lil[u, v] = 1
|
| 276 |
+
adjacency_lil[v, u] = 1
|
| 277 |
+
|
| 278 |
+
# Convert back to CSR format and return
|
| 279 |
+
face_adjacency = adjacency_lil.tocsr()
|
| 280 |
+
|
| 281 |
+
if with_knn:
|
| 282 |
+
print("Adding KNN edges.")
|
| 283 |
+
### Add KNN edges graph too
|
| 284 |
+
dummy_row = []
|
| 285 |
+
dummy_col = []
|
| 286 |
+
for idx1 in range(n_components):
|
| 287 |
+
i = component_centroid_face_idx[idx1]
|
| 288 |
+
for idx2 in indices[idx1]:
|
| 289 |
+
j = component_centroid_face_idx[idx2]
|
| 290 |
+
dummy_row.extend([i, j])
|
| 291 |
+
dummy_col.extend([j, i]) ### duplicates are handled by coo
|
| 292 |
+
|
| 293 |
+
dummy_data = np.ones(len(dummy_row), dtype=np.int16)
|
| 294 |
+
dummy_mat = coo_matrix(
|
| 295 |
+
(dummy_data, (dummy_row, dummy_col)),
|
| 296 |
+
shape=(num_faces, num_faces)
|
| 297 |
+
).tocsr()
|
| 298 |
+
face_adjacency = face_adjacency + dummy_mat
|
| 299 |
+
###########################
|
| 300 |
+
|
| 301 |
+
return face_adjacency
|
| 302 |
+
#########################
|
| 303 |
+
|
| 304 |
+
def construct_face_adjacency_matrix_facemst(face_list, vertices, k=10, with_knn=True):
|
| 305 |
+
"""
|
| 306 |
+
Given a list of faces (each face is a 3-tuple of vertex indices),
|
| 307 |
+
construct a face-based adjacency matrix of shape (num_faces, num_faces).
|
| 308 |
+
|
| 309 |
+
Two faces are adjacent if they share an edge (the "mesh adjacency").
|
| 310 |
+
If multiple connected components remain, we:
|
| 311 |
+
1) Compute the centroid of each face.
|
| 312 |
+
2) Use a KNN graph (k=10) based on centroid distances.
|
| 313 |
+
3) Compute MST of that KNN graph.
|
| 314 |
+
4) Add MST edges that connect different components as "dummy" edges
|
| 315 |
+
in the face adjacency matrix, ensuring one connected component.
|
| 316 |
+
|
| 317 |
+
Parameters
|
| 318 |
+
----------
|
| 319 |
+
face_list : list of tuples
|
| 320 |
+
List of faces, each face is a tuple (v0, v1, v2) of vertex indices.
|
| 321 |
+
vertices : np.ndarray of shape (num_vertices, 3)
|
| 322 |
+
Array of vertex coordinates.
|
| 323 |
+
k : int, optional
|
| 324 |
+
Number of neighbors to use in centroid KNN. Default is 10.
|
| 325 |
+
|
| 326 |
+
Returns
|
| 327 |
+
-------
|
| 328 |
+
face_adjacency : scipy.sparse.csr_matrix
|
| 329 |
+
A CSR sparse matrix of shape (num_faces, num_faces),
|
| 330 |
+
containing 1s for adjacent faces (shared-edge adjacency)
|
| 331 |
+
plus dummy edges ensuring a single connected component.
|
| 332 |
+
"""
|
| 333 |
+
num_faces = len(face_list)
|
| 334 |
+
if num_faces == 0:
|
| 335 |
+
# Return an empty matrix if no faces
|
| 336 |
+
return csr_matrix((0, 0))
|
| 337 |
+
|
| 338 |
+
#--------------------------------------------------------------------------
|
| 339 |
+
# 1) Build adjacency based on shared edges.
|
| 340 |
+
# (Same logic as the original code, plus import statements.)
|
| 341 |
+
#--------------------------------------------------------------------------
|
| 342 |
+
edge_to_faces = defaultdict(list)
|
| 343 |
+
uf = UnionFind(num_faces)
|
| 344 |
+
for f_idx, (v0, v1, v2) in enumerate(face_list):
|
| 345 |
+
# Sort each edge’s endpoints so (i, j) == (j, i)
|
| 346 |
+
edges = [
|
| 347 |
+
tuple(sorted((v0, v1))),
|
| 348 |
+
tuple(sorted((v1, v2))),
|
| 349 |
+
tuple(sorted((v2, v0)))
|
| 350 |
+
]
|
| 351 |
+
for e in edges:
|
| 352 |
+
edge_to_faces[e].append(f_idx)
|
| 353 |
+
|
| 354 |
+
row = []
|
| 355 |
+
col = []
|
| 356 |
+
for edge, face_indices in edge_to_faces.items():
|
| 357 |
+
unique_faces = list(set(face_indices))
|
| 358 |
+
if len(unique_faces) > 1:
|
| 359 |
+
# For every pair of distinct faces that share this edge,
|
| 360 |
+
# mark them as mutually adjacent
|
| 361 |
+
for i in range(len(unique_faces)):
|
| 362 |
+
for j in range(i + 1, len(unique_faces)):
|
| 363 |
+
fi = unique_faces[i]
|
| 364 |
+
fj = unique_faces[j]
|
| 365 |
+
row.append(fi)
|
| 366 |
+
col.append(fj)
|
| 367 |
+
row.append(fj)
|
| 368 |
+
col.append(fi)
|
| 369 |
+
uf.union(fi, fj)
|
| 370 |
+
|
| 371 |
+
data = np.ones(len(row), dtype=np.int8)
|
| 372 |
+
face_adjacency = coo_matrix(
|
| 373 |
+
(data, (row, col)), shape=(num_faces, num_faces)
|
| 374 |
+
).tocsr()
|
| 375 |
+
|
| 376 |
+
#--------------------------------------------------------------------------
|
| 377 |
+
# 2) Check if the graph from shared edges is already connected.
|
| 378 |
+
#--------------------------------------------------------------------------
|
| 379 |
+
n_components = 0
|
| 380 |
+
for i in range(num_faces):
|
| 381 |
+
if uf.find(i) == i:
|
| 382 |
+
n_components += 1
|
| 383 |
+
print("n_components", n_components)
|
| 384 |
+
|
| 385 |
+
if n_components == 1:
|
| 386 |
+
# Already a single connected component, no need for dummy edges
|
| 387 |
+
return face_adjacency
|
| 388 |
+
#--------------------------------------------------------------------------
|
| 389 |
+
# 3) Compute centroids of each face for building a KNN graph.
|
| 390 |
+
#--------------------------------------------------------------------------
|
| 391 |
+
face_centroids = []
|
| 392 |
+
for (v0, v1, v2) in face_list:
|
| 393 |
+
centroid = (vertices[v0] + vertices[v1] + vertices[v2]) / 3.0
|
| 394 |
+
face_centroids.append(centroid)
|
| 395 |
+
face_centroids = np.array(face_centroids)
|
| 396 |
+
|
| 397 |
+
#--------------------------------------------------------------------------
|
| 398 |
+
# 4) Build a KNN graph (k=10) over face centroids using scikit‐learn
|
| 399 |
+
#--------------------------------------------------------------------------
|
| 400 |
+
knn = NearestNeighbors(n_neighbors=k, algorithm='auto')
|
| 401 |
+
knn.fit(face_centroids)
|
| 402 |
+
distances, indices = knn.kneighbors(face_centroids)
|
| 403 |
+
# 'distances[i]' are the distances from face i to each of its 'k' neighbors
|
| 404 |
+
# 'indices[i]' are the face indices of those neighbors
|
| 405 |
+
|
| 406 |
+
#--------------------------------------------------------------------------
|
| 407 |
+
# 5) Build a weighted graph in NetworkX using centroid-distances as edges
|
| 408 |
+
#--------------------------------------------------------------------------
|
| 409 |
+
G = nx.Graph()
|
| 410 |
+
# Add each face as a node in the graph
|
| 411 |
+
G.add_nodes_from(range(num_faces))
|
| 412 |
+
|
| 413 |
+
# For each face i, add edges (i -> j) for each neighbor j in the KNN
|
| 414 |
+
for i in range(num_faces):
|
| 415 |
+
for j, dist in zip(indices[i], distances[i]):
|
| 416 |
+
if i == j:
|
| 417 |
+
continue # skip self-loop
|
| 418 |
+
# Add an undirected edge with 'weight' = distance
|
| 419 |
+
# NetworkX handles parallel edges gracefully via last add_edge,
|
| 420 |
+
# but it typically overwrites the weight if (i, j) already exists.
|
| 421 |
+
G.add_edge(i, j, weight=dist)
|
| 422 |
+
|
| 423 |
+
#--------------------------------------------------------------------------
|
| 424 |
+
# 6) Compute MST on that KNN graph
|
| 425 |
+
#--------------------------------------------------------------------------
|
| 426 |
+
mst = nx.minimum_spanning_tree(G, weight='weight')
|
| 427 |
+
# Sort MST edges by ascending weight, so we add the shortest edges first
|
| 428 |
+
mst_edges_sorted = sorted(
|
| 429 |
+
mst.edges(data=True), key=lambda e: e[2]['weight']
|
| 430 |
+
)
|
| 431 |
+
print("mst edges sorted", len(mst_edges_sorted))
|
| 432 |
+
#--------------------------------------------------------------------------
|
| 433 |
+
# 7) Use a union-find structure to add MST edges only if they
|
| 434 |
+
# connect two currently disconnected components of the adjacency matrix
|
| 435 |
+
#--------------------------------------------------------------------------
|
| 436 |
+
|
| 437 |
+
# Convert face_adjacency to LIL format for efficient edge addition
|
| 438 |
+
adjacency_lil = face_adjacency.tolil()
|
| 439 |
+
|
| 440 |
+
# Now, step through MST edges in ascending order
|
| 441 |
+
for (u, v, attr) in mst_edges_sorted:
|
| 442 |
+
if uf.find(u) != uf.find(v):
|
| 443 |
+
# These belong to different components, so unify them
|
| 444 |
+
uf.union(u, v)
|
| 445 |
+
# And add a "dummy" edge to our adjacency matrix
|
| 446 |
+
adjacency_lil[u, v] = 1
|
| 447 |
+
adjacency_lil[v, u] = 1
|
| 448 |
+
|
| 449 |
+
# Convert back to CSR format and return
|
| 450 |
+
face_adjacency = adjacency_lil.tocsr()
|
| 451 |
+
|
| 452 |
+
if with_knn:
|
| 453 |
+
print("Adding KNN edges.")
|
| 454 |
+
### Add KNN edges graph too
|
| 455 |
+
dummy_row = []
|
| 456 |
+
dummy_col = []
|
| 457 |
+
for i in range(num_faces):
|
| 458 |
+
for j in indices[i]:
|
| 459 |
+
dummy_row.extend([i, j])
|
| 460 |
+
dummy_col.extend([j, i]) ### duplicates are handled by coo
|
| 461 |
+
|
| 462 |
+
dummy_data = np.ones(len(dummy_row), dtype=np.int16)
|
| 463 |
+
dummy_mat = coo_matrix(
|
| 464 |
+
(dummy_data, (dummy_row, dummy_col)),
|
| 465 |
+
shape=(num_faces, num_faces)
|
| 466 |
+
).tocsr()
|
| 467 |
+
face_adjacency = face_adjacency + dummy_mat
|
| 468 |
+
###########################
|
| 469 |
+
|
| 470 |
+
return face_adjacency
|
| 471 |
+
|
| 472 |
+
def construct_face_adjacency_matrix_naive(face_list):
|
| 473 |
+
"""
|
| 474 |
+
Given a list of faces (each face is a 3-tuple of vertex indices),
|
| 475 |
+
construct a face-based adjacency matrix of shape (num_faces, num_faces).
|
| 476 |
+
Two faces are adjacent if they share an edge.
|
| 477 |
+
|
| 478 |
+
If multiple connected components exist, dummy edges are added to
|
| 479 |
+
turn them into a single connected component. Edges are added naively by
|
| 480 |
+
randomly selecting a face and connecting consecutive components -- (comp_i, comp_i+1) ...
|
| 481 |
+
|
| 482 |
+
Parameters
|
| 483 |
+
----------
|
| 484 |
+
face_list : list of tuples
|
| 485 |
+
List of faces, each face is a tuple (v0, v1, v2) of vertex indices.
|
| 486 |
+
|
| 487 |
+
Returns
|
| 488 |
+
-------
|
| 489 |
+
face_adjacency : scipy.sparse.csr_matrix
|
| 490 |
+
A CSR sparse matrix of shape (num_faces, num_faces),
|
| 491 |
+
containing 1s for adjacent faces and 0s otherwise.
|
| 492 |
+
Additional edges are added if the faces are in multiple components.
|
| 493 |
+
"""
|
| 494 |
+
|
| 495 |
+
num_faces = len(face_list)
|
| 496 |
+
if num_faces == 0:
|
| 497 |
+
# Return an empty matrix if no faces
|
| 498 |
+
return csr_matrix((0, 0))
|
| 499 |
+
|
| 500 |
+
# Step 1: Map each undirected edge -> list of face indices that contain that edge
|
| 501 |
+
edge_to_faces = defaultdict(list)
|
| 502 |
+
|
| 503 |
+
# Populate the edge_to_faces dictionary
|
| 504 |
+
for f_idx, (v0, v1, v2) in enumerate(face_list):
|
| 505 |
+
# For an edge, we always store its endpoints in sorted order
|
| 506 |
+
# to avoid duplication (e.g. edge (2,5) is the same as (5,2)).
|
| 507 |
+
edges = [
|
| 508 |
+
tuple(sorted((v0, v1))),
|
| 509 |
+
tuple(sorted((v1, v2))),
|
| 510 |
+
tuple(sorted((v2, v0)))
|
| 511 |
+
]
|
| 512 |
+
for e in edges:
|
| 513 |
+
edge_to_faces[e].append(f_idx)
|
| 514 |
+
|
| 515 |
+
# Step 2: Build the adjacency (row, col) lists among faces
|
| 516 |
+
row = []
|
| 517 |
+
col = []
|
| 518 |
+
for e, faces_sharing_e in edge_to_faces.items():
|
| 519 |
+
# If an edge is shared by multiple faces, make each pair of those faces adjacent
|
| 520 |
+
f_indices = list(set(faces_sharing_e)) # unique face indices for this edge
|
| 521 |
+
if len(f_indices) > 1:
|
| 522 |
+
# For each pair of faces, mark them as adjacent
|
| 523 |
+
for i in range(len(f_indices)):
|
| 524 |
+
for j in range(i + 1, len(f_indices)):
|
| 525 |
+
f_i = f_indices[i]
|
| 526 |
+
f_j = f_indices[j]
|
| 527 |
+
row.append(f_i)
|
| 528 |
+
col.append(f_j)
|
| 529 |
+
row.append(f_j)
|
| 530 |
+
col.append(f_i)
|
| 531 |
+
|
| 532 |
+
# Create a COO matrix, then convert it to CSR
|
| 533 |
+
data = np.ones(len(row), dtype=np.int8)
|
| 534 |
+
face_adjacency = coo_matrix(
|
| 535 |
+
(data, (row, col)),
|
| 536 |
+
shape=(num_faces, num_faces)
|
| 537 |
+
).tocsr()
|
| 538 |
+
|
| 539 |
+
# Step 3: Ensure single connected component
|
| 540 |
+
# Use connected_components to see how many components exist
|
| 541 |
+
n_components, labels = connected_components(face_adjacency, directed=False)
|
| 542 |
+
|
| 543 |
+
if n_components > 1:
|
| 544 |
+
# We have multiple components; let's "connect" them via dummy edges
|
| 545 |
+
# The simplest approach is to pick one face from each component
|
| 546 |
+
# and connect them sequentially to enforce a single component.
|
| 547 |
+
component_representatives = []
|
| 548 |
+
|
| 549 |
+
for comp_id in range(n_components):
|
| 550 |
+
# indices of faces in this component
|
| 551 |
+
faces_in_comp = np.where(labels == comp_id)[0]
|
| 552 |
+
if len(faces_in_comp) > 0:
|
| 553 |
+
# take the first face in this component as a representative
|
| 554 |
+
component_representatives.append(faces_in_comp[0])
|
| 555 |
+
|
| 556 |
+
# Now, add edges between consecutive representatives
|
| 557 |
+
dummy_row = []
|
| 558 |
+
dummy_col = []
|
| 559 |
+
for i in range(len(component_representatives) - 1):
|
| 560 |
+
f_i = component_representatives[i]
|
| 561 |
+
f_j = component_representatives[i + 1]
|
| 562 |
+
dummy_row.extend([f_i, f_j])
|
| 563 |
+
dummy_col.extend([f_j, f_i])
|
| 564 |
+
|
| 565 |
+
if dummy_row:
|
| 566 |
+
dummy_data = np.ones(len(dummy_row), dtype=np.int8)
|
| 567 |
+
dummy_mat = coo_matrix(
|
| 568 |
+
(dummy_data, (dummy_row, dummy_col)),
|
| 569 |
+
shape=(num_faces, num_faces)
|
| 570 |
+
).tocsr()
|
| 571 |
+
face_adjacency = face_adjacency + dummy_mat
|
| 572 |
+
|
| 573 |
+
return face_adjacency
|
| 574 |
+
#####################################
|
| 575 |
+
|
| 576 |
+
def load_features(feature_filename, mesh_filename, viz_mode):
|
| 577 |
+
|
| 578 |
+
print("Reading features:")
|
| 579 |
+
print(f" Feature filename: {feature_filename}")
|
| 580 |
+
print(f" Mesh filename: {mesh_filename}")
|
| 581 |
+
|
| 582 |
+
# load features
|
| 583 |
+
feat = np.load(feature_filename, allow_pickle=True)
|
| 584 |
+
feat = feat.astype(np.float32)
|
| 585 |
+
|
| 586 |
+
# load mesh things
|
| 587 |
+
tm = load_mesh_util(mesh_filename)
|
| 588 |
+
|
| 589 |
+
V = np.array(tm.vertices, dtype=np.float32)
|
| 590 |
+
F = np.array(tm.faces)
|
| 591 |
+
|
| 592 |
+
if viz_mode == "faces":
|
| 593 |
+
pca_colors = np.array(tm.visual.face_colors, dtype=np.float32)
|
| 594 |
+
pca_colors = pca_colors[:,:3] / 255.
|
| 595 |
+
|
| 596 |
+
else:
|
| 597 |
+
pca_colors = np.array(tm.visual.vertex_colors, dtype=np.float32)
|
| 598 |
+
pca_colors = pca_colors[:,:3] / 255.
|
| 599 |
+
|
| 600 |
+
arrgh(V, F, pca_colors, feat)
|
| 601 |
+
|
| 602 |
+
print(F)
|
| 603 |
+
print(V[F[1][0]])
|
| 604 |
+
print(V[F[1][1]])
|
| 605 |
+
print(V[F[1][2]])
|
| 606 |
+
|
| 607 |
+
return {
|
| 608 |
+
'V' : V,
|
| 609 |
+
'F' : F,
|
| 610 |
+
'pca_colors' : pca_colors,
|
| 611 |
+
'feat_np' : feat,
|
| 612 |
+
'feat_pt' : torch.tensor(feat, device='cuda'),
|
| 613 |
+
'trimesh' : tm,
|
| 614 |
+
'label' : None,
|
| 615 |
+
'num_cluster' : 1,
|
| 616 |
+
'scalar' : None
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
def prep_feature_mesh(m, name='mesh'):
|
| 620 |
+
ps_mesh = ps.register_surface_mesh(name, m['V'], m['F'])
|
| 621 |
+
ps_mesh.set_selection_mode('faces_only')
|
| 622 |
+
m['ps_mesh'] = ps_mesh
|
| 623 |
+
|
| 624 |
+
def viz_pca_colors(m):
|
| 625 |
+
m['ps_mesh'].add_color_quantity('pca colors', m['pca_colors'], enabled=True, defined_on=m["viz_mode"])
|
| 626 |
+
|
| 627 |
+
def viz_feature(m, ind):
|
| 628 |
+
m['ps_mesh'].add_scalar_quantity('pca colors', m['feat_np'][:,ind], cmap='turbo', enabled=True, defined_on=m["viz_mode"])
|
| 629 |
+
|
| 630 |
+
def feature_distance_np(feats, query_feat):
|
| 631 |
+
# normalize
|
| 632 |
+
feats = feats / np.linalg.norm(feats,axis=1)[:,None]
|
| 633 |
+
query_feat = query_feat / np.linalg.norm(query_feat)
|
| 634 |
+
# cosine distance
|
| 635 |
+
cos_sim = np.dot(feats, query_feat)
|
| 636 |
+
cos_dist = (1 - cos_sim) / 2.
|
| 637 |
+
return cos_dist
|
| 638 |
+
|
| 639 |
+
def feature_distance_pt(feats, query_feat):
|
| 640 |
+
return (1. - torch.nn.functional.cosine_similarity(feats, query_feat[None,:], dim=-1)) / 2.
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def ps_callback(opts):
|
| 644 |
+
m = opts.m
|
| 645 |
+
|
| 646 |
+
changed, ind = psim.Combo("Mode", modes_list.index(opts.mode), modes_list)
|
| 647 |
+
if changed:
|
| 648 |
+
opts.mode = modes_list[ind]
|
| 649 |
+
m['ps_mesh'].remove_all_quantities()
|
| 650 |
+
|
| 651 |
+
if opts.mode == 'pca':
|
| 652 |
+
psim.TextUnformatted("""3-dim PCA embeddeding of features is shown as rgb color""")
|
| 653 |
+
viz_pca_colors(m)
|
| 654 |
+
|
| 655 |
+
elif opts.mode == 'feature_viz':
|
| 656 |
+
psim.TextUnformatted("""Use the slider to scrub through all features.\nCtrl-click to type a particular index.""")
|
| 657 |
+
|
| 658 |
+
this_changed, opts.i_feature = psim.SliderInt("feature index", opts.i_feature, v_min=0, v_max=(m['feat_np'].shape[-1]-1))
|
| 659 |
+
this_changed = this_changed or changed
|
| 660 |
+
|
| 661 |
+
if this_changed:
|
| 662 |
+
viz_feature(m, opts.i_feature)
|
| 663 |
+
|
| 664 |
+
elif opts.mode == "cluster_agglo":
|
| 665 |
+
psim.TextUnformatted("""Use the slider to toggle the number of desired clusters.""")
|
| 666 |
+
cluster_changed, opts.i_cluster = psim.SliderInt("number of clusters", opts.i_cluster, v_min=1, v_max=30)
|
| 667 |
+
|
| 668 |
+
### To handle different face adjacency options
|
| 669 |
+
mode_changed, ind = psim.Combo("Adj Matrix Def", adj_mode_list.index(opts.adj_mode), adj_mode_list)
|
| 670 |
+
knn_changed, opts.add_knn_edges = psim.Checkbox("Add KNN edges", opts.add_knn_edges)
|
| 671 |
+
|
| 672 |
+
if mode_changed:
|
| 673 |
+
opts.adj_mode = adj_mode_list[ind]
|
| 674 |
+
|
| 675 |
+
if psim.Button("Recompute"):
|
| 676 |
+
|
| 677 |
+
### Run clustering algorithm
|
| 678 |
+
num_clusters = opts.i_cluster
|
| 679 |
+
|
| 680 |
+
### Mesh 1
|
| 681 |
+
point_feat = m['feat_np']
|
| 682 |
+
point_feat = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True)
|
| 683 |
+
|
| 684 |
+
### Compute adjacency matrix ###
|
| 685 |
+
if opts.adj_mode == "Vanilla":
|
| 686 |
+
adj_matrix = construct_face_adjacency_matrix_naive(opts.m["F"])
|
| 687 |
+
elif opts.adj_mode == "Face_MST":
|
| 688 |
+
adj_matrix = construct_face_adjacency_matrix_facemst(opts.m["F"], opts.m["V"], with_knn=opts.add_knn_edges)
|
| 689 |
+
elif opts.adj_mode == "CC_MST":
|
| 690 |
+
adj_matrix = construct_face_adjacency_matrix_ccmst(opts.m["F"], opts.m["V"], with_knn=opts.add_knn_edges)
|
| 691 |
+
################################
|
| 692 |
+
|
| 693 |
+
## Agglomerative clustering
|
| 694 |
+
clustering = AgglomerativeClustering(connectivity= adj_matrix,
|
| 695 |
+
n_clusters=num_clusters,
|
| 696 |
+
).fit(point_feat)
|
| 697 |
+
|
| 698 |
+
m['ps_mesh'].add_scalar_quantity("cluster", clustering.labels_, cmap='turbo', vminmax=(0, num_clusters-1), enabled=True, defined_on=m["viz_mode"])
|
| 699 |
+
print("Recomputed.")
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
elif opts.mode == "cluster_kmeans":
|
| 703 |
+
psim.TextUnformatted("""Use the slider to toggle the number of desired clusters.""")
|
| 704 |
+
|
| 705 |
+
cluster_changed, opts.i_cluster = psim.SliderInt("number of clusters", opts.i_cluster, v_min=1, v_max=30)
|
| 706 |
+
|
| 707 |
+
if psim.Button("Recompute"):
|
| 708 |
+
|
| 709 |
+
### Run clustering algorithm
|
| 710 |
+
num_clusters = opts.i_cluster
|
| 711 |
+
|
| 712 |
+
### Mesh 1
|
| 713 |
+
point_feat = m['feat_np']
|
| 714 |
+
point_feat = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True)
|
| 715 |
+
clustering = KMeans(n_clusters=num_clusters, random_state=0, n_init="auto").fit(point_feat)
|
| 716 |
+
|
| 717 |
+
m['ps_mesh'].add_scalar_quantity("cluster", clustering.labels_, cmap='turbo', vminmax=(0, num_clusters-1), enabled=True, defined_on=m["viz_mode"])
|
| 718 |
+
|
| 719 |
+
def main():
|
| 720 |
+
## Parse args
|
| 721 |
+
# Uses simple_parsing library to automatically construct parser from the dataclass Options
|
| 722 |
+
parser = ArgumentParser()
|
| 723 |
+
parser.add_arguments(Options, dest="options")
|
| 724 |
+
parser.add_argument('--data_root', default="../exp_results/partfield_features/trellis/", help='Path the model features are stored.')
|
| 725 |
+
args = parser.parse_args()
|
| 726 |
+
opts: Options = args.options
|
| 727 |
+
|
| 728 |
+
DATA_ROOT = args.data_root
|
| 729 |
+
|
| 730 |
+
shape_1 = opts.filename
|
| 731 |
+
|
| 732 |
+
if os.path.exists(os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0.npy")):
|
| 733 |
+
feature_fname1 = os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0.npy")
|
| 734 |
+
mesh_fname1 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_1 + "_0.ply")
|
| 735 |
+
else:
|
| 736 |
+
feature_fname1 = os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0_batch.npy")
|
| 737 |
+
mesh_fname1 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_1 + "_0.ply")
|
| 738 |
+
|
| 739 |
+
#### To save output ####
|
| 740 |
+
os.makedirs(opts.output_fol, exist_ok=True)
|
| 741 |
+
########################
|
| 742 |
+
|
| 743 |
+
# Initialize
|
| 744 |
+
ps.init()
|
| 745 |
+
|
| 746 |
+
mesh_dict = load_features(feature_fname1, mesh_fname1, opts.viz_mode)
|
| 747 |
+
prep_feature_mesh(mesh_dict)
|
| 748 |
+
mesh_dict["viz_mode"] = opts.viz_mode
|
| 749 |
+
opts.m = mesh_dict
|
| 750 |
+
|
| 751 |
+
# Start the interactive UI
|
| 752 |
+
ps.set_user_callback(lambda : ps_callback(opts))
|
| 753 |
+
ps.show()
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
if __name__ == "__main__":
|
| 757 |
+
main()
|
| 758 |
+
|
PartField/compute_metric.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import json
|
| 3 |
+
from os.path import join
|
| 4 |
+
from typing import List
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
def compute_iou(pred, gt):
|
| 8 |
+
intersection = np.logical_and(pred, gt).sum()
|
| 9 |
+
union = np.logical_or(pred, gt).sum()
|
| 10 |
+
if union != 0:
|
| 11 |
+
return (intersection / union) * 100
|
| 12 |
+
else:
|
| 13 |
+
return 0
|
| 14 |
+
|
| 15 |
+
def eval_single_gt_shape(gt_label, pred_masks):
|
| 16 |
+
# gt: [N,], label index
|
| 17 |
+
# pred: [B, N], B is the number of predicted parts, binary label
|
| 18 |
+
unique_gt_label = np.unique(gt_label)
|
| 19 |
+
best_ious = []
|
| 20 |
+
for label in unique_gt_label:
|
| 21 |
+
best_iou = 0
|
| 22 |
+
if label == -1:
|
| 23 |
+
continue
|
| 24 |
+
for mask in pred_masks:
|
| 25 |
+
iou = compute_iou(mask, gt_label == label)
|
| 26 |
+
best_iou = max(best_iou, iou)
|
| 27 |
+
best_ious.append(best_iou)
|
| 28 |
+
return np.mean(best_ious)
|
| 29 |
+
|
| 30 |
+
def eval_whole_dataset(pred_folder, merge_parts=False):
|
| 31 |
+
print(pred_folder)
|
| 32 |
+
meta = json.load(open("/home/mikaelaangel/Desktop/data/PartObjaverse-Tiny_semantic.json", "r"))
|
| 33 |
+
|
| 34 |
+
categories = meta.keys()
|
| 35 |
+
results_per_cat = {}
|
| 36 |
+
per_cat_mious = []
|
| 37 |
+
overall_mious = []
|
| 38 |
+
|
| 39 |
+
MAX_NUM_CLUSTERS = 20
|
| 40 |
+
view_id = 0
|
| 41 |
+
|
| 42 |
+
for cat in categories:
|
| 43 |
+
results_per_cat[cat] = []
|
| 44 |
+
for shape_id in meta[cat].keys():
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
all_pred_labels = []
|
| 48 |
+
for num_cluster in range(2, MAX_NUM_CLUSTERS):
|
| 49 |
+
### load each label
|
| 50 |
+
fname_clustering = os.path.join(pred_folder, "cluster_out", str(shape_id) + "_" + str(view_id) + "_" + str(num_cluster).zfill(2)) + ".npy"
|
| 51 |
+
pred_label = np.load(fname_clustering)
|
| 52 |
+
all_pred_labels.append(np.squeeze(pred_label))
|
| 53 |
+
|
| 54 |
+
all_pred_labels = np.array(all_pred_labels)
|
| 55 |
+
|
| 56 |
+
except:
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
pred_masks = []
|
| 60 |
+
|
| 61 |
+
#### Path for PartObjaverseTiny Labels
|
| 62 |
+
gt_labels_path = "PartObjaverse-Tiny_instance_gt"
|
| 63 |
+
#################################
|
| 64 |
+
|
| 65 |
+
gt_label = np.load(os.path.join(gt_labels_path, shape_id + ".npy"))
|
| 66 |
+
|
| 67 |
+
if merge_parts:
|
| 68 |
+
pred_masks = []
|
| 69 |
+
for result in all_pred_labels:
|
| 70 |
+
pred = result
|
| 71 |
+
assert pred.shape[0] == gt_label.shape[0]
|
| 72 |
+
for label in np.unique(pred):
|
| 73 |
+
pred_masks.append(pred == label)
|
| 74 |
+
miou = eval_single_gt_shape(gt_label, np.array(pred_masks))
|
| 75 |
+
results_per_cat[cat].append(miou)
|
| 76 |
+
else:
|
| 77 |
+
best_miou = 0
|
| 78 |
+
for result in all_pred_labels:
|
| 79 |
+
pred_masks = []
|
| 80 |
+
pred = result
|
| 81 |
+
|
| 82 |
+
for label in np.unique(pred):
|
| 83 |
+
pred_masks.append(pred == label)
|
| 84 |
+
miou = eval_single_gt_shape(gt_label, np.array(pred_masks))
|
| 85 |
+
best_miou = max(best_miou, miou)
|
| 86 |
+
results_per_cat[cat].append(best_miou)
|
| 87 |
+
|
| 88 |
+
print(np.mean(results_per_cat[cat]))
|
| 89 |
+
per_cat_mious.append(np.mean(results_per_cat[cat]))
|
| 90 |
+
overall_mious += results_per_cat[cat]
|
| 91 |
+
print(np.mean(per_cat_mious))
|
| 92 |
+
print(np.mean(overall_mious), len(overall_mious))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
eval_whole_dataset("dump_partobjtiny_clustering")
|
| 97 |
+
|
PartField/configs/final/correspondence_demo.yaml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
result_name: partfield_features/correspondence_demo
|
| 2 |
+
|
| 3 |
+
continue_ckpt: model/model.ckpt
|
| 4 |
+
|
| 5 |
+
triplane_channels_low: 128
|
| 6 |
+
triplane_channels_high: 512
|
| 7 |
+
triplane_resolution: 128
|
| 8 |
+
|
| 9 |
+
vertex_feature: True
|
| 10 |
+
n_point_per_face: 1000
|
| 11 |
+
n_sample_each: 10000
|
| 12 |
+
is_pc: False
|
| 13 |
+
remesh_demo: False
|
| 14 |
+
correspondence_demo: True
|
| 15 |
+
|
| 16 |
+
preprocess_mesh: True
|
| 17 |
+
|
| 18 |
+
dataset:
|
| 19 |
+
type: "Mix"
|
| 20 |
+
data_path: data/DenseCorr3D
|
| 21 |
+
train_batch_size: 1
|
| 22 |
+
val_batch_size: 1
|
| 23 |
+
train_num_workers: 8
|
| 24 |
+
all_files:
|
| 25 |
+
# pairs of example to run correspondence
|
| 26 |
+
- animals/071b8_toy_animals_017/simple_mesh.obj
|
| 27 |
+
- animals/bdfd0_toy_animals_016/simple_mesh.obj
|
| 28 |
+
- animals/2d6b3_toy_animals_009/simple_mesh.obj
|
| 29 |
+
- animals/96615_toy_animals_018/simple_mesh.obj
|
| 30 |
+
- chairs/063d1_chair_006/simple_mesh.obj
|
| 31 |
+
- chairs/bea57_chair_012/simple_mesh.obj
|
| 32 |
+
- chairs/fe0fe_chair_004/simple_mesh.obj
|
| 33 |
+
- chairs/288dc_chair_011/simple_mesh.obj
|
| 34 |
+
# consider decimating animals/../color_mesh.obj yourself for better mesh topology than the provided simple_mesh.obj
|
| 35 |
+
# (e.g. <50k vertices for functional map efficiency).
|
| 36 |
+
|
| 37 |
+
loss:
|
| 38 |
+
triplet: 1.0
|
| 39 |
+
|
| 40 |
+
use_2d_feat: False
|
| 41 |
+
pvcnn:
|
| 42 |
+
point_encoder_type: 'pvcnn'
|
| 43 |
+
z_triplane_channels: 256
|
| 44 |
+
z_triplane_resolution: 128
|
PartField/configs/final/demo.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
result_name: demo_test
|
| 2 |
+
|
| 3 |
+
continue_ckpt: model/model.ckpt
|
| 4 |
+
|
| 5 |
+
triplane_channels_low: 128
|
| 6 |
+
triplane_channels_high: 512
|
| 7 |
+
triplane_resolution: 128
|
| 8 |
+
|
| 9 |
+
n_point_per_face: 1000
|
| 10 |
+
n_sample_each: 10000
|
| 11 |
+
is_pc : False
|
| 12 |
+
remesh_demo : False
|
| 13 |
+
|
| 14 |
+
dataset:
|
| 15 |
+
type: "Mix"
|
| 16 |
+
data_path: "objaverse_data"
|
| 17 |
+
train_batch_size: 1
|
| 18 |
+
val_batch_size: 1
|
| 19 |
+
train_num_workers: 8
|
| 20 |
+
|
| 21 |
+
loss:
|
| 22 |
+
triplet: 1.0
|
| 23 |
+
|
| 24 |
+
use_2d_feat: False
|
| 25 |
+
pvcnn:
|
| 26 |
+
point_encoder_type: 'pvcnn'
|
| 27 |
+
z_triplane_channels: 256
|
| 28 |
+
z_triplane_resolution: 128
|
PartField/download_demo_data.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
mkdir data
|
| 3 |
+
cd data
|
| 4 |
+
mkdir objaverse_samples
|
| 5 |
+
cd objaverse_samples
|
| 6 |
+
wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-050/00200996b8f34f55a2dd2f44d316d107.glb
|
| 7 |
+
wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-042/002e462c8bfa4267a9c9f038c7966f3b.glb
|
| 8 |
+
wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-046/0c3ca2b32545416f8f1e6f0e87def1a6.glb
|
| 9 |
+
wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-063/65c6ffa083c6496eb84a0aa3c48d63ad.glb
|
| 10 |
+
|
| 11 |
+
cd ..
|
| 12 |
+
mkdir trellis_samples
|
| 13 |
+
cd trellis_samples
|
| 14 |
+
wget https://github.com/Trellis3D/trellis3d.github.io/raw/refs/heads/main/assets/scenes/blacksmith/glbs/dwarf.glb
|
| 15 |
+
wget https://github.com/Trellis3D/trellis3d.github.io/raw/refs/heads/main/assets/img2/glbs/goblin.glb
|
| 16 |
+
wget https://github.com/Trellis3D/trellis3d.github.io/raw/refs/heads/main/assets/img2/glbs/excavator.glb
|
| 17 |
+
wget https://github.com/Trellis3D/trellis3d.github.io/raw/refs/heads/main/assets/img2/glbs/elephant.glb
|
| 18 |
+
cd ..
|
| 19 |
+
cd ..
|
PartField/environment.yml
ADDED
|
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: partfield
|
| 2 |
+
channels:
|
| 3 |
+
- nvidia/label/cuda-12.4.0
|
| 4 |
+
- conda-forge
|
| 5 |
+
- defaults
|
| 6 |
+
dependencies:
|
| 7 |
+
- _anaconda_depends=2025.03=py310_mkl_0
|
| 8 |
+
- _libgcc_mutex=0.1=conda_forge
|
| 9 |
+
- _openmp_mutex=4.5=2_gnu
|
| 10 |
+
- aiobotocore=2.21.1=pyhd8ed1ab_0
|
| 11 |
+
- aiohappyeyeballs=2.6.1=pyhd8ed1ab_0
|
| 12 |
+
- aiohttp=3.11.14=py310h89163eb_0
|
| 13 |
+
- aioitertools=0.12.0=pyhd8ed1ab_1
|
| 14 |
+
- aiosignal=1.3.2=pyhd8ed1ab_0
|
| 15 |
+
- alabaster=1.0.0=pyhd8ed1ab_1
|
| 16 |
+
- alsa-lib=1.2.13=hb9d3cd8_0
|
| 17 |
+
- altair=5.5.0=pyhd8ed1ab_1
|
| 18 |
+
- anaconda=custom=py310_3
|
| 19 |
+
- anyio=4.9.0=pyh29332c3_0
|
| 20 |
+
- aom=3.9.1=hac33072_0
|
| 21 |
+
- appdirs=1.4.4=pyhd8ed1ab_1
|
| 22 |
+
- argon2-cffi=23.1.0=pyhd8ed1ab_1
|
| 23 |
+
- argon2-cffi-bindings=21.2.0=py310ha75aee5_5
|
| 24 |
+
- arrow=1.3.0=pyhd8ed1ab_1
|
| 25 |
+
- astroid=3.3.9=py310hff52083_0
|
| 26 |
+
- astropy=6.1.7=py310hf462985_0
|
| 27 |
+
- astropy-iers-data=0.2025.3.31.0.36.18=pyhd8ed1ab_0
|
| 28 |
+
- asttokens=3.0.0=pyhd8ed1ab_1
|
| 29 |
+
- async-lru=2.0.5=pyh29332c3_0
|
| 30 |
+
- async-timeout=5.0.1=pyhd8ed1ab_1
|
| 31 |
+
- asyncssh=2.20.0=pyhd8ed1ab_0
|
| 32 |
+
- atomicwrites=1.4.1=pyhd8ed1ab_1
|
| 33 |
+
- attr=2.5.1=h166bdaf_1
|
| 34 |
+
- attrs=25.3.0=pyh71513ae_0
|
| 35 |
+
- automat=24.8.1=pyhd8ed1ab_1
|
| 36 |
+
- autopep8=2.0.4=pyhd8ed1ab_0
|
| 37 |
+
- aws-c-auth=0.8.6=hd08a7f5_4
|
| 38 |
+
- aws-c-cal=0.8.7=h043a21b_0
|
| 39 |
+
- aws-c-common=0.12.0=hb9d3cd8_0
|
| 40 |
+
- aws-c-compression=0.3.1=h3870646_2
|
| 41 |
+
- aws-c-event-stream=0.5.4=h04a3f94_2
|
| 42 |
+
- aws-c-http=0.9.4=hb9b18c6_4
|
| 43 |
+
- aws-c-io=0.17.0=h3dad3f2_6
|
| 44 |
+
- aws-c-mqtt=0.12.2=h108da3e_2
|
| 45 |
+
- aws-c-s3=0.7.13=h822ba82_2
|
| 46 |
+
- aws-c-sdkutils=0.2.3=h3870646_2
|
| 47 |
+
- aws-checksums=0.2.3=h3870646_2
|
| 48 |
+
- aws-crt-cpp=0.31.0=h55f77e1_4
|
| 49 |
+
- aws-sdk-cpp=1.11.510=h37a5c72_3
|
| 50 |
+
- azure-core-cpp=1.14.0=h5cfcd09_0
|
| 51 |
+
- azure-identity-cpp=1.10.0=h113e628_0
|
| 52 |
+
- azure-storage-blobs-cpp=12.13.0=h3cf044e_1
|
| 53 |
+
- azure-storage-common-cpp=12.8.0=h736e048_1
|
| 54 |
+
- azure-storage-files-datalake-cpp=12.12.0=ha633028_1
|
| 55 |
+
- babel=2.17.0=pyhd8ed1ab_0
|
| 56 |
+
- backports=1.0=pyhd8ed1ab_5
|
| 57 |
+
- backports.tarfile=1.2.0=pyhd8ed1ab_1
|
| 58 |
+
- bcrypt=4.3.0=py310h505e2c1_0
|
| 59 |
+
- beautifulsoup4=4.13.3=pyha770c72_0
|
| 60 |
+
- binaryornot=0.4.4=pyhd8ed1ab_2
|
| 61 |
+
- binutils=2.43=h4852527_4
|
| 62 |
+
- binutils_impl_linux-64=2.43=h4bf12b8_4
|
| 63 |
+
- binutils_linux-64=2.43=h4852527_4
|
| 64 |
+
- black=25.1.0=pyha5154f8_0
|
| 65 |
+
- blas=1.0=mkl
|
| 66 |
+
- bleach=6.2.0=pyh29332c3_4
|
| 67 |
+
- bleach-with-css=6.2.0=h82add2a_4
|
| 68 |
+
- blinker=1.9.0=pyhff2d567_0
|
| 69 |
+
- blosc=1.21.6=he440d0b_1
|
| 70 |
+
- bokeh=3.7.0=pyhd8ed1ab_0
|
| 71 |
+
- brotli=1.1.0=hb9d3cd8_2
|
| 72 |
+
- brotli-bin=1.1.0=hb9d3cd8_2
|
| 73 |
+
- brotli-python=1.1.0=py310hf71b8c6_2
|
| 74 |
+
- brunsli=0.1=h9c3ff4c_0
|
| 75 |
+
- bzip2=1.0.8=h4bc722e_7
|
| 76 |
+
- c-ares=1.34.4=hb9d3cd8_0
|
| 77 |
+
- c-blosc2=2.15.2=h3122c55_1
|
| 78 |
+
- c-compiler=1.9.0=h2b85faf_0
|
| 79 |
+
- ca-certificates=2025.1.31=hbcca054_0
|
| 80 |
+
- cached-property=1.5.2=hd8ed1ab_1
|
| 81 |
+
- cached_property=1.5.2=pyha770c72_1
|
| 82 |
+
- cachetools=5.5.2=pyhd8ed1ab_0
|
| 83 |
+
- cairo=1.18.4=h3394656_0
|
| 84 |
+
- certifi=2025.1.31=pyhd8ed1ab_0
|
| 85 |
+
- cffi=1.17.1=py310h8deb56e_0
|
| 86 |
+
- chardet=5.2.0=pyhd8ed1ab_3
|
| 87 |
+
- charls=2.4.2=h59595ed_0
|
| 88 |
+
- charset-normalizer=3.4.1=pyhd8ed1ab_0
|
| 89 |
+
- click=8.1.8=pyh707e725_0
|
| 90 |
+
- cloudpickle=3.1.1=pyhd8ed1ab_0
|
| 91 |
+
- colorama=0.4.6=pyhd8ed1ab_1
|
| 92 |
+
- colorcet=3.1.0=pyhd8ed1ab_1
|
| 93 |
+
- comm=0.2.2=pyhd8ed1ab_1
|
| 94 |
+
- constantly=15.1.0=py_0
|
| 95 |
+
- contourpy=1.3.1=py310h3788b33_0
|
| 96 |
+
- cookiecutter=2.6.0=pyhd8ed1ab_1
|
| 97 |
+
- cpython=3.10.16=py310hd8ed1ab_1
|
| 98 |
+
- cryptography=44.0.2=py310h6c63255_0
|
| 99 |
+
- cssselect=1.2.0=pyhd8ed1ab_1
|
| 100 |
+
- cuda=12.4.0=0
|
| 101 |
+
- cuda-cccl_linux-64=12.8.90=ha770c72_1
|
| 102 |
+
- cuda-command-line-tools=12.8.1=ha770c72_0
|
| 103 |
+
- cuda-compiler=12.8.1=hbad6d8a_0
|
| 104 |
+
- cuda-crt-dev_linux-64=12.8.93=ha770c72_1
|
| 105 |
+
- cuda-crt-tools=12.8.93=ha770c72_1
|
| 106 |
+
- cuda-cudart=12.8.90=h5888daf_1
|
| 107 |
+
- cuda-cudart-dev=12.8.90=h5888daf_1
|
| 108 |
+
- cuda-cudart-dev_linux-64=12.8.90=h3f2d84a_1
|
| 109 |
+
- cuda-cudart-static=12.8.90=h5888daf_1
|
| 110 |
+
- cuda-cudart-static_linux-64=12.8.90=h3f2d84a_1
|
| 111 |
+
- cuda-cudart_linux-64=12.8.90=h3f2d84a_1
|
| 112 |
+
- cuda-cuobjdump=12.8.90=hbd13f7d_1
|
| 113 |
+
- cuda-cupti=12.8.90=hbd13f7d_0
|
| 114 |
+
- cuda-cupti-dev=12.8.90=h5888daf_0
|
| 115 |
+
- cuda-cuxxfilt=12.8.90=hbd13f7d_1
|
| 116 |
+
- cuda-demo-suite=12.4.99=0
|
| 117 |
+
- cuda-driver-dev=12.8.90=h5888daf_1
|
| 118 |
+
- cuda-driver-dev_linux-64=12.8.90=h3f2d84a_1
|
| 119 |
+
- cuda-gdb=12.8.90=h50b4baa_0
|
| 120 |
+
- cuda-libraries=12.8.1=ha770c72_0
|
| 121 |
+
- cuda-libraries-dev=12.8.1=ha770c72_0
|
| 122 |
+
- cuda-nsight=12.8.90=h7938cbb_1
|
| 123 |
+
- cuda-nvcc=12.8.93=hcdd1206_1
|
| 124 |
+
- cuda-nvcc-dev_linux-64=12.8.93=he91c749_1
|
| 125 |
+
- cuda-nvcc-impl=12.8.93=h85509e4_1
|
| 126 |
+
- cuda-nvcc-tools=12.8.93=he02047a_1
|
| 127 |
+
- cuda-nvcc_linux-64=12.8.93=h04802cd_1
|
| 128 |
+
- cuda-nvdisasm=12.8.90=hbd13f7d_1
|
| 129 |
+
- cuda-nvml-dev=12.8.90=hbd13f7d_0
|
| 130 |
+
- cuda-nvprof=12.8.90=hbd13f7d_0
|
| 131 |
+
- cuda-nvprune=12.8.90=hbd13f7d_1
|
| 132 |
+
- cuda-nvrtc=12.8.93=h5888daf_1
|
| 133 |
+
- cuda-nvrtc-dev=12.8.93=h5888daf_1
|
| 134 |
+
- cuda-nvtx=12.8.90=hbd13f7d_0
|
| 135 |
+
- cuda-nvvm-dev_linux-64=12.8.93=ha770c72_1
|
| 136 |
+
- cuda-nvvm-impl=12.8.93=he02047a_1
|
| 137 |
+
- cuda-nvvm-tools=12.8.93=he02047a_1
|
| 138 |
+
- cuda-nvvp=12.8.93=hbd13f7d_1
|
| 139 |
+
- cuda-opencl=12.8.90=hbd13f7d_0
|
| 140 |
+
- cuda-opencl-dev=12.8.90=h5888daf_0
|
| 141 |
+
- cuda-profiler-api=12.8.90=h7938cbb_1
|
| 142 |
+
- cuda-sanitizer-api=12.8.93=hbd13f7d_1
|
| 143 |
+
- cuda-toolkit=12.8.1=ha804496_0
|
| 144 |
+
- cuda-tools=12.8.1=ha770c72_0
|
| 145 |
+
- cuda-version=12.8=h5d125a7_3
|
| 146 |
+
- cuda-visual-tools=12.8.1=ha770c72_0
|
| 147 |
+
- curl=8.12.1=h332b0f4_0
|
| 148 |
+
- cxx-compiler=1.9.0=h1a2810e_0
|
| 149 |
+
- cycler=0.12.1=pyhd8ed1ab_1
|
| 150 |
+
- cyrus-sasl=2.1.27=h54b06d7_7
|
| 151 |
+
- cytoolz=1.0.1=py310ha75aee5_0
|
| 152 |
+
- datashader=0.17.0=pyhd8ed1ab_0
|
| 153 |
+
- dav1d=1.2.1=hd590300_0
|
| 154 |
+
- dbus=1.13.6=h5008d03_3
|
| 155 |
+
- debugpy=1.8.13=py310hf71b8c6_0
|
| 156 |
+
- decorator=5.2.1=pyhd8ed1ab_0
|
| 157 |
+
- defusedxml=0.7.1=pyhd8ed1ab_0
|
| 158 |
+
- deprecated=1.2.18=pyhd8ed1ab_0
|
| 159 |
+
- diff-match-patch=20241021=pyhd8ed1ab_1
|
| 160 |
+
- dill=0.3.9=pyhd8ed1ab_1
|
| 161 |
+
- docstring-to-markdown=0.16=pyh29332c3_1
|
| 162 |
+
- docutils=0.21.2=pyhd8ed1ab_1
|
| 163 |
+
- double-conversion=3.3.1=h5888daf_0
|
| 164 |
+
- et_xmlfile=2.0.0=pyhd8ed1ab_1
|
| 165 |
+
- exceptiongroup=1.2.2=pyhd8ed1ab_1
|
| 166 |
+
- executing=2.1.0=pyhd8ed1ab_1
|
| 167 |
+
- expat=2.7.0=h5888daf_0
|
| 168 |
+
- fcitx-qt5=1.2.7=h748e8b9_2
|
| 169 |
+
- filelock=3.18.0=pyhd8ed1ab_0
|
| 170 |
+
- flake8=7.1.2=pyhd8ed1ab_0
|
| 171 |
+
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
| 172 |
+
- font-ttf-inconsolata=3.000=h77eed37_0
|
| 173 |
+
- font-ttf-source-code-pro=2.038=h77eed37_0
|
| 174 |
+
- font-ttf-ubuntu=0.83=h77eed37_3
|
| 175 |
+
- fontconfig=2.15.0=h7e30c49_1
|
| 176 |
+
- fonts-conda-ecosystem=1=0
|
| 177 |
+
- fonts-conda-forge=1=0
|
| 178 |
+
- fonttools=4.56.0=py310h89163eb_0
|
| 179 |
+
- fqdn=1.5.1=pyhd8ed1ab_1
|
| 180 |
+
- freetype=2.13.3=h48d6fc4_0
|
| 181 |
+
- frozenlist=1.5.0=py310h89163eb_1
|
| 182 |
+
- fzf=0.61.0=h59e48b9_0
|
| 183 |
+
- gcc=13.3.0=h9576a4e_2
|
| 184 |
+
- gcc_impl_linux-64=13.3.0=h1e990d8_2
|
| 185 |
+
- gcc_linux-64=13.3.0=hc28eda2_8
|
| 186 |
+
- gds-tools=1.13.1.3=h5888daf_0
|
| 187 |
+
- gettext=0.23.1=h5888daf_0
|
| 188 |
+
- gettext-tools=0.23.1=h5888daf_0
|
| 189 |
+
- gflags=2.2.2=h5888daf_1005
|
| 190 |
+
- giflib=5.2.2=hd590300_0
|
| 191 |
+
- gitdb=4.0.12=pyhd8ed1ab_0
|
| 192 |
+
- gitpython=3.1.44=pyhff2d567_0
|
| 193 |
+
- glib=2.84.0=h07242d1_0
|
| 194 |
+
- glib-tools=2.84.0=h4833e2c_0
|
| 195 |
+
- glog=0.7.1=hbabe93e_0
|
| 196 |
+
- gmp=6.3.0=hac33072_2
|
| 197 |
+
- gmpy2=2.1.5=py310he8512ff_3
|
| 198 |
+
- graphite2=1.3.13=h59595ed_1003
|
| 199 |
+
- greenlet=3.1.1=py310hf71b8c6_1
|
| 200 |
+
- gst-plugins-base=1.24.7=h0a52356_0
|
| 201 |
+
- gstreamer=1.24.7=hf3bb09a_0
|
| 202 |
+
- gxx=13.3.0=h9576a4e_2
|
| 203 |
+
- gxx_impl_linux-64=13.3.0=hae580e1_2
|
| 204 |
+
- gxx_linux-64=13.3.0=h6834431_8
|
| 205 |
+
- h11=0.14.0=pyhd8ed1ab_1
|
| 206 |
+
- h2=4.2.0=pyhd8ed1ab_0
|
| 207 |
+
- h5py=3.13.0=nompi_py310h60e0fe6_100
|
| 208 |
+
- harfbuzz=10.4.0=h76408a6_0
|
| 209 |
+
- hdf5=1.14.3=nompi_h2d575fe_109
|
| 210 |
+
- holoviews=1.20.2=pyhd8ed1ab_0
|
| 211 |
+
- hpack=4.1.0=pyhd8ed1ab_0
|
| 212 |
+
- httpcore=1.0.7=pyh29332c3_1
|
| 213 |
+
- httpx=0.28.1=pyhd8ed1ab_0
|
| 214 |
+
- hvplot=0.11.2=pyhd8ed1ab_0
|
| 215 |
+
- hyperframe=6.1.0=pyhd8ed1ab_0
|
| 216 |
+
- hyperlink=21.0.0=pyh29332c3_1
|
| 217 |
+
- icu=75.1=he02047a_0
|
| 218 |
+
- idna=3.10=pyhd8ed1ab_1
|
| 219 |
+
- imagecodecs=2024.12.30=py310h78a9a29_0
|
| 220 |
+
- imageio=2.37.0=pyhfb79c49_0
|
| 221 |
+
- imagesize=1.4.1=pyhd8ed1ab_0
|
| 222 |
+
- imbalanced-learn=0.13.0=pyhd8ed1ab_0
|
| 223 |
+
- importlib-metadata=8.6.1=pyha770c72_0
|
| 224 |
+
- importlib_resources=6.5.2=pyhd8ed1ab_0
|
| 225 |
+
- incremental=24.7.2=pyhd8ed1ab_1
|
| 226 |
+
- inflection=0.5.1=pyhd8ed1ab_1
|
| 227 |
+
- iniconfig=2.0.0=pyhd8ed1ab_1
|
| 228 |
+
- intake=2.0.8=pyhd8ed1ab_0
|
| 229 |
+
- intervaltree=3.1.0=pyhd8ed1ab_1
|
| 230 |
+
- ipykernel=6.29.5=pyh3099207_0
|
| 231 |
+
- ipython=8.34.0=pyh907856f_0
|
| 232 |
+
- ipython_genutils=0.2.0=pyhd8ed1ab_2
|
| 233 |
+
- isoduration=20.11.0=pyhd8ed1ab_1
|
| 234 |
+
- isort=6.0.1=pyhd8ed1ab_0
|
| 235 |
+
- itemadapter=0.11.0=pyhd8ed1ab_0
|
| 236 |
+
- itemloaders=1.3.2=pyhd8ed1ab_1
|
| 237 |
+
- itsdangerous=2.2.0=pyhd8ed1ab_1
|
| 238 |
+
- jaraco.classes=3.4.0=pyhd8ed1ab_2
|
| 239 |
+
- jaraco.context=6.0.1=pyhd8ed1ab_0
|
| 240 |
+
- jaraco.functools=4.1.0=pyhd8ed1ab_0
|
| 241 |
+
- jedi=0.19.2=pyhd8ed1ab_1
|
| 242 |
+
- jeepney=0.9.0=pyhd8ed1ab_0
|
| 243 |
+
- jellyfish=1.1.3=py310h505e2c1_0
|
| 244 |
+
- jinja2=3.1.6=pyhd8ed1ab_0
|
| 245 |
+
- jmespath=1.0.1=pyhd8ed1ab_1
|
| 246 |
+
- joblib=1.4.2=pyhd8ed1ab_1
|
| 247 |
+
- jq=1.7.1=hd590300_0
|
| 248 |
+
- json5=0.10.0=pyhd8ed1ab_1
|
| 249 |
+
- jsonpointer=3.0.0=py310hff52083_1
|
| 250 |
+
- jsonschema=4.23.0=pyhd8ed1ab_1
|
| 251 |
+
- jsonschema-specifications=2024.10.1=pyhd8ed1ab_1
|
| 252 |
+
- jsonschema-with-format-nongpl=4.23.0=hd8ed1ab_1
|
| 253 |
+
- jupyter=1.1.1=pyhd8ed1ab_1
|
| 254 |
+
- jupyter-lsp=2.2.5=pyhd8ed1ab_1
|
| 255 |
+
- jupyter_client=8.6.3=pyhd8ed1ab_1
|
| 256 |
+
- jupyter_console=6.6.3=pyhd8ed1ab_1
|
| 257 |
+
- jupyter_core=5.7.2=pyh31011fe_1
|
| 258 |
+
- jupyter_events=0.12.0=pyh29332c3_0
|
| 259 |
+
- jupyter_server=2.15.0=pyhd8ed1ab_0
|
| 260 |
+
- jupyter_server_terminals=0.5.3=pyhd8ed1ab_1
|
| 261 |
+
- jupyterlab=4.3.6=pyhd8ed1ab_0
|
| 262 |
+
- jupyterlab-variableinspector=3.2.4=pyhd8ed1ab_0
|
| 263 |
+
- jupyterlab_pygments=0.3.0=pyhd8ed1ab_2
|
| 264 |
+
- jupyterlab_server=2.27.3=pyhd8ed1ab_1
|
| 265 |
+
- jxrlib=1.1=hd590300_3
|
| 266 |
+
- kernel-headers_linux-64=3.10.0=he073ed8_18
|
| 267 |
+
- keyring=25.6.0=pyha804496_0
|
| 268 |
+
- keyutils=1.6.1=h166bdaf_0
|
| 269 |
+
- kiwisolver=1.4.7=py310h3788b33_0
|
| 270 |
+
- krb5=1.21.3=h659f571_0
|
| 271 |
+
- lame=3.100=h166bdaf_1003
|
| 272 |
+
- lazy-loader=0.4=pyhd8ed1ab_2
|
| 273 |
+
- lazy_loader=0.4=pyhd8ed1ab_2
|
| 274 |
+
- lcms2=2.17=h717163a_0
|
| 275 |
+
- ld_impl_linux-64=2.43=h712a8e2_4
|
| 276 |
+
- lerc=4.0.0=h27087fc_0
|
| 277 |
+
- libabseil=20250127.1=cxx17_hbbce691_0
|
| 278 |
+
- libaec=1.1.3=h59595ed_0
|
| 279 |
+
- libarrow=19.0.1=h120c447_5_cpu
|
| 280 |
+
- libarrow-acero=19.0.1=hcb10f89_5_cpu
|
| 281 |
+
- libarrow-dataset=19.0.1=hcb10f89_5_cpu
|
| 282 |
+
- libarrow-substrait=19.0.1=h1bed206_5_cpu
|
| 283 |
+
- libasprintf=0.23.1=h8e693c7_0
|
| 284 |
+
- libasprintf-devel=0.23.1=h8e693c7_0
|
| 285 |
+
- libavif16=1.2.1=hbb36593_2
|
| 286 |
+
- libblas=3.9.0=1_h86c2bf4_netlib
|
| 287 |
+
- libbrotlicommon=1.1.0=hb9d3cd8_2
|
| 288 |
+
- libbrotlidec=1.1.0=hb9d3cd8_2
|
| 289 |
+
- libbrotlienc=1.1.0=hb9d3cd8_2
|
| 290 |
+
- libcap=2.75=h39aace5_0
|
| 291 |
+
- libcblas=3.9.0=8_h3b12eaf_netlib
|
| 292 |
+
- libclang-cpp19.1=19.1.7=default_hb5137d0_2
|
| 293 |
+
- libclang-cpp20.1=20.1.1=default_hb5137d0_0
|
| 294 |
+
- libclang13=20.1.1=default_h9c6a7e4_0
|
| 295 |
+
- libcrc32c=1.1.2=h9c3ff4c_0
|
| 296 |
+
- libcublas=12.8.4.1=h9ab20c4_1
|
| 297 |
+
- libcublas-dev=12.8.4.1=h9ab20c4_1
|
| 298 |
+
- libcufft=11.3.3.83=h5888daf_1
|
| 299 |
+
- libcufft-dev=11.3.3.83=h5888daf_1
|
| 300 |
+
- libcufile=1.13.1.3=h12f29b5_0
|
| 301 |
+
- libcufile-dev=1.13.1.3=h5888daf_0
|
| 302 |
+
- libcups=2.3.3=h4637d8d_4
|
| 303 |
+
- libcurand=10.3.9.90=h9ab20c4_1
|
| 304 |
+
- libcurand-dev=10.3.9.90=h9ab20c4_1
|
| 305 |
+
- libcurl=8.12.1=h332b0f4_0
|
| 306 |
+
- libcusolver=11.7.3.90=h9ab20c4_1
|
| 307 |
+
- libcusolver-dev=11.7.3.90=h9ab20c4_1
|
| 308 |
+
- libcusparse=12.5.8.93=hbd13f7d_0
|
| 309 |
+
- libcusparse-dev=12.5.8.93=h5888daf_0
|
| 310 |
+
- libdeflate=1.23=h4ddbbb0_0
|
| 311 |
+
- libdrm=2.4.124=hb9d3cd8_0
|
| 312 |
+
- libedit=3.1.20250104=pl5321h7949ede_0
|
| 313 |
+
- libegl=1.7.0=ha4b6fd6_2
|
| 314 |
+
- libev=4.33=hd590300_2
|
| 315 |
+
- libevent=2.1.12=hf998b51_1
|
| 316 |
+
- libexpat=2.7.0=h5888daf_0
|
| 317 |
+
- libffi=3.4.6=h2dba641_1
|
| 318 |
+
- libflac=1.4.3=h59595ed_0
|
| 319 |
+
- libgcc=14.2.0=h767d61c_2
|
| 320 |
+
- libgcc-devel_linux-64=13.3.0=hc03c837_102
|
| 321 |
+
- libgcc-ng=14.2.0=h69a702a_2
|
| 322 |
+
- libgcrypt-lib=1.11.0=hb9d3cd8_2
|
| 323 |
+
- libgettextpo=0.23.1=h5888daf_0
|
| 324 |
+
- libgettextpo-devel=0.23.1=h5888daf_0
|
| 325 |
+
- libgfortran=14.2.0=h69a702a_2
|
| 326 |
+
- libgfortran-ng=14.2.0=h69a702a_2
|
| 327 |
+
- libgfortran5=14.2.0=hf1ad2bd_2
|
| 328 |
+
- libgl=1.7.0=ha4b6fd6_2
|
| 329 |
+
- libglib=2.84.0=h2ff4ddf_0
|
| 330 |
+
- libglvnd=1.7.0=ha4b6fd6_2
|
| 331 |
+
- libglx=1.7.0=ha4b6fd6_2
|
| 332 |
+
- libgomp=14.2.0=h767d61c_2
|
| 333 |
+
- libgoogle-cloud=2.36.0=hc4361e1_1
|
| 334 |
+
- libgoogle-cloud-storage=2.36.0=h0121fbd_1
|
| 335 |
+
- libgpg-error=1.51=hbd13f7d_1
|
| 336 |
+
- libgrpc=1.71.0=he753a82_0
|
| 337 |
+
- libhwy=1.1.0=h00ab1b0_0
|
| 338 |
+
- libiconv=1.18=h4ce23a2_1
|
| 339 |
+
- libjpeg-turbo=3.0.0=hd590300_1
|
| 340 |
+
- libjxl=0.11.1=hdb8da77_0
|
| 341 |
+
- liblapack=3.9.0=8_h3b12eaf_netlib
|
| 342 |
+
- libllvm19=19.1.7=ha7bfdaf_1
|
| 343 |
+
- libllvm20=20.1.1=ha7bfdaf_0
|
| 344 |
+
- liblzma=5.6.4=hb9d3cd8_0
|
| 345 |
+
- libnghttp2=1.64.0=h161d5f1_0
|
| 346 |
+
- libnl=3.11.0=hb9d3cd8_0
|
| 347 |
+
- libnpp=12.3.3.100=h9ab20c4_1
|
| 348 |
+
- libnpp-dev=12.3.3.100=h9ab20c4_1
|
| 349 |
+
- libnsl=2.0.1=hd590300_0
|
| 350 |
+
- libntlm=1.8=hb9d3cd8_0
|
| 351 |
+
- libnuma=2.0.18=h4ab18f5_2
|
| 352 |
+
- libnvfatbin=12.8.90=hbd13f7d_0
|
| 353 |
+
- libnvfatbin-dev=12.8.90=h5888daf_0
|
| 354 |
+
- libnvjitlink=12.8.93=h5888daf_1
|
| 355 |
+
- libnvjitlink-dev=12.8.93=h5888daf_1
|
| 356 |
+
- libnvjpeg=12.3.5.92=h97fd463_0
|
| 357 |
+
- libnvjpeg-dev=12.3.5.92=ha770c72_0
|
| 358 |
+
- libogg=1.3.5=h4ab18f5_0
|
| 359 |
+
- libopengl=1.7.0=ha4b6fd6_2
|
| 360 |
+
- libopentelemetry-cpp=1.19.0=hd1b1c89_0
|
| 361 |
+
- libopentelemetry-cpp-headers=1.19.0=ha770c72_0
|
| 362 |
+
- libopus=1.3.1=h7f98852_1
|
| 363 |
+
- libparquet=19.0.1=h081d1f1_5_cpu
|
| 364 |
+
- libpciaccess=0.18=hd590300_0
|
| 365 |
+
- libpng=1.6.47=h943b412_0
|
| 366 |
+
- libpq=17.4=h27ae623_0
|
| 367 |
+
- libprotobuf=5.29.3=h501fc15_0
|
| 368 |
+
- libre2-11=2024.07.02=hba17884_3
|
| 369 |
+
- libsanitizer=13.3.0=he8ea267_2
|
| 370 |
+
- libsndfile=1.2.2=hc60ed4a_1
|
| 371 |
+
- libsodium=1.0.20=h4ab18f5_0
|
| 372 |
+
- libspatialindex=2.1.0=he57a185_0
|
| 373 |
+
- libsqlite=3.49.1=hee588c1_2
|
| 374 |
+
- libssh2=1.11.1=hf672d98_0
|
| 375 |
+
- libstdcxx=14.2.0=h8f9b012_2
|
| 376 |
+
- libstdcxx-devel_linux-64=13.3.0=hc03c837_102
|
| 377 |
+
- libstdcxx-ng=14.2.0=h4852527_2
|
| 378 |
+
- libsystemd0=257.4=h4e0b6ca_1
|
| 379 |
+
- libthrift=0.21.0=h0e7cc3e_0
|
| 380 |
+
- libtiff=4.7.0=hd9ff511_3
|
| 381 |
+
- libudev1=257.4=hbe16f8c_1
|
| 382 |
+
- libutf8proc=2.10.0=h4c51ac1_0
|
| 383 |
+
- libuuid=2.38.1=h0b41bf4_0
|
| 384 |
+
- libvorbis=1.3.7=h9c3ff4c_0
|
| 385 |
+
- libwebp=1.5.0=hae8dbeb_0
|
| 386 |
+
- libwebp-base=1.5.0=h851e524_0
|
| 387 |
+
- libxcb=1.17.0=h8a09558_0
|
| 388 |
+
- libxcrypt=4.4.36=hd590300_1
|
| 389 |
+
- libxkbcommon=1.8.1=hc4a0caf_0
|
| 390 |
+
- libxkbfile=1.1.0=h166bdaf_1
|
| 391 |
+
- libxml2=2.13.7=h8d12d68_0
|
| 392 |
+
- libxslt=1.1.39=h76b75d6_0
|
| 393 |
+
- libzlib=1.3.1=hb9d3cd8_2
|
| 394 |
+
- libzopfli=1.0.3=h9c3ff4c_0
|
| 395 |
+
- linkify-it-py=2.0.3=pyhd8ed1ab_1
|
| 396 |
+
- locket=1.0.0=pyhd8ed1ab_0
|
| 397 |
+
- lxml=5.3.1=py310h6ee67d5_0
|
| 398 |
+
- lz4=4.3.3=py310h80b8a69_2
|
| 399 |
+
- lz4-c=1.10.0=h5888daf_1
|
| 400 |
+
- markdown=3.6=pyhd8ed1ab_0
|
| 401 |
+
- markdown-it-py=3.0.0=pyhd8ed1ab_1
|
| 402 |
+
- markupsafe=3.0.2=py310h89163eb_1
|
| 403 |
+
- matplotlib=3.10.1=py310hff52083_0
|
| 404 |
+
- matplotlib-base=3.10.1=py310h68603db_0
|
| 405 |
+
- matplotlib-inline=0.1.7=pyhd8ed1ab_1
|
| 406 |
+
- mccabe=0.7.0=pyhd8ed1ab_1
|
| 407 |
+
- mdit-py-plugins=0.4.2=pyhd8ed1ab_1
|
| 408 |
+
- mdurl=0.1.2=pyhd8ed1ab_1
|
| 409 |
+
- mistune=3.1.3=pyh29332c3_0
|
| 410 |
+
- more-itertools=10.6.0=pyhd8ed1ab_0
|
| 411 |
+
- mpc=1.3.1=h24ddda3_1
|
| 412 |
+
- mpfr=4.2.1=h90cbb55_3
|
| 413 |
+
- mpg123=1.32.9=hc50e24c_0
|
| 414 |
+
- mpmath=1.3.0=pyhd8ed1ab_1
|
| 415 |
+
- msgpack-python=1.1.0=py310h3788b33_0
|
| 416 |
+
- multidict=6.2.0=py310h89163eb_0
|
| 417 |
+
- multipledispatch=0.6.0=pyhd8ed1ab_1
|
| 418 |
+
- munkres=1.1.4=pyh9f0ad1d_0
|
| 419 |
+
- mypy=1.15.0=py310ha75aee5_0
|
| 420 |
+
- mypy_extensions=1.0.0=pyha770c72_1
|
| 421 |
+
- mysql-common=9.0.1=h266115a_5
|
| 422 |
+
- mysql-libs=9.0.1=he0572af_5
|
| 423 |
+
- narwhals=1.32.0=pyhd8ed1ab_0
|
| 424 |
+
- nbclient=0.10.2=pyhd8ed1ab_0
|
| 425 |
+
- nbconvert=7.16.6=hb482800_0
|
| 426 |
+
- nbconvert-core=7.16.6=pyh29332c3_0
|
| 427 |
+
- nbconvert-pandoc=7.16.6=hed9df3c_0
|
| 428 |
+
- nbformat=5.10.4=pyhd8ed1ab_1
|
| 429 |
+
- ncurses=6.5=h2d0b736_3
|
| 430 |
+
- nest-asyncio=1.6.0=pyhd8ed1ab_1
|
| 431 |
+
- networkx=3.4.2=pyh267e887_2
|
| 432 |
+
- nlohmann_json=3.11.3=he02047a_1
|
| 433 |
+
- nltk=3.9.1=pyhd8ed1ab_1
|
| 434 |
+
- nomkl=1.0=h5ca1d4c_0
|
| 435 |
+
- notebook=7.3.3=pyhd8ed1ab_0
|
| 436 |
+
- notebook-shim=0.2.4=pyhd8ed1ab_1
|
| 437 |
+
- nsight-compute=2025.1.1.2=hb5ebaad_0
|
| 438 |
+
- nspr=4.36=h5888daf_0
|
| 439 |
+
- nss=3.110=h159eef7_0
|
| 440 |
+
- numexpr=2.10.2=py310hdb6e06b_100
|
| 441 |
+
- numpydoc=1.8.0=pyhd8ed1ab_1
|
| 442 |
+
- ocl-icd=2.3.2=hb9d3cd8_2
|
| 443 |
+
- oniguruma=6.9.10=hb9d3cd8_0
|
| 444 |
+
- opencl-headers=2024.10.24=h5888daf_0
|
| 445 |
+
- openjpeg=2.5.3=h5fbd93e_0
|
| 446 |
+
- openldap=2.6.9=he970967_0
|
| 447 |
+
- openpyxl=3.1.5=py310h0999ad4_1
|
| 448 |
+
- openssl=3.4.1=h7b32b05_0
|
| 449 |
+
- orc=2.1.1=h17f744e_1
|
| 450 |
+
- overrides=7.7.0=pyhd8ed1ab_1
|
| 451 |
+
- packaging=24.2=pyhd8ed1ab_2
|
| 452 |
+
- pandas=2.2.3=py310h5eaa309_1
|
| 453 |
+
- pandoc=3.6.4=ha770c72_0
|
| 454 |
+
- pandocfilters=1.5.0=pyhd8ed1ab_0
|
| 455 |
+
- panel=1.6.2=pyhd8ed1ab_0
|
| 456 |
+
- param=2.2.0=pyhd8ed1ab_0
|
| 457 |
+
- parsel=1.10.0=pyhd8ed1ab_0
|
| 458 |
+
- parso=0.8.4=pyhd8ed1ab_1
|
| 459 |
+
- partd=1.4.2=pyhd8ed1ab_0
|
| 460 |
+
- pathspec=0.12.1=pyhd8ed1ab_1
|
| 461 |
+
- patsy=1.0.1=pyhd8ed1ab_1
|
| 462 |
+
- pcre2=10.44=hba22ea6_2
|
| 463 |
+
- pexpect=4.9.0=pyhd8ed1ab_1
|
| 464 |
+
- pickleshare=0.7.5=pyhd8ed1ab_1004
|
| 465 |
+
- pillow=11.1.0=py310h7e6dc6c_0
|
| 466 |
+
- pip=25.0.1=pyh8b19718_0
|
| 467 |
+
- pixman=0.44.2=h29eaf8c_0
|
| 468 |
+
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_2
|
| 469 |
+
- platformdirs=4.3.7=pyh29332c3_0
|
| 470 |
+
- plotly=6.0.1=pyhd8ed1ab_0
|
| 471 |
+
- pluggy=1.5.0=pyhd8ed1ab_1
|
| 472 |
+
- ply=3.11=pyhd8ed1ab_3
|
| 473 |
+
- prometheus-cpp=1.3.0=ha5d0236_0
|
| 474 |
+
- prometheus_client=0.21.1=pyhd8ed1ab_0
|
| 475 |
+
- prompt-toolkit=3.0.50=pyha770c72_0
|
| 476 |
+
- prompt_toolkit=3.0.50=hd8ed1ab_0
|
| 477 |
+
- propcache=0.2.1=py310h89163eb_1
|
| 478 |
+
- protego=0.4.0=pyhd8ed1ab_0
|
| 479 |
+
- protobuf=5.29.3=py310hcba5963_0
|
| 480 |
+
- psutil=7.0.0=py310ha75aee5_0
|
| 481 |
+
- pthread-stubs=0.4=hb9d3cd8_1002
|
| 482 |
+
- ptyprocess=0.7.0=pyhd8ed1ab_1
|
| 483 |
+
- pulseaudio-client=17.0=hac146a9_1
|
| 484 |
+
- pure_eval=0.2.3=pyhd8ed1ab_1
|
| 485 |
+
- py-cpuinfo=9.0.0=pyhd8ed1ab_1
|
| 486 |
+
- pyarrow=19.0.1=py310hff52083_0
|
| 487 |
+
- pyarrow-core=19.0.1=py310hac404ae_0_cpu
|
| 488 |
+
- pyasn1=0.6.1=pyhd8ed1ab_2
|
| 489 |
+
- pyasn1-modules=0.4.2=pyhd8ed1ab_0
|
| 490 |
+
- pycodestyle=2.12.1=pyhd8ed1ab_1
|
| 491 |
+
- pyconify=0.2.1=pyhd8ed1ab_0
|
| 492 |
+
- pycparser=2.22=pyh29332c3_1
|
| 493 |
+
- pyct=0.5.0=pyhd8ed1ab_1
|
| 494 |
+
- pycurl=7.45.6=py310h6811363_0
|
| 495 |
+
- pydeck=0.9.1=pyhd8ed1ab_0
|
| 496 |
+
- pydispatcher=2.0.5=py_1
|
| 497 |
+
- pydocstyle=6.3.0=pyhd8ed1ab_1
|
| 498 |
+
- pyerfa=2.0.1.5=py310hf462985_0
|
| 499 |
+
- pyflakes=3.2.0=pyhd8ed1ab_1
|
| 500 |
+
- pygithub=2.6.1=pyhd8ed1ab_0
|
| 501 |
+
- pygments=2.19.1=pyhd8ed1ab_0
|
| 502 |
+
- pyjwt=2.10.1=pyhd8ed1ab_0
|
| 503 |
+
- pylint=3.3.5=pyh29332c3_0
|
| 504 |
+
- pylint-venv=3.0.4=pyhd8ed1ab_1
|
| 505 |
+
- pyls-spyder=0.4.0=pyhd8ed1ab_1
|
| 506 |
+
- pynacl=1.5.0=py310ha75aee5_4
|
| 507 |
+
- pyodbc=5.2.0=py310hf71b8c6_0
|
| 508 |
+
- pyopenssl=25.0.0=pyhd8ed1ab_0
|
| 509 |
+
- pyparsing=3.2.3=pyhd8ed1ab_1
|
| 510 |
+
- pyqt=5.15.9=py310h04931ad_5
|
| 511 |
+
- pyqt5-sip=12.12.2=py310hc6cd4ac_5
|
| 512 |
+
- pyqtwebengine=5.15.9=py310h704022c_5
|
| 513 |
+
- pyside6=6.8.3=py310hfd10a26_0
|
| 514 |
+
- pysocks=1.7.1=pyha55dd90_7
|
| 515 |
+
- pytables=3.10.1=py310h1affd9f_4
|
| 516 |
+
- pytest=8.3.5=pyhd8ed1ab_0
|
| 517 |
+
- python=3.10.16=he725a3c_1_cpython
|
| 518 |
+
- python-dateutil=2.9.0.post0=pyhff2d567_1
|
| 519 |
+
- python-fastjsonschema=2.21.1=pyhd8ed1ab_0
|
| 520 |
+
- python-gssapi=1.9.0=py310h695cd88_1
|
| 521 |
+
- python-json-logger=2.0.7=pyhd8ed1ab_0
|
| 522 |
+
- python-lsp-black=2.0.0=pyhff2d567_1
|
| 523 |
+
- python-lsp-jsonrpc=1.1.2=pyhff2d567_1
|
| 524 |
+
- python-lsp-server=1.12.2=pyhff2d567_0
|
| 525 |
+
- python-lsp-server-base=1.12.2=pyhd8ed1ab_0
|
| 526 |
+
- python-slugify=8.0.4=pyhd8ed1ab_1
|
| 527 |
+
- python-tzdata=2025.2=pyhd8ed1ab_0
|
| 528 |
+
- python_abi=3.10=5_cp310
|
| 529 |
+
- pytoolconfig=1.2.5=pyhd8ed1ab_1
|
| 530 |
+
- pytz=2024.1=pyhd8ed1ab_0
|
| 531 |
+
- pyuca=1.2=pyhd8ed1ab_2
|
| 532 |
+
- pyviz_comms=3.0.4=pyhd8ed1ab_1
|
| 533 |
+
- pywavelets=1.8.0=py310hf462985_0
|
| 534 |
+
- pyxdg=0.28=pyhd8ed1ab_0
|
| 535 |
+
- pyyaml=6.0.2=py310h89163eb_2
|
| 536 |
+
- pyzmq=26.3.0=py310h71f11fc_0
|
| 537 |
+
- qdarkstyle=3.2.3=pyhd8ed1ab_1
|
| 538 |
+
- qhull=2020.2=h434a139_5
|
| 539 |
+
- qstylizer=0.2.4=pyhff2d567_0
|
| 540 |
+
- qt-main=5.15.15=hc3cb62f_2
|
| 541 |
+
- qt-webengine=5.15.15=h0071231_2
|
| 542 |
+
- qt6-main=6.8.3=h588cce1_0
|
| 543 |
+
- qtawesome=1.4.0=pyh9208f05_1
|
| 544 |
+
- qtconsole=5.6.1=pyhd8ed1ab_1
|
| 545 |
+
- qtconsole-base=5.6.1=pyha770c72_1
|
| 546 |
+
- qtpy=2.4.3=pyhd8ed1ab_0
|
| 547 |
+
- queuelib=1.8.0=pyhd8ed1ab_0
|
| 548 |
+
- rav1e=0.6.6=he8a937b_2
|
| 549 |
+
- rdma-core=56.0=h5888daf_0
|
| 550 |
+
- re2=2024.07.02=h9925aae_3
|
| 551 |
+
- readline=8.2=h8c095d6_2
|
| 552 |
+
- referencing=0.36.2=pyh29332c3_0
|
| 553 |
+
- regex=2024.11.6=py310ha75aee5_0
|
| 554 |
+
- requests=2.32.3=pyhd8ed1ab_1
|
| 555 |
+
- requests-file=2.1.0=pyhd8ed1ab_1
|
| 556 |
+
- rfc3339-validator=0.1.4=pyhd8ed1ab_1
|
| 557 |
+
- rfc3986-validator=0.1.1=pyh9f0ad1d_0
|
| 558 |
+
- rich=14.0.0=pyh29332c3_0
|
| 559 |
+
- rope=1.13.0=pyhd8ed1ab_1
|
| 560 |
+
- rpds-py=0.24.0=py310hc1293b2_0
|
| 561 |
+
- rtree=1.4.0=pyh11ca60a_1
|
| 562 |
+
- s2n=1.5.14=h6c98b2b_0
|
| 563 |
+
- s3fs=2025.3.1=pyhd8ed1ab_0
|
| 564 |
+
- scikit-image=0.25.2=py310h5eaa309_0
|
| 565 |
+
- scikit-learn=1.6.1=py310h27f47ee_0
|
| 566 |
+
- scipy=1.15.2=py310h1d65ade_0
|
| 567 |
+
- scrapy=2.12.0=py310hff52083_1
|
| 568 |
+
- seaborn=0.13.2=hd8ed1ab_3
|
| 569 |
+
- seaborn-base=0.13.2=pyhd8ed1ab_3
|
| 570 |
+
- secretstorage=3.3.3=py310hff52083_3
|
| 571 |
+
- send2trash=1.8.3=pyh0d859eb_1
|
| 572 |
+
- service-identity=24.2.0=pyha770c72_1
|
| 573 |
+
- service_identity=24.2.0=hd8ed1ab_1
|
| 574 |
+
- setuptools=75.8.2=pyhff2d567_0
|
| 575 |
+
- sip=6.7.12=py310hc6cd4ac_0
|
| 576 |
+
- six=1.17.0=pyhd8ed1ab_0
|
| 577 |
+
- sklearn-compat=0.1.3=pyhd8ed1ab_0
|
| 578 |
+
- smmap=5.0.2=pyhd8ed1ab_0
|
| 579 |
+
- snappy=1.2.1=h8bd8927_1
|
| 580 |
+
- sniffio=1.3.1=pyhd8ed1ab_1
|
| 581 |
+
- snowballstemmer=2.2.0=pyhd8ed1ab_0
|
| 582 |
+
- sortedcontainers=2.4.0=pyhd8ed1ab_1
|
| 583 |
+
- soupsieve=2.5=pyhd8ed1ab_1
|
| 584 |
+
- sphinx=8.1.3=pyhd8ed1ab_1
|
| 585 |
+
- sphinxcontrib-applehelp=2.0.0=pyhd8ed1ab_1
|
| 586 |
+
- sphinxcontrib-devhelp=2.0.0=pyhd8ed1ab_1
|
| 587 |
+
- sphinxcontrib-htmlhelp=2.1.0=pyhd8ed1ab_1
|
| 588 |
+
- sphinxcontrib-jsmath=1.0.1=pyhd8ed1ab_1
|
| 589 |
+
- sphinxcontrib-qthelp=2.0.0=pyhd8ed1ab_1
|
| 590 |
+
- sphinxcontrib-serializinghtml=1.1.10=pyhd8ed1ab_1
|
| 591 |
+
- spyder=6.0.5=hd8ed1ab_0
|
| 592 |
+
- spyder-base=6.0.5=linux_pyh62a8a7d_0
|
| 593 |
+
- spyder-kernels=3.0.3=unix_pyh707e725_0
|
| 594 |
+
- sqlalchemy=2.0.40=py310ha75aee5_0
|
| 595 |
+
- stack_data=0.6.3=pyhd8ed1ab_1
|
| 596 |
+
- statsmodels=0.14.4=py310hf462985_0
|
| 597 |
+
- streamlit=1.44.0=pyhd8ed1ab_1
|
| 598 |
+
- superqt=0.7.3=pyhb6d5dde_0
|
| 599 |
+
- svt-av1=3.0.2=h5888daf_0
|
| 600 |
+
- sympy=1.13.3=pyh2585a3b_105
|
| 601 |
+
- sysroot_linux-64=2.17=h0157908_18
|
| 602 |
+
- tabulate=0.9.0=pyhd8ed1ab_2
|
| 603 |
+
- tblib=3.0.0=pyhd8ed1ab_1
|
| 604 |
+
- tenacity=9.0.0=pyhd8ed1ab_1
|
| 605 |
+
- terminado=0.18.1=pyh0d859eb_0
|
| 606 |
+
- text-unidecode=1.3=pyhd8ed1ab_2
|
| 607 |
+
- textdistance=4.6.3=pyhd8ed1ab_1
|
| 608 |
+
- threadpoolctl=3.6.0=pyhecae5ae_0
|
| 609 |
+
- three-merge=0.1.1=pyhd8ed1ab_1
|
| 610 |
+
- tifffile=2025.3.30=pyhd8ed1ab_0
|
| 611 |
+
- tinycss2=1.4.0=pyhd8ed1ab_0
|
| 612 |
+
- tk=8.6.13=noxft_h4845f30_101
|
| 613 |
+
- tldextract=5.1.3=pyhd8ed1ab_1
|
| 614 |
+
- toml=0.10.2=pyhd8ed1ab_1
|
| 615 |
+
- tomli=2.2.1=pyhd8ed1ab_1
|
| 616 |
+
- tomlkit=0.13.2=pyha770c72_1
|
| 617 |
+
- toolz=1.0.0=pyhd8ed1ab_1
|
| 618 |
+
- tornado=6.4.2=py310ha75aee5_0
|
| 619 |
+
- tqdm=4.67.1=pyhd8ed1ab_1
|
| 620 |
+
- traitlets=5.14.3=pyhd8ed1ab_1
|
| 621 |
+
- twisted=24.11.0=py310ha75aee5_0
|
| 622 |
+
- types-python-dateutil=2.9.0.20241206=pyhd8ed1ab_0
|
| 623 |
+
- typing-extensions=4.13.0=h9fa5a19_1
|
| 624 |
+
- typing_extensions=4.13.0=pyh29332c3_1
|
| 625 |
+
- typing_utils=0.1.0=pyhd8ed1ab_1
|
| 626 |
+
- tzdata=2025b=h78e105d_0
|
| 627 |
+
- uc-micro-py=1.0.3=pyhd8ed1ab_1
|
| 628 |
+
- ujson=5.10.0=py310hf71b8c6_1
|
| 629 |
+
- unicodedata2=16.0.0=py310ha75aee5_0
|
| 630 |
+
- unixodbc=2.3.12=h661eb56_0
|
| 631 |
+
- uri-template=1.3.0=pyhd8ed1ab_1
|
| 632 |
+
- urllib3=2.3.0=pyhd8ed1ab_0
|
| 633 |
+
- w3lib=2.3.1=pyhd8ed1ab_0
|
| 634 |
+
- watchdog=6.0.0=py310hff52083_0
|
| 635 |
+
- wayland=1.23.1=h3e06ad9_0
|
| 636 |
+
- wcwidth=0.2.13=pyhd8ed1ab_1
|
| 637 |
+
- webcolors=24.11.1=pyhd8ed1ab_0
|
| 638 |
+
- webencodings=0.5.1=pyhd8ed1ab_3
|
| 639 |
+
- websocket-client=1.8.0=pyhd8ed1ab_1
|
| 640 |
+
- whatthepatch=1.0.7=pyhd8ed1ab_1
|
| 641 |
+
- wheel=0.45.1=pyhd8ed1ab_1
|
| 642 |
+
- wrapt=1.17.2=py310ha75aee5_0
|
| 643 |
+
- wurlitzer=3.1.1=pyhd8ed1ab_1
|
| 644 |
+
- xarray=2025.3.1=pyhd8ed1ab_0
|
| 645 |
+
- xcb-util=0.4.1=hb711507_2
|
| 646 |
+
- xcb-util-cursor=0.1.5=hb9d3cd8_0
|
| 647 |
+
- xcb-util-image=0.4.0=hb711507_2
|
| 648 |
+
- xcb-util-keysyms=0.4.1=hb711507_0
|
| 649 |
+
- xcb-util-renderutil=0.3.10=hb711507_0
|
| 650 |
+
- xcb-util-wm=0.4.2=hb711507_0
|
| 651 |
+
- xkeyboard-config=2.43=hb9d3cd8_0
|
| 652 |
+
- xorg-libice=1.1.2=hb9d3cd8_0
|
| 653 |
+
- xorg-libsm=1.2.6=he73a12e_0
|
| 654 |
+
- xorg-libx11=1.8.12=h4f16b4b_0
|
| 655 |
+
- xorg-libxau=1.0.12=hb9d3cd8_0
|
| 656 |
+
- xorg-libxcomposite=0.4.6=hb9d3cd8_2
|
| 657 |
+
- xorg-libxcursor=1.2.3=hb9d3cd8_0
|
| 658 |
+
- xorg-libxdamage=1.1.6=hb9d3cd8_0
|
| 659 |
+
- xorg-libxdmcp=1.1.5=hb9d3cd8_0
|
| 660 |
+
- xorg-libxext=1.3.6=hb9d3cd8_0
|
| 661 |
+
- xorg-libxfixes=6.0.1=hb9d3cd8_0
|
| 662 |
+
- xorg-libxi=1.8.2=hb9d3cd8_0
|
| 663 |
+
- xorg-libxrandr=1.5.4=hb9d3cd8_0
|
| 664 |
+
- xorg-libxrender=0.9.12=hb9d3cd8_0
|
| 665 |
+
- xorg-libxtst=1.2.5=hb9d3cd8_3
|
| 666 |
+
- xorg-libxxf86vm=1.1.6=hb9d3cd8_0
|
| 667 |
+
- xyzservices=2025.1.0=pyhd8ed1ab_0
|
| 668 |
+
- yaml=0.2.5=h7f98852_2
|
| 669 |
+
- yapf=0.43.0=pyhd8ed1ab_1
|
| 670 |
+
- yarl=1.18.3=py310h89163eb_1
|
| 671 |
+
- zeromq=4.3.5=h3b0a872_7
|
| 672 |
+
- zfp=1.0.1=h5888daf_2
|
| 673 |
+
- zict=3.0.0=pyhd8ed1ab_1
|
| 674 |
+
- zipp=3.21.0=pyhd8ed1ab_1
|
| 675 |
+
- zlib=1.3.1=hb9d3cd8_2
|
| 676 |
+
- zlib-ng=2.2.4=h7955e40_0
|
| 677 |
+
- zope.interface=7.2=py310ha75aee5_0
|
| 678 |
+
- zstandard=0.23.0=py310ha75aee5_1
|
| 679 |
+
- zstd=1.5.7=hb8e6e7a_2
|
| 680 |
+
- pip:
|
| 681 |
+
- addict==2.4.0
|
| 682 |
+
- arrgh==1.0.0
|
| 683 |
+
- boto3==1.37.24
|
| 684 |
+
- botocore==1.37.24
|
| 685 |
+
- configargparse==1.7
|
| 686 |
+
- cuda-bindings==12.8.0
|
| 687 |
+
- cuda-python==12.8.0
|
| 688 |
+
- cudf-cu12==25.2.2
|
| 689 |
+
- cuml-cu12==25.2.1
|
| 690 |
+
- cupy-cuda12x==13.4.1
|
| 691 |
+
- cuvs-cu12==25.2.1
|
| 692 |
+
- dash==3.0.2
|
| 693 |
+
- dask==2024.12.1
|
| 694 |
+
- dask-cuda==25.2.0
|
| 695 |
+
- dask-cudf-cu12==25.2.2
|
| 696 |
+
- dask-expr==1.1.21
|
| 697 |
+
- distributed==2024.12.1
|
| 698 |
+
- distributed-ucxx-cu12==0.42.0
|
| 699 |
+
- docstring-parser==0.16
|
| 700 |
+
- einops==0.8.1
|
| 701 |
+
- fastrlock==0.8.3
|
| 702 |
+
- flask==3.0.3
|
| 703 |
+
- fsspec==2024.12.0
|
| 704 |
+
- ipywidgets==8.1.5
|
| 705 |
+
- jupyterlab-widgets==3.0.13
|
| 706 |
+
- libcudf-cu12==25.2.2
|
| 707 |
+
- libcuml-cu12==25.2.1
|
| 708 |
+
- libcuvs-cu12==25.2.1
|
| 709 |
+
- libigl==2.5.1
|
| 710 |
+
- libkvikio-cu12==25.2.1
|
| 711 |
+
- libraft-cu12==25.2.0
|
| 712 |
+
- libucx-cu12==1.18.0
|
| 713 |
+
- libucxx-cu12==0.42.0
|
| 714 |
+
- lightning==2.2.0
|
| 715 |
+
- lightning-utilities==0.14.2
|
| 716 |
+
- llvmlite==0.43.0
|
| 717 |
+
- loguru==0.7.3
|
| 718 |
+
- mesh2sdf==1.1.0
|
| 719 |
+
- numba==0.60.0
|
| 720 |
+
- numba-cuda==0.2.0
|
| 721 |
+
- numpy==2.0.2
|
| 722 |
+
- nvidia-cublas-cu12==12.4.2.65
|
| 723 |
+
- nvidia-cuda-cupti-cu12==12.4.99
|
| 724 |
+
- nvidia-cuda-nvrtc-cu12==12.4.99
|
| 725 |
+
- nvidia-cuda-runtime-cu12==12.4.99
|
| 726 |
+
- nvidia-cudnn-cu12==9.1.0.70
|
| 727 |
+
- nvidia-cufft-cu12==11.2.0.44
|
| 728 |
+
- nvidia-curand-cu12==10.3.5.119
|
| 729 |
+
- nvidia-cusolver-cu12==11.6.0.99
|
| 730 |
+
- nvidia-cusparse-cu12==12.3.0.142
|
| 731 |
+
- nvidia-ml-py==12.570.86
|
| 732 |
+
- nvidia-nccl-cu12==2.20.5
|
| 733 |
+
- nvidia-nvcomp-cu12==4.2.0.11
|
| 734 |
+
- nvidia-nvjitlink-cu12==12.4.99
|
| 735 |
+
- nvidia-nvtx-cu12==12.4.99
|
| 736 |
+
- nvtx==0.2.11
|
| 737 |
+
- open3d==0.19.0
|
| 738 |
+
- plyfile==1.1
|
| 739 |
+
- polyscope==2.4.0
|
| 740 |
+
- pooch==1.8.2
|
| 741 |
+
- potpourri3d==1.2.1
|
| 742 |
+
- pylibcudf-cu12==25.2.2
|
| 743 |
+
- pylibraft-cu12==25.2.0
|
| 744 |
+
- pymeshlab==2023.12.post3
|
| 745 |
+
- pynvjitlink-cu12==0.5.2
|
| 746 |
+
- pynvml==12.0.0
|
| 747 |
+
- pyquaternion==0.9.9
|
| 748 |
+
- pytorch-lightning==2.5.1
|
| 749 |
+
- pyvista==0.44.2
|
| 750 |
+
- raft-dask-cu12==25.2.0
|
| 751 |
+
- rapids-dask-dependency==25.2.0
|
| 752 |
+
- retrying==1.3.4
|
| 753 |
+
- rmm-cu12==25.2.0
|
| 754 |
+
- s3transfer==0.11.4
|
| 755 |
+
- scooby==0.10.0
|
| 756 |
+
- simple-parsing==0.1.7
|
| 757 |
+
- tetgen==0.6.5
|
| 758 |
+
- torch==2.4.0+cu124
|
| 759 |
+
- torch-scatter==2.1.2+pt24cu124
|
| 760 |
+
- torchaudio==2.4.0+cu124
|
| 761 |
+
- torchmetrics==1.7.0
|
| 762 |
+
- torchvision==0.19.0+cu124
|
| 763 |
+
- treelite==4.4.1
|
| 764 |
+
- trimesh==4.6.6
|
| 765 |
+
- triton==3.0.0
|
| 766 |
+
- ucx-py-cu12==0.42.0
|
| 767 |
+
- ucxx-cu12==0.42.0
|
| 768 |
+
- vtk==9.3.1
|
| 769 |
+
- werkzeug==3.0.6
|
| 770 |
+
- widgetsnbextension==4.0.13
|
| 771 |
+
- xgboost==3.0.0
|
| 772 |
+
- yacs==0.1.8
|
PartField/partfield/__pycache__/dataloader.cpython-310.pyc
ADDED
|
Binary file (7.71 kB). View file
|
|
|
PartField/partfield/__pycache__/model_trainer_pvcnn_only_demo.cpython-310.pyc
ADDED
|
Binary file (7.72 kB). View file
|
|
|
PartField/partfield/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (372 Bytes). View file
|
|
|
PartField/partfield/config/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os.path as osp
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import pytz
|
| 5 |
+
|
| 6 |
+
def default_argument_parser(add_help=True, default_config_file=""):
|
| 7 |
+
parser = argparse.ArgumentParser(add_help=add_help)
|
| 8 |
+
parser.add_argument("--config-file", '-c', default=default_config_file, metavar="FILE", help="path to config file")
|
| 9 |
+
parser.add_argument(
|
| 10 |
+
"--opts",
|
| 11 |
+
help="Modify config options using the command-line",
|
| 12 |
+
default=None,
|
| 13 |
+
nargs=argparse.REMAINDER,
|
| 14 |
+
)
|
| 15 |
+
return parser
|
| 16 |
+
|
| 17 |
+
def setup(args, freeze=True):
|
| 18 |
+
from .defaults import _C as cfg
|
| 19 |
+
cfg = cfg.clone()
|
| 20 |
+
cfg.merge_from_file(args.config_file)
|
| 21 |
+
cfg.merge_from_list(args.opts)
|
| 22 |
+
dt = datetime.now(pytz.timezone('America/Los_Angeles')).strftime('%y%m%d-%H%M%S')
|
| 23 |
+
cfg.output_dir = osp.join(cfg.output_dir, cfg.name, dt)
|
| 24 |
+
if freeze:
|
| 25 |
+
cfg.freeze()
|
| 26 |
+
return cfg
|
PartField/partfield/config/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.15 kB). View file
|
|
|
PartField/partfield/config/__pycache__/defaults.cpython-310.pyc
ADDED
|
Binary file (2.17 kB). View file
|
|
|
PartField/partfield/config/defaults.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from yacs.config import CfgNode as CN
|
| 2 |
+
|
| 3 |
+
_C = CN()
|
| 4 |
+
_C.seed = 0
|
| 5 |
+
_C.output_dir = "results"
|
| 6 |
+
_C.result_name = "test_all"
|
| 7 |
+
|
| 8 |
+
_C.triplet_sampling = "random"
|
| 9 |
+
_C.load_original_mesh = False
|
| 10 |
+
|
| 11 |
+
_C.num_pos = 64
|
| 12 |
+
_C.num_neg_random = 256
|
| 13 |
+
_C.num_neg_hard_pc = 128
|
| 14 |
+
_C.num_neg_hard_emb = 128
|
| 15 |
+
|
| 16 |
+
_C.vertex_feature = False # if true, sample feature on vertices; if false, sample feature on faces
|
| 17 |
+
_C.n_point_per_face = 2000
|
| 18 |
+
_C.n_sample_each = 10000
|
| 19 |
+
_C.preprocess_mesh = False
|
| 20 |
+
|
| 21 |
+
_C.regress_2d_feat = False
|
| 22 |
+
|
| 23 |
+
_C.is_pc = False
|
| 24 |
+
|
| 25 |
+
_C.cut_manifold = False
|
| 26 |
+
_C.remesh_demo = False
|
| 27 |
+
_C.correspondence_demo = False
|
| 28 |
+
|
| 29 |
+
_C.save_every_epoch = 10
|
| 30 |
+
_C.training_epochs = 30
|
| 31 |
+
_C.continue_training = False
|
| 32 |
+
|
| 33 |
+
_C.continue_ckpt = None
|
| 34 |
+
_C.epoch_selected = "epoch=50.ckpt"
|
| 35 |
+
|
| 36 |
+
_C.triplane_resolution = 128
|
| 37 |
+
_C.triplane_channels_low = 128
|
| 38 |
+
_C.triplane_channels_high = 512
|
| 39 |
+
_C.lr = 1e-3
|
| 40 |
+
_C.train = True
|
| 41 |
+
_C.test = False
|
| 42 |
+
|
| 43 |
+
_C.inference_save_pred_sdf_to_mesh=True
|
| 44 |
+
_C.inference_save_feat_pca=True
|
| 45 |
+
_C.name = "test"
|
| 46 |
+
_C.test_subset = False
|
| 47 |
+
_C.test_corres = False
|
| 48 |
+
_C.test_partobjaversetiny = False
|
| 49 |
+
|
| 50 |
+
_C.dataset = CN()
|
| 51 |
+
_C.dataset.type = "Demo_Dataset"
|
| 52 |
+
_C.dataset.data_path = "objaverse_data/"
|
| 53 |
+
_C.dataset.train_num_workers = 64
|
| 54 |
+
_C.dataset.val_num_workers = 32
|
| 55 |
+
_C.dataset.train_batch_size = 2
|
| 56 |
+
_C.dataset.val_batch_size = 2
|
| 57 |
+
_C.dataset.all_files = [] # only used for correspondence demo
|
| 58 |
+
|
| 59 |
+
_C.voxel2triplane = CN()
|
| 60 |
+
_C.voxel2triplane.transformer_dim = 1024
|
| 61 |
+
_C.voxel2triplane.transformer_layers = 6
|
| 62 |
+
_C.voxel2triplane.transformer_heads = 8
|
| 63 |
+
_C.voxel2triplane.triplane_low_res = 32
|
| 64 |
+
_C.voxel2triplane.triplane_high_res = 256
|
| 65 |
+
_C.voxel2triplane.triplane_dim = 64
|
| 66 |
+
_C.voxel2triplane.normalize_vox_feat = False
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
_C.loss = CN()
|
| 70 |
+
_C.loss.triplet = 0.0
|
| 71 |
+
_C.loss.sdf = 1.0
|
| 72 |
+
_C.loss.feat = 10.0
|
| 73 |
+
_C.loss.l1 = 0.0
|
| 74 |
+
|
| 75 |
+
_C.use_pvcnn = False
|
| 76 |
+
_C.use_pvcnnonly = True
|
| 77 |
+
|
| 78 |
+
_C.pvcnn = CN()
|
| 79 |
+
_C.pvcnn.point_encoder_type = 'pvcnn'
|
| 80 |
+
_C.pvcnn.use_point_scatter = True
|
| 81 |
+
_C.pvcnn.z_triplane_channels = 64
|
| 82 |
+
_C.pvcnn.z_triplane_resolution = 256
|
| 83 |
+
_C.pvcnn.unet_cfg = CN()
|
| 84 |
+
_C.pvcnn.unet_cfg.depth = 3
|
| 85 |
+
_C.pvcnn.unet_cfg.enabled = True
|
| 86 |
+
_C.pvcnn.unet_cfg.rolled = True
|
| 87 |
+
_C.pvcnn.unet_cfg.use_3d_aware = True
|
| 88 |
+
_C.pvcnn.unet_cfg.start_hidden_channels = 32
|
| 89 |
+
_C.pvcnn.unet_cfg.use_initial_conv = False
|
| 90 |
+
|
| 91 |
+
_C.use_2d_feat = False
|
| 92 |
+
_C.inference_metrics_only = False
|
PartField/partfield/dataloader.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import boto3
|
| 3 |
+
import json
|
| 4 |
+
from os import path as osp
|
| 5 |
+
# from botocore.config import Config
|
| 6 |
+
# from botocore.exceptions import ClientError
|
| 7 |
+
import h5py
|
| 8 |
+
import io
|
| 9 |
+
import numpy as np
|
| 10 |
+
import skimage
|
| 11 |
+
import trimesh
|
| 12 |
+
import os
|
| 13 |
+
from scipy.spatial import KDTree
|
| 14 |
+
import gc
|
| 15 |
+
from plyfile import PlyData
|
| 16 |
+
|
| 17 |
+
## For remeshing
|
| 18 |
+
import mesh2sdf
|
| 19 |
+
import tetgen
|
| 20 |
+
import vtk
|
| 21 |
+
import math
|
| 22 |
+
import tempfile
|
| 23 |
+
|
| 24 |
+
### For mesh processing
|
| 25 |
+
import pymeshlab
|
| 26 |
+
|
| 27 |
+
from partfield.utils import *
|
| 28 |
+
|
| 29 |
+
#########################
|
| 30 |
+
## To handle quad inputs
|
| 31 |
+
#########################
|
| 32 |
+
def quad_to_triangle_mesh(F):
|
| 33 |
+
"""
|
| 34 |
+
Converts a quad-dominant mesh into a pure triangle mesh by splitting quads into two triangles.
|
| 35 |
+
|
| 36 |
+
Parameters:
|
| 37 |
+
quad_mesh (trimesh.Trimesh): Input mesh with quad faces.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
trimesh.Trimesh: A new mesh with only triangle faces.
|
| 41 |
+
"""
|
| 42 |
+
faces = F
|
| 43 |
+
|
| 44 |
+
### If already a triangle mesh -- skip
|
| 45 |
+
if len(faces[0]) == 3:
|
| 46 |
+
return F
|
| 47 |
+
|
| 48 |
+
new_faces = []
|
| 49 |
+
|
| 50 |
+
for face in faces:
|
| 51 |
+
if len(face) == 4: # Quad face
|
| 52 |
+
# Split into two triangles
|
| 53 |
+
new_faces.append([face[0], face[1], face[2]]) # Triangle 1
|
| 54 |
+
new_faces.append([face[0], face[2], face[3]]) # Triangle 2
|
| 55 |
+
else:
|
| 56 |
+
print(f"Warning: Skipping non-triangle/non-quad face {face}")
|
| 57 |
+
|
| 58 |
+
new_faces = np.array(new_faces)
|
| 59 |
+
|
| 60 |
+
return new_faces
|
| 61 |
+
#########################
|
| 62 |
+
|
| 63 |
+
class Demo_Dataset(torch.utils.data.Dataset):
|
| 64 |
+
def __init__(self, cfg):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
self.data_path = cfg.dataset.data_path
|
| 68 |
+
self.is_pc = cfg.is_pc
|
| 69 |
+
|
| 70 |
+
all_files = os.listdir(self.data_path)
|
| 71 |
+
|
| 72 |
+
selected = []
|
| 73 |
+
for f in all_files:
|
| 74 |
+
if ".ply" in f and self.is_pc:
|
| 75 |
+
selected.append(f)
|
| 76 |
+
elif (".obj" in f or ".glb" in f or ".off" in f) and not self.is_pc:
|
| 77 |
+
selected.append(f)
|
| 78 |
+
|
| 79 |
+
self.data_list = selected
|
| 80 |
+
self.pc_num_pts = 100000
|
| 81 |
+
|
| 82 |
+
self.preprocess_mesh = cfg.preprocess_mesh
|
| 83 |
+
self.result_name = cfg.result_name
|
| 84 |
+
|
| 85 |
+
print("val dataset len:", len(self.data_list))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def __len__(self):
|
| 89 |
+
return len(self.data_list)
|
| 90 |
+
|
| 91 |
+
def load_ply_to_numpy(self, filename):
|
| 92 |
+
"""
|
| 93 |
+
Load a PLY file and extract the point cloud as a (N, 3) NumPy array.
|
| 94 |
+
|
| 95 |
+
Parameters:
|
| 96 |
+
filename (str): Path to the PLY file.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
numpy.ndarray: Point cloud array of shape (N, 3).
|
| 100 |
+
"""
|
| 101 |
+
ply_data = PlyData.read(filename)
|
| 102 |
+
|
| 103 |
+
# Extract vertex data
|
| 104 |
+
vertex_data = ply_data["vertex"]
|
| 105 |
+
|
| 106 |
+
# Convert to NumPy array (x, y, z)
|
| 107 |
+
points = np.vstack([vertex_data["x"], vertex_data["y"], vertex_data["z"]]).T
|
| 108 |
+
|
| 109 |
+
return points
|
| 110 |
+
|
| 111 |
+
def get_model(self, ply_file):
|
| 112 |
+
|
| 113 |
+
uid = ply_file.split(".")[-2].replace("/", "_")
|
| 114 |
+
|
| 115 |
+
####
|
| 116 |
+
if self.is_pc:
|
| 117 |
+
ply_file_read = os.path.join(self.data_path, ply_file)
|
| 118 |
+
pc = self.load_ply_to_numpy(ply_file_read)
|
| 119 |
+
|
| 120 |
+
bbmin = pc.min(0)
|
| 121 |
+
bbmax = pc.max(0)
|
| 122 |
+
center = (bbmin + bbmax) * 0.5
|
| 123 |
+
scale = 2.0 * 0.9 / (bbmax - bbmin).max()
|
| 124 |
+
pc = (pc - center) * scale
|
| 125 |
+
|
| 126 |
+
else:
|
| 127 |
+
obj_path = os.path.join(self.data_path, ply_file)
|
| 128 |
+
mesh = load_mesh_util(obj_path)
|
| 129 |
+
vertices = mesh.vertices
|
| 130 |
+
faces = mesh.faces
|
| 131 |
+
|
| 132 |
+
bbmin = vertices.min(0)
|
| 133 |
+
bbmax = vertices.max(0)
|
| 134 |
+
center = (bbmin + bbmax) * 0.5
|
| 135 |
+
scale = 2.0 * 0.9 / (bbmax - bbmin).max()
|
| 136 |
+
vertices = (vertices - center) * scale
|
| 137 |
+
mesh.vertices = vertices
|
| 138 |
+
|
| 139 |
+
### Make sure it is a triangle mesh -- just convert the quad
|
| 140 |
+
mesh.faces = quad_to_triangle_mesh(faces)
|
| 141 |
+
|
| 142 |
+
print("before preprocessing...")
|
| 143 |
+
print(mesh.vertices.shape)
|
| 144 |
+
print(mesh.faces.shape)
|
| 145 |
+
print()
|
| 146 |
+
|
| 147 |
+
### Pre-process mesh
|
| 148 |
+
if self.preprocess_mesh:
|
| 149 |
+
# Create a PyMeshLab mesh directly from vertices and faces
|
| 150 |
+
ml_mesh = pymeshlab.Mesh(vertex_matrix=mesh.vertices, face_matrix=mesh.faces)
|
| 151 |
+
|
| 152 |
+
# Create a MeshSet and add your mesh
|
| 153 |
+
ms = pymeshlab.MeshSet()
|
| 154 |
+
ms.add_mesh(ml_mesh, "from_trimesh")
|
| 155 |
+
|
| 156 |
+
# Apply filters
|
| 157 |
+
ms.apply_filter('meshing_remove_duplicate_faces')
|
| 158 |
+
ms.apply_filter('meshing_remove_duplicate_vertices')
|
| 159 |
+
percentageMerge = pymeshlab.PercentageValue(0.5)
|
| 160 |
+
ms.apply_filter('meshing_merge_close_vertices', threshold=percentageMerge)
|
| 161 |
+
ms.apply_filter('meshing_remove_unreferenced_vertices')
|
| 162 |
+
|
| 163 |
+
# Save or extract mesh
|
| 164 |
+
processed = ms.current_mesh()
|
| 165 |
+
mesh.vertices = processed.vertex_matrix()
|
| 166 |
+
mesh.faces = processed.face_matrix()
|
| 167 |
+
|
| 168 |
+
print("after preprocessing...")
|
| 169 |
+
print(mesh.vertices.shape)
|
| 170 |
+
print(mesh.faces.shape)
|
| 171 |
+
|
| 172 |
+
### Save input
|
| 173 |
+
save_dir = f"exp_results/{self.result_name}"
|
| 174 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 175 |
+
view_id = 0
|
| 176 |
+
mesh.export(f'{save_dir}/input_{uid}_{view_id}.ply')
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
pc, _ = trimesh.sample.sample_surface(mesh, self.pc_num_pts)
|
| 180 |
+
|
| 181 |
+
result = {
|
| 182 |
+
'uid': uid
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
result['pc'] = torch.tensor(pc, dtype=torch.float32)
|
| 186 |
+
|
| 187 |
+
if not self.is_pc:
|
| 188 |
+
result['vertices'] = mesh.vertices
|
| 189 |
+
result['faces'] = mesh.faces
|
| 190 |
+
|
| 191 |
+
return result
|
| 192 |
+
|
| 193 |
+
def __getitem__(self, index):
|
| 194 |
+
|
| 195 |
+
gc.collect()
|
| 196 |
+
|
| 197 |
+
return self.get_model(self.data_list[index])
|
| 198 |
+
|
| 199 |
+
##############
|
| 200 |
+
|
| 201 |
+
###############################
|
| 202 |
+
class Demo_Remesh_Dataset(torch.utils.data.Dataset):
|
| 203 |
+
def __init__(self, cfg):
|
| 204 |
+
super().__init__()
|
| 205 |
+
|
| 206 |
+
self.data_path = cfg.dataset.data_path
|
| 207 |
+
|
| 208 |
+
all_files = os.listdir(self.data_path)
|
| 209 |
+
|
| 210 |
+
selected = []
|
| 211 |
+
for f in all_files:
|
| 212 |
+
if (".obj" in f or ".glb" in f):
|
| 213 |
+
selected.append(f)
|
| 214 |
+
|
| 215 |
+
self.data_list = selected
|
| 216 |
+
self.pc_num_pts = 100000
|
| 217 |
+
|
| 218 |
+
self.preprocess_mesh = cfg.preprocess_mesh
|
| 219 |
+
self.result_name = cfg.result_name
|
| 220 |
+
|
| 221 |
+
print("val dataset len:", len(self.data_list))
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def __len__(self):
|
| 225 |
+
return len(self.data_list)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def get_model(self, ply_file):
|
| 229 |
+
|
| 230 |
+
uid = ply_file.split(".")[-2]
|
| 231 |
+
|
| 232 |
+
####
|
| 233 |
+
obj_path = os.path.join(self.data_path, ply_file)
|
| 234 |
+
mesh = load_mesh_util(obj_path)
|
| 235 |
+
vertices = mesh.vertices
|
| 236 |
+
faces = mesh.faces
|
| 237 |
+
|
| 238 |
+
bbmin = vertices.min(0)
|
| 239 |
+
bbmax = vertices.max(0)
|
| 240 |
+
center = (bbmin + bbmax) * 0.5
|
| 241 |
+
scale = 2.0 * 0.9 / (bbmax - bbmin).max()
|
| 242 |
+
vertices = (vertices - center) * scale
|
| 243 |
+
mesh.vertices = vertices
|
| 244 |
+
|
| 245 |
+
### Pre-process mesh
|
| 246 |
+
if self.preprocess_mesh:
|
| 247 |
+
# Create a PyMeshLab mesh directly from vertices and faces
|
| 248 |
+
ml_mesh = pymeshlab.Mesh(vertex_matrix=mesh.vertices, face_matrix=mesh.faces)
|
| 249 |
+
|
| 250 |
+
# Create a MeshSet and add your mesh
|
| 251 |
+
ms = pymeshlab.MeshSet()
|
| 252 |
+
ms.add_mesh(ml_mesh, "from_trimesh")
|
| 253 |
+
|
| 254 |
+
# Apply filters
|
| 255 |
+
ms.apply_filter('meshing_remove_duplicate_faces')
|
| 256 |
+
ms.apply_filter('meshing_remove_duplicate_vertices')
|
| 257 |
+
percentageMerge = pymeshlab.PercentageValue(0.5)
|
| 258 |
+
ms.apply_filter('meshing_merge_close_vertices', threshold=percentageMerge)
|
| 259 |
+
ms.apply_filter('meshing_remove_unreferenced_vertices')
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# Save or extract mesh
|
| 263 |
+
processed = ms.current_mesh()
|
| 264 |
+
mesh.vertices = processed.vertex_matrix()
|
| 265 |
+
mesh.faces = processed.face_matrix()
|
| 266 |
+
|
| 267 |
+
print("after preprocessing...")
|
| 268 |
+
print(mesh.vertices.shape)
|
| 269 |
+
print(mesh.faces.shape)
|
| 270 |
+
|
| 271 |
+
### Save input
|
| 272 |
+
save_dir = f"exp_results/{self.result_name}"
|
| 273 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 274 |
+
view_id = 0
|
| 275 |
+
mesh.export(f'{save_dir}/input_{uid}_{view_id}.ply')
|
| 276 |
+
|
| 277 |
+
try:
|
| 278 |
+
###### Remesh ######
|
| 279 |
+
size= 256
|
| 280 |
+
level = 2 / size
|
| 281 |
+
|
| 282 |
+
sdf = mesh2sdf.core.compute(mesh.vertices, mesh.faces, size)
|
| 283 |
+
# NOTE: the negative value is not reliable if the mesh is not watertight
|
| 284 |
+
udf = np.abs(sdf)
|
| 285 |
+
vertices, faces, _, _ = skimage.measure.marching_cubes(udf, level)
|
| 286 |
+
|
| 287 |
+
#### Only use SDF mesh ###
|
| 288 |
+
# new_mesh = trimesh.Trimesh(vertices, faces)
|
| 289 |
+
##########################
|
| 290 |
+
|
| 291 |
+
#### Make tet #####
|
| 292 |
+
components = trimesh.Trimesh(vertices, faces).split(only_watertight=False)
|
| 293 |
+
new_mesh = [] #trimesh.Trimesh()
|
| 294 |
+
if len(components) > 100000:
|
| 295 |
+
raise NotImplementedError
|
| 296 |
+
for i, c in enumerate(components):
|
| 297 |
+
c.fix_normals()
|
| 298 |
+
new_mesh.append(c) #trimesh.util.concatenate(new_mesh, c)
|
| 299 |
+
new_mesh = trimesh.util.concatenate(new_mesh)
|
| 300 |
+
|
| 301 |
+
# generate tet mesh
|
| 302 |
+
tet = tetgen.TetGen(new_mesh.vertices, new_mesh.faces)
|
| 303 |
+
tet.tetrahedralize(plc=True, nobisect=1., quality=True, fixedvolume=True, maxvolume=math.sqrt(2) / 12 * (2 / size) ** 3)
|
| 304 |
+
tmp_vtk = tempfile.NamedTemporaryFile(suffix='.vtk', delete=True)
|
| 305 |
+
tet.grid.save(tmp_vtk.name)
|
| 306 |
+
|
| 307 |
+
# extract surface mesh from tet mesh
|
| 308 |
+
reader = vtk.vtkUnstructuredGridReader()
|
| 309 |
+
reader.SetFileName(tmp_vtk.name)
|
| 310 |
+
reader.Update()
|
| 311 |
+
surface_filter = vtk.vtkDataSetSurfaceFilter()
|
| 312 |
+
surface_filter.SetInputConnection(reader.GetOutputPort())
|
| 313 |
+
surface_filter.Update()
|
| 314 |
+
polydata = surface_filter.GetOutput()
|
| 315 |
+
writer = vtk.vtkOBJWriter()
|
| 316 |
+
tmp_obj = tempfile.NamedTemporaryFile(suffix='.obj', delete=True)
|
| 317 |
+
writer.SetFileName(tmp_obj.name)
|
| 318 |
+
writer.SetInputData(polydata)
|
| 319 |
+
writer.Update()
|
| 320 |
+
new_mesh = load_mesh_util(tmp_obj.name)
|
| 321 |
+
##########################
|
| 322 |
+
|
| 323 |
+
new_mesh.vertices = new_mesh.vertices * (2.0 / size) - 1.0 # normalize it to [-1, 1]
|
| 324 |
+
|
| 325 |
+
mesh = new_mesh
|
| 326 |
+
####################
|
| 327 |
+
|
| 328 |
+
except:
|
| 329 |
+
print("Error in tet.")
|
| 330 |
+
mesh = mesh
|
| 331 |
+
|
| 332 |
+
pc, _ = trimesh.sample.sample_surface(mesh, self.pc_num_pts)
|
| 333 |
+
|
| 334 |
+
result = {
|
| 335 |
+
'uid': uid
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
result['pc'] = torch.tensor(pc, dtype=torch.float32)
|
| 339 |
+
result['vertices'] = mesh.vertices
|
| 340 |
+
result['faces'] = mesh.faces
|
| 341 |
+
|
| 342 |
+
return result
|
| 343 |
+
|
| 344 |
+
def __getitem__(self, index):
|
| 345 |
+
|
| 346 |
+
gc.collect()
|
| 347 |
+
|
| 348 |
+
return self.get_model(self.data_list[index])
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class Correspondence_Demo_Dataset(Demo_Dataset):
|
| 352 |
+
def __init__(self, cfg):
|
| 353 |
+
super().__init__(cfg)
|
| 354 |
+
|
| 355 |
+
self.data_path = cfg.dataset.data_path
|
| 356 |
+
self.is_pc = cfg.is_pc
|
| 357 |
+
|
| 358 |
+
self.data_list = cfg.dataset.all_files
|
| 359 |
+
|
| 360 |
+
self.pc_num_pts = 100000
|
| 361 |
+
|
| 362 |
+
self.preprocess_mesh = cfg.preprocess_mesh
|
| 363 |
+
self.result_name = cfg.result_name
|
| 364 |
+
|
| 365 |
+
print("val dataset len:", len(self.data_list))
|
| 366 |
+
|
PartField/partfield/model/PVCNN/__pycache__/conv_pointnet.cpython-310.pyc
ADDED
|
Binary file (6.6 kB). View file
|
|
|
PartField/partfield/model/PVCNN/__pycache__/dnnlib_util.cpython-310.pyc
ADDED
|
Binary file (33.9 kB). View file
|
|
|
PartField/partfield/model/PVCNN/__pycache__/encoder_pc.cpython-310.pyc
ADDED
|
Binary file (6.09 kB). View file
|
|
|
PartField/partfield/model/PVCNN/__pycache__/pc_encoder.cpython-310.pyc
ADDED
|
Binary file (3.44 kB). View file
|
|
|
PartField/partfield/model/PVCNN/__pycache__/unet_3daware.cpython-310.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
PartField/partfield/model/PVCNN/conv_pointnet.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Taken from gensdf
|
| 3 |
+
https://github.com/princeton-computational-imaging/gensdf
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
# from dnnlib.util import printarr
|
| 10 |
+
try:
|
| 11 |
+
from torch_scatter import scatter_mean, scatter_max
|
| 12 |
+
except:
|
| 13 |
+
pass
|
| 14 |
+
# from .unet import UNet
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Resnet Blocks
|
| 21 |
+
class ResnetBlockFC(nn.Module):
|
| 22 |
+
''' Fully connected ResNet Block class.
|
| 23 |
+
Args:
|
| 24 |
+
size_in (int): input dimension
|
| 25 |
+
size_out (int): output dimension
|
| 26 |
+
size_h (int): hidden dimension
|
| 27 |
+
'''
|
| 28 |
+
|
| 29 |
+
def __init__(self, size_in, size_out=None, size_h=None):
|
| 30 |
+
super().__init__()
|
| 31 |
+
# Attributes
|
| 32 |
+
if size_out is None:
|
| 33 |
+
size_out = size_in
|
| 34 |
+
|
| 35 |
+
if size_h is None:
|
| 36 |
+
size_h = min(size_in, size_out)
|
| 37 |
+
|
| 38 |
+
self.size_in = size_in
|
| 39 |
+
self.size_h = size_h
|
| 40 |
+
self.size_out = size_out
|
| 41 |
+
# Submodules
|
| 42 |
+
self.fc_0 = nn.Linear(size_in, size_h)
|
| 43 |
+
self.fc_1 = nn.Linear(size_h, size_out)
|
| 44 |
+
self.actvn = nn.ReLU()
|
| 45 |
+
|
| 46 |
+
if size_in == size_out:
|
| 47 |
+
self.shortcut = None
|
| 48 |
+
else:
|
| 49 |
+
self.shortcut = nn.Linear(size_in, size_out, bias=False)
|
| 50 |
+
# Initialization
|
| 51 |
+
nn.init.zeros_(self.fc_1.weight)
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
net = self.fc_0(self.actvn(x))
|
| 55 |
+
dx = self.fc_1(self.actvn(net))
|
| 56 |
+
|
| 57 |
+
if self.shortcut is not None:
|
| 58 |
+
x_s = self.shortcut(x)
|
| 59 |
+
else:
|
| 60 |
+
x_s = x
|
| 61 |
+
|
| 62 |
+
return x_s + dx
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ConvPointnet(nn.Module):
|
| 66 |
+
''' PointNet-based encoder network with ResNet blocks for each point.
|
| 67 |
+
Number of input points are fixed.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
c_dim (int): dimension of latent code c
|
| 71 |
+
dim (int): input points dimension
|
| 72 |
+
hidden_dim (int): hidden dimension of the network
|
| 73 |
+
scatter_type (str): feature aggregation when doing local pooling
|
| 74 |
+
unet (bool): weather to use U-Net
|
| 75 |
+
unet_kwargs (str): U-Net parameters
|
| 76 |
+
plane_resolution (int): defined resolution for plane feature
|
| 77 |
+
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
|
| 78 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
| 79 |
+
n_blocks (int): number of blocks ResNetBlockFC layers
|
| 80 |
+
'''
|
| 81 |
+
|
| 82 |
+
def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
|
| 83 |
+
# unet=False, unet_kwargs=None,
|
| 84 |
+
plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.c_dim = c_dim
|
| 87 |
+
|
| 88 |
+
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
|
| 89 |
+
self.blocks = nn.ModuleList([
|
| 90 |
+
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
|
| 91 |
+
])
|
| 92 |
+
self.fc_c = nn.Linear(hidden_dim, c_dim)
|
| 93 |
+
|
| 94 |
+
self.actvn = nn.ReLU()
|
| 95 |
+
self.hidden_dim = hidden_dim
|
| 96 |
+
|
| 97 |
+
# if unet:
|
| 98 |
+
# self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
|
| 99 |
+
# else:
|
| 100 |
+
# self.unet = None
|
| 101 |
+
|
| 102 |
+
self.reso_plane = plane_resolution
|
| 103 |
+
self.plane_type = plane_type
|
| 104 |
+
self.padding = padding
|
| 105 |
+
|
| 106 |
+
if scatter_type == 'max':
|
| 107 |
+
self.scatter = scatter_max
|
| 108 |
+
elif scatter_type == 'mean':
|
| 109 |
+
self.scatter = scatter_mean
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# takes in "p": point cloud and "query": sdf_xyz
|
| 113 |
+
# sample plane features for unlabeled_query as well
|
| 114 |
+
def forward(self, p):#, query2):
|
| 115 |
+
batch_size, T, D = p.size()
|
| 116 |
+
|
| 117 |
+
# acquire the index for each point
|
| 118 |
+
coord = {}
|
| 119 |
+
index = {}
|
| 120 |
+
if 'xz' in self.plane_type:
|
| 121 |
+
coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
|
| 122 |
+
index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
|
| 123 |
+
if 'xy' in self.plane_type:
|
| 124 |
+
coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
|
| 125 |
+
index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
|
| 126 |
+
if 'yz' in self.plane_type:
|
| 127 |
+
coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
|
| 128 |
+
index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
net = self.fc_pos(p)
|
| 132 |
+
|
| 133 |
+
net = self.blocks[0](net)
|
| 134 |
+
for block in self.blocks[1:]:
|
| 135 |
+
pooled = self.pool_local(coord, index, net)
|
| 136 |
+
net = torch.cat([net, pooled], dim=2)
|
| 137 |
+
net = block(net)
|
| 138 |
+
|
| 139 |
+
c = self.fc_c(net)
|
| 140 |
+
|
| 141 |
+
fea = {}
|
| 142 |
+
plane_feat_sum = 0
|
| 143 |
+
#second_sum = 0
|
| 144 |
+
if 'xz' in self.plane_type:
|
| 145 |
+
fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
|
| 146 |
+
# plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')
|
| 147 |
+
#second_sum += self.sample_plane_feature(query2, fea['xz'], 'xz')
|
| 148 |
+
if 'xy' in self.plane_type:
|
| 149 |
+
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
|
| 150 |
+
# plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')
|
| 151 |
+
#second_sum += self.sample_plane_feature(query2, fea['xy'], 'xy')
|
| 152 |
+
if 'yz' in self.plane_type:
|
| 153 |
+
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
|
| 154 |
+
# plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')
|
| 155 |
+
#second_sum += self.sample_plane_feature(query2, fea['yz'], 'yz')
|
| 156 |
+
return fea
|
| 157 |
+
|
| 158 |
+
# return plane_feat_sum.transpose(2,1)#, second_sum.transpose(2,1)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def normalize_coordinate(self, p, padding=0.1, plane='xz'):
|
| 162 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
p (tensor): point
|
| 166 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
| 167 |
+
plane (str): plane feature type, ['xz', 'xy', 'yz']
|
| 168 |
+
'''
|
| 169 |
+
if plane == 'xz':
|
| 170 |
+
xy = p[:, :, [0, 2]]
|
| 171 |
+
elif plane =='xy':
|
| 172 |
+
xy = p[:, :, [0, 1]]
|
| 173 |
+
else:
|
| 174 |
+
xy = p[:, :, [1, 2]]
|
| 175 |
+
|
| 176 |
+
xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
|
| 177 |
+
xy_new = xy_new + 0.5 # range (0, 1)
|
| 178 |
+
|
| 179 |
+
# f there are outliers out of the range
|
| 180 |
+
if xy_new.max() >= 1:
|
| 181 |
+
xy_new[xy_new >= 1] = 1 - 10e-6
|
| 182 |
+
if xy_new.min() < 0:
|
| 183 |
+
xy_new[xy_new < 0] = 0.0
|
| 184 |
+
return xy_new
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def coordinate2index(self, x, reso):
|
| 188 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
| 189 |
+
Corresponds to our 3D model
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
x (tensor): coordinate
|
| 193 |
+
reso (int): defined resolution
|
| 194 |
+
coord_type (str): coordinate type
|
| 195 |
+
'''
|
| 196 |
+
x = (x * reso).long()
|
| 197 |
+
index = x[:, :, 0] + reso * x[:, :, 1]
|
| 198 |
+
index = index[:, None, :]
|
| 199 |
+
return index
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# xy is the normalized coordinates of the point cloud of each plane
|
| 203 |
+
# I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input
|
| 204 |
+
def pool_local(self, xy, index, c):
|
| 205 |
+
bs, fea_dim = c.size(0), c.size(2)
|
| 206 |
+
keys = xy.keys()
|
| 207 |
+
|
| 208 |
+
c_out = 0
|
| 209 |
+
for key in keys:
|
| 210 |
+
# scatter plane features from points
|
| 211 |
+
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
|
| 212 |
+
if self.scatter == scatter_max:
|
| 213 |
+
fea = fea[0]
|
| 214 |
+
# gather feature back to points
|
| 215 |
+
fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
|
| 216 |
+
c_out += fea
|
| 217 |
+
return c_out.permute(0, 2, 1)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def generate_plane_features(self, p, c, plane='xz'):
|
| 221 |
+
# acquire indices of features in plane
|
| 222 |
+
xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
|
| 223 |
+
index = self.coordinate2index(xy, self.reso_plane)
|
| 224 |
+
|
| 225 |
+
# scatter plane features from points
|
| 226 |
+
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
|
| 227 |
+
c = c.permute(0, 2, 1) # B x 512 x T
|
| 228 |
+
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
|
| 229 |
+
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
|
| 230 |
+
|
| 231 |
+
# printarr(fea_plane, c, p, xy, index)
|
| 232 |
+
# import pdb; pdb.set_trace()
|
| 233 |
+
|
| 234 |
+
# process the plane features with UNet
|
| 235 |
+
# if self.unet is not None:
|
| 236 |
+
# fea_plane = self.unet(fea_plane)
|
| 237 |
+
|
| 238 |
+
return fea_plane
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# sample_plane_feature function copied from /src/conv_onet/models/decoder.py
|
| 242 |
+
# uses values from plane_feature and pixel locations from vgrid to interpolate feature
|
| 243 |
+
def sample_plane_feature(self, query, plane_feature, plane):
|
| 244 |
+
xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
|
| 245 |
+
xy = xy[:, :, None].float()
|
| 246 |
+
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
|
| 247 |
+
sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
|
| 248 |
+
return sampled_feat
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
|
PartField/partfield/model/PVCNN/dnnlib_util.py
ADDED
|
@@ -0,0 +1,1074 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
"""Miscellaneous utility classes and functions."""
|
| 10 |
+
from collections import namedtuple
|
| 11 |
+
import time
|
| 12 |
+
import ctypes
|
| 13 |
+
import fnmatch
|
| 14 |
+
import importlib
|
| 15 |
+
import inspect
|
| 16 |
+
import numpy as np
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import shutil
|
| 20 |
+
import sys
|
| 21 |
+
import types
|
| 22 |
+
import io
|
| 23 |
+
import pickle
|
| 24 |
+
import re
|
| 25 |
+
# import requests
|
| 26 |
+
import html
|
| 27 |
+
import hashlib
|
| 28 |
+
import glob
|
| 29 |
+
import tempfile
|
| 30 |
+
import urllib
|
| 31 |
+
import urllib.request
|
| 32 |
+
import uuid
|
| 33 |
+
import boto3
|
| 34 |
+
import threading
|
| 35 |
+
from contextlib import ContextDecorator
|
| 36 |
+
from contextlib import contextmanager, nullcontext
|
| 37 |
+
|
| 38 |
+
from distutils.util import strtobool
|
| 39 |
+
from typing import Any, List, Tuple, Union
|
| 40 |
+
import importlib
|
| 41 |
+
from loguru import logger
|
| 42 |
+
# import wandb
|
| 43 |
+
import torch
|
| 44 |
+
import psutil
|
| 45 |
+
import subprocess
|
| 46 |
+
|
| 47 |
+
import random
|
| 48 |
+
import string
|
| 49 |
+
import pdb
|
| 50 |
+
|
| 51 |
+
# Util classes
|
| 52 |
+
# ------------------------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class EasyDict(dict):
|
| 56 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
| 57 |
+
|
| 58 |
+
def __getattr__(self, name: str) -> Any:
|
| 59 |
+
try:
|
| 60 |
+
return self[name]
|
| 61 |
+
except KeyError:
|
| 62 |
+
raise AttributeError(name)
|
| 63 |
+
|
| 64 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 65 |
+
self[name] = value
|
| 66 |
+
|
| 67 |
+
def __delattr__(self, name: str) -> None:
|
| 68 |
+
del self[name]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Logger(object):
|
| 72 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
| 73 |
+
|
| 74 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
| 75 |
+
self.file = None
|
| 76 |
+
|
| 77 |
+
if file_name is not None:
|
| 78 |
+
self.file = open(file_name, file_mode)
|
| 79 |
+
|
| 80 |
+
self.should_flush = should_flush
|
| 81 |
+
self.stdout = sys.stdout
|
| 82 |
+
self.stderr = sys.stderr
|
| 83 |
+
|
| 84 |
+
sys.stdout = self
|
| 85 |
+
sys.stderr = self
|
| 86 |
+
|
| 87 |
+
def __enter__(self) -> "Logger":
|
| 88 |
+
return self
|
| 89 |
+
|
| 90 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
| 91 |
+
self.close()
|
| 92 |
+
|
| 93 |
+
def write(self, text: Union[str, bytes]) -> None:
|
| 94 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
| 95 |
+
if isinstance(text, bytes):
|
| 96 |
+
text = text.decode()
|
| 97 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
if self.file is not None:
|
| 101 |
+
self.file.write(text)
|
| 102 |
+
|
| 103 |
+
self.stdout.write(text)
|
| 104 |
+
|
| 105 |
+
if self.should_flush:
|
| 106 |
+
self.flush()
|
| 107 |
+
|
| 108 |
+
def flush(self) -> None:
|
| 109 |
+
"""Flush written text to both stdout and a file, if open."""
|
| 110 |
+
if self.file is not None:
|
| 111 |
+
self.file.flush()
|
| 112 |
+
|
| 113 |
+
self.stdout.flush()
|
| 114 |
+
|
| 115 |
+
def close(self) -> None:
|
| 116 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
| 117 |
+
self.flush()
|
| 118 |
+
|
| 119 |
+
# if using multiple loggers, prevent closing in wrong order
|
| 120 |
+
if sys.stdout is self:
|
| 121 |
+
sys.stdout = self.stdout
|
| 122 |
+
if sys.stderr is self:
|
| 123 |
+
sys.stderr = self.stderr
|
| 124 |
+
|
| 125 |
+
if self.file is not None:
|
| 126 |
+
self.file.close()
|
| 127 |
+
self.file = None
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Cache directories
|
| 131 |
+
# ------------------------------------------------------------------------------------------
|
| 132 |
+
|
| 133 |
+
_dnnlib_cache_dir = None
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def set_cache_dir(path: str) -> None:
|
| 137 |
+
global _dnnlib_cache_dir
|
| 138 |
+
_dnnlib_cache_dir = path
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def make_cache_dir_path(*paths: str) -> str:
|
| 142 |
+
if _dnnlib_cache_dir is not None:
|
| 143 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
| 144 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
| 145 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
| 146 |
+
if 'HOME' in os.environ:
|
| 147 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
| 148 |
+
if 'USERPROFILE' in os.environ:
|
| 149 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
| 150 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Small util functions
|
| 154 |
+
# ------------------------------------------------------------------------------------------
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def format_time(seconds: Union[int, float]) -> str:
|
| 158 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 159 |
+
s = int(np.rint(seconds))
|
| 160 |
+
|
| 161 |
+
if s < 60:
|
| 162 |
+
return "{0}s".format(s)
|
| 163 |
+
elif s < 60 * 60:
|
| 164 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 165 |
+
elif s < 24 * 60 * 60:
|
| 166 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
| 167 |
+
else:
|
| 168 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def format_time_brief(seconds: Union[int, float]) -> str:
|
| 172 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 173 |
+
s = int(np.rint(seconds))
|
| 174 |
+
|
| 175 |
+
if s < 60:
|
| 176 |
+
return "{0}s".format(s)
|
| 177 |
+
elif s < 60 * 60:
|
| 178 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 179 |
+
elif s < 24 * 60 * 60:
|
| 180 |
+
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
| 181 |
+
else:
|
| 182 |
+
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def ask_yes_no(question: str) -> bool:
|
| 186 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
| 187 |
+
while True:
|
| 188 |
+
try:
|
| 189 |
+
print("{0} [y/n]".format(question))
|
| 190 |
+
return strtobool(input().lower())
|
| 191 |
+
except ValueError:
|
| 192 |
+
pass
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def tuple_product(t: Tuple) -> Any:
|
| 196 |
+
"""Calculate the product of the tuple elements."""
|
| 197 |
+
result = 1
|
| 198 |
+
|
| 199 |
+
for v in t:
|
| 200 |
+
result *= v
|
| 201 |
+
|
| 202 |
+
return result
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
_str_to_ctype = {
|
| 206 |
+
"uint8": ctypes.c_ubyte,
|
| 207 |
+
"uint16": ctypes.c_uint16,
|
| 208 |
+
"uint32": ctypes.c_uint32,
|
| 209 |
+
"uint64": ctypes.c_uint64,
|
| 210 |
+
"int8": ctypes.c_byte,
|
| 211 |
+
"int16": ctypes.c_int16,
|
| 212 |
+
"int32": ctypes.c_int32,
|
| 213 |
+
"int64": ctypes.c_int64,
|
| 214 |
+
"float32": ctypes.c_float,
|
| 215 |
+
"float64": ctypes.c_double
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
| 220 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
| 221 |
+
type_str = None
|
| 222 |
+
|
| 223 |
+
if isinstance(type_obj, str):
|
| 224 |
+
type_str = type_obj
|
| 225 |
+
elif hasattr(type_obj, "__name__"):
|
| 226 |
+
type_str = type_obj.__name__
|
| 227 |
+
elif hasattr(type_obj, "name"):
|
| 228 |
+
type_str = type_obj.name
|
| 229 |
+
else:
|
| 230 |
+
raise RuntimeError("Cannot infer type name from input")
|
| 231 |
+
|
| 232 |
+
assert type_str in _str_to_ctype.keys()
|
| 233 |
+
|
| 234 |
+
my_dtype = np.dtype(type_str)
|
| 235 |
+
my_ctype = _str_to_ctype[type_str]
|
| 236 |
+
|
| 237 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
| 238 |
+
|
| 239 |
+
return my_dtype, my_ctype
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def is_pickleable(obj: Any) -> bool:
|
| 243 |
+
try:
|
| 244 |
+
with io.BytesIO() as stream:
|
| 245 |
+
pickle.dump(obj, stream)
|
| 246 |
+
return True
|
| 247 |
+
except:
|
| 248 |
+
return False
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# Functionality to import modules/objects by name, and call functions by name
|
| 252 |
+
# ------------------------------------------------------------------------------------------
|
| 253 |
+
|
| 254 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
| 255 |
+
"""Searches for the underlying module behind the name to some python object.
|
| 256 |
+
Returns the module and the object name (original name with module part removed)."""
|
| 257 |
+
|
| 258 |
+
# allow convenience shorthands, substitute them by full names
|
| 259 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
| 260 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
| 261 |
+
|
| 262 |
+
# list alternatives for (module_name, local_obj_name)
|
| 263 |
+
parts = obj_name.split(".")
|
| 264 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
| 265 |
+
|
| 266 |
+
# try each alternative in turn
|
| 267 |
+
for module_name, local_obj_name in name_pairs:
|
| 268 |
+
try:
|
| 269 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 270 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 271 |
+
return module, local_obj_name
|
| 272 |
+
except:
|
| 273 |
+
pass
|
| 274 |
+
|
| 275 |
+
# maybe some of the modules themselves contain errors?
|
| 276 |
+
for module_name, _local_obj_name in name_pairs:
|
| 277 |
+
try:
|
| 278 |
+
importlib.import_module(module_name) # may raise ImportError
|
| 279 |
+
except ImportError:
|
| 280 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
| 281 |
+
raise
|
| 282 |
+
|
| 283 |
+
# maybe the requested attribute is missing?
|
| 284 |
+
for module_name, local_obj_name in name_pairs:
|
| 285 |
+
try:
|
| 286 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 287 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 288 |
+
except ImportError:
|
| 289 |
+
pass
|
| 290 |
+
|
| 291 |
+
# we are out of luck, but we have no idea why
|
| 292 |
+
raise ImportError(obj_name)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
| 296 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
| 297 |
+
if obj_name == '':
|
| 298 |
+
return module
|
| 299 |
+
obj = module
|
| 300 |
+
for part in obj_name.split("."):
|
| 301 |
+
obj = getattr(obj, part)
|
| 302 |
+
return obj
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def get_obj_by_name(name: str) -> Any:
|
| 306 |
+
"""Finds the python object with the given name."""
|
| 307 |
+
module, obj_name = get_module_from_obj_name(name)
|
| 308 |
+
return get_obj_from_module(module, obj_name)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
| 312 |
+
"""Finds the python object with the given name and calls it as a function."""
|
| 313 |
+
assert func_name is not None
|
| 314 |
+
func_obj = get_obj_by_name(func_name)
|
| 315 |
+
assert callable(func_obj)
|
| 316 |
+
return func_obj(*args, **kwargs)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
| 320 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
| 321 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
| 325 |
+
"""Get the directory path of the module containing the given object name."""
|
| 326 |
+
module, _ = get_module_from_obj_name(obj_name)
|
| 327 |
+
return os.path.dirname(inspect.getfile(module))
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def is_top_level_function(obj: Any) -> bool:
|
| 331 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
| 332 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def get_top_level_function_name(obj: Any) -> str:
|
| 336 |
+
"""Return the fully-qualified name of a top-level function."""
|
| 337 |
+
assert is_top_level_function(obj)
|
| 338 |
+
module = obj.__module__
|
| 339 |
+
if module == '__main__':
|
| 340 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
| 341 |
+
return module + "." + obj.__name__
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
# File system helpers
|
| 345 |
+
# ------------------------------------------------------------------------------------------
|
| 346 |
+
|
| 347 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
| 348 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
| 349 |
+
Returns list of tuples containing both absolute and relative paths."""
|
| 350 |
+
assert os.path.isdir(dir_path)
|
| 351 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
| 352 |
+
|
| 353 |
+
if ignores is None:
|
| 354 |
+
ignores = []
|
| 355 |
+
|
| 356 |
+
result = []
|
| 357 |
+
|
| 358 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
| 359 |
+
for ignore_ in ignores:
|
| 360 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
| 361 |
+
|
| 362 |
+
# dirs need to be edited in-place
|
| 363 |
+
for d in dirs_to_remove:
|
| 364 |
+
dirs.remove(d)
|
| 365 |
+
|
| 366 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
| 367 |
+
|
| 368 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
| 369 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
| 370 |
+
|
| 371 |
+
if add_base_to_relative:
|
| 372 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
| 373 |
+
|
| 374 |
+
assert len(absolute_paths) == len(relative_paths)
|
| 375 |
+
result += zip(absolute_paths, relative_paths)
|
| 376 |
+
|
| 377 |
+
return result
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
| 381 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
| 382 |
+
Will create all necessary directories."""
|
| 383 |
+
for file in files:
|
| 384 |
+
target_dir_name = os.path.dirname(file[1])
|
| 385 |
+
|
| 386 |
+
# will create all intermediate-level directories
|
| 387 |
+
if not os.path.exists(target_dir_name):
|
| 388 |
+
os.makedirs(target_dir_name)
|
| 389 |
+
|
| 390 |
+
shutil.copyfile(file[0], file[1])
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# URL helpers
|
| 394 |
+
# ------------------------------------------------------------------------------------------
|
| 395 |
+
|
| 396 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
| 397 |
+
"""Determine whether the given object is a valid URL string."""
|
| 398 |
+
if not isinstance(obj, str) or not "://" in obj:
|
| 399 |
+
return False
|
| 400 |
+
if allow_file_urls and obj.startswith('file://'):
|
| 401 |
+
return True
|
| 402 |
+
try:
|
| 403 |
+
res = requests.compat.urlparse(obj)
|
| 404 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 405 |
+
return False
|
| 406 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
| 407 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 408 |
+
return False
|
| 409 |
+
except:
|
| 410 |
+
return False
|
| 411 |
+
return True
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
| 415 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
| 416 |
+
assert num_attempts >= 1
|
| 417 |
+
assert not (return_filename and (not cache))
|
| 418 |
+
|
| 419 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
| 420 |
+
if not re.match('^[a-z]+://', url):
|
| 421 |
+
return url if return_filename else open(url, "rb")
|
| 422 |
+
|
| 423 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
| 424 |
+
# arise on Windows:
|
| 425 |
+
#
|
| 426 |
+
# file:///c:/foo.txt
|
| 427 |
+
#
|
| 428 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
| 429 |
+
# invalid. Drop the forward slash for such pathnames.
|
| 430 |
+
#
|
| 431 |
+
# If you touch this code path, you should test it on both Linux and
|
| 432 |
+
# Windows.
|
| 433 |
+
#
|
| 434 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
| 435 |
+
# but that converts forward slashes to backslashes and this causes
|
| 436 |
+
# its own set of problems.
|
| 437 |
+
if url.startswith('file://'):
|
| 438 |
+
filename = urllib.parse.urlparse(url).path
|
| 439 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
| 440 |
+
filename = filename[1:]
|
| 441 |
+
return filename if return_filename else open(filename, "rb")
|
| 442 |
+
|
| 443 |
+
assert is_url(url)
|
| 444 |
+
|
| 445 |
+
# Lookup from cache.
|
| 446 |
+
if cache_dir is None:
|
| 447 |
+
cache_dir = make_cache_dir_path('downloads')
|
| 448 |
+
|
| 449 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
| 450 |
+
if cache:
|
| 451 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
| 452 |
+
if len(cache_files) == 1:
|
| 453 |
+
filename = cache_files[0]
|
| 454 |
+
return filename if return_filename else open(filename, "rb")
|
| 455 |
+
|
| 456 |
+
# Download.
|
| 457 |
+
url_name = None
|
| 458 |
+
url_data = None
|
| 459 |
+
with requests.Session() as session:
|
| 460 |
+
if verbose:
|
| 461 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
| 462 |
+
for attempts_left in reversed(range(num_attempts)):
|
| 463 |
+
try:
|
| 464 |
+
with session.get(url) as res:
|
| 465 |
+
res.raise_for_status()
|
| 466 |
+
if len(res.content) == 0:
|
| 467 |
+
raise IOError("No data received")
|
| 468 |
+
|
| 469 |
+
if len(res.content) < 8192:
|
| 470 |
+
content_str = res.content.decode("utf-8")
|
| 471 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
| 472 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
| 473 |
+
if len(links) == 1:
|
| 474 |
+
url = requests.compat.urljoin(url, links[0])
|
| 475 |
+
raise IOError("Google Drive virus checker nag")
|
| 476 |
+
if "Google Drive - Quota exceeded" in content_str:
|
| 477 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
| 478 |
+
|
| 479 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
| 480 |
+
url_name = match[1] if match else url
|
| 481 |
+
url_data = res.content
|
| 482 |
+
if verbose:
|
| 483 |
+
print(" done")
|
| 484 |
+
break
|
| 485 |
+
except KeyboardInterrupt:
|
| 486 |
+
raise
|
| 487 |
+
except:
|
| 488 |
+
if not attempts_left:
|
| 489 |
+
if verbose:
|
| 490 |
+
print(" failed")
|
| 491 |
+
raise
|
| 492 |
+
if verbose:
|
| 493 |
+
print(".", end="", flush=True)
|
| 494 |
+
|
| 495 |
+
# Save to cache.
|
| 496 |
+
if cache:
|
| 497 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
| 498 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
| 499 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
| 500 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 501 |
+
with open(temp_file, "wb") as f:
|
| 502 |
+
f.write(url_data)
|
| 503 |
+
os.replace(temp_file, cache_file) # atomic
|
| 504 |
+
if return_filename:
|
| 505 |
+
return cache_file
|
| 506 |
+
|
| 507 |
+
# Return data as file object.
|
| 508 |
+
assert not return_filename
|
| 509 |
+
return io.BytesIO(url_data)
|
| 510 |
+
|
| 511 |
+
# ------------------------------------------------------------------------------------------
|
| 512 |
+
# util function modified from https://github.com/nv-tlabs/LION/blob/0467d2199076e95a7e88bafd99dcd7d48a04b4a7/utils/model_helper.py
|
| 513 |
+
def import_class(model_str):
|
| 514 |
+
from torch_utils.dist_utils import is_rank0
|
| 515 |
+
if is_rank0():
|
| 516 |
+
logger.info('import: {}', model_str)
|
| 517 |
+
p, m = model_str.rsplit('.', 1)
|
| 518 |
+
mod = importlib.import_module(p)
|
| 519 |
+
Model = getattr(mod, m)
|
| 520 |
+
return Model
|
| 521 |
+
|
| 522 |
+
class ScopedTorchProfiler(ContextDecorator):
|
| 523 |
+
"""
|
| 524 |
+
Marks ranges for both nvtx profiling (with nsys) and torch autograd profiler
|
| 525 |
+
"""
|
| 526 |
+
__global_counts = {}
|
| 527 |
+
enabled=False
|
| 528 |
+
|
| 529 |
+
def __init__(self, unique_name: str):
|
| 530 |
+
"""
|
| 531 |
+
Names must be unique!
|
| 532 |
+
"""
|
| 533 |
+
ScopedTorchProfiler.__global_counts[unique_name] = 0
|
| 534 |
+
self._name = unique_name
|
| 535 |
+
self._autograd_scope = torch.profiler.record_function(unique_name)
|
| 536 |
+
|
| 537 |
+
def __enter__(self):
|
| 538 |
+
if ScopedTorchProfiler.enabled:
|
| 539 |
+
torch.cuda.nvtx.range_push(self._name)
|
| 540 |
+
self._autograd_scope.__enter__()
|
| 541 |
+
|
| 542 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 543 |
+
self._autograd_scope.__exit__(exc_type, exc_value, traceback)
|
| 544 |
+
if ScopedTorchProfiler.enabled:
|
| 545 |
+
torch.cuda.nvtx.range_pop()
|
| 546 |
+
|
| 547 |
+
class TimingsMonitor():
|
| 548 |
+
CUDATimer = namedtuple('CUDATimer', ['start', 'end'])
|
| 549 |
+
def __init__(self, device, enabled=True, timing_names:List[str]=[], cuda_timing_names:List[str]=[]):
|
| 550 |
+
"""
|
| 551 |
+
Usage:
|
| 552 |
+
tmonitor = TimingsMonitor(device)
|
| 553 |
+
for i in range(n_iter):
|
| 554 |
+
# Record arbitrary scopes
|
| 555 |
+
with tmonitor.timing_scope('regular_scope_name'):
|
| 556 |
+
...
|
| 557 |
+
with tmonitor.cuda_timing_scope('nested_scope_name'):
|
| 558 |
+
...
|
| 559 |
+
with tmonitor.cuda_timing_scope('cuda_scope_name'):
|
| 560 |
+
...
|
| 561 |
+
tmonitor.record_timing('duration_name', end_time - start_time)
|
| 562 |
+
|
| 563 |
+
# Gather timings
|
| 564 |
+
tmonitor.record_all_cuda_timings()
|
| 565 |
+
tmonitor.update_all_averages()
|
| 566 |
+
averages = tmonitor.get_average_timings()
|
| 567 |
+
all_timings = tmonitor.get_timings()
|
| 568 |
+
|
| 569 |
+
Two types of timers, standard report timing and cuda timings.
|
| 570 |
+
Cuda timing supports scoped context manager cuda_event_scope.
|
| 571 |
+
Args:
|
| 572 |
+
device: device to time on (needed for cuda timers)
|
| 573 |
+
# enabled: HACK to only report timings from rank 0, set enabled=(global_rank==0)
|
| 574 |
+
timing_names: timings to report optional (will auto add new names)
|
| 575 |
+
cuda_timing_names: cuda periods to time optional (will auto add new names)
|
| 576 |
+
"""
|
| 577 |
+
self.enabled=enabled
|
| 578 |
+
self.device = device
|
| 579 |
+
|
| 580 |
+
# Normal timing
|
| 581 |
+
# self.all_timings_dict = {k:None for k in timing_names + cuda_timing_names}
|
| 582 |
+
self.all_timings_dict = {}
|
| 583 |
+
self.avg_meter_dict = {}
|
| 584 |
+
|
| 585 |
+
# Cuda event timers to measure time spent on pushing data to gpu and on training step
|
| 586 |
+
self.cuda_event_timers = {}
|
| 587 |
+
|
| 588 |
+
for k in timing_names:
|
| 589 |
+
self.add_new_timing(k)
|
| 590 |
+
|
| 591 |
+
for k in cuda_timing_names:
|
| 592 |
+
self.add_new_cuda_timing(k)
|
| 593 |
+
|
| 594 |
+
# Running averages
|
| 595 |
+
# self.avg_meter_dict = {k:AverageMeter() for k in self.all_timings_dict}
|
| 596 |
+
|
| 597 |
+
def add_new_timing(self, name):
|
| 598 |
+
self.avg_meter_dict[name] = AverageMeter()
|
| 599 |
+
self.all_timings_dict[name] = None
|
| 600 |
+
|
| 601 |
+
def add_new_cuda_timing(self, name):
|
| 602 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
| 603 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
| 604 |
+
self.cuda_event_timers[name] = self.CUDATimer(start=start_event, end=end_event)
|
| 605 |
+
self.add_new_timing(name)
|
| 606 |
+
|
| 607 |
+
def clear_timings(self):
|
| 608 |
+
self.all_timings_dict = {k:None for k in self.all_timings_dict}
|
| 609 |
+
|
| 610 |
+
def get_timings(self):
|
| 611 |
+
return self.all_timings_dict
|
| 612 |
+
|
| 613 |
+
def get_average_timings(self):
|
| 614 |
+
return {k:v.avg for k,v in self.avg_meter_dict.items()}
|
| 615 |
+
|
| 616 |
+
def update_all_averages(self):
|
| 617 |
+
"""
|
| 618 |
+
Once per iter, when timings have been finished recording, one should
|
| 619 |
+
call update_average_iter to keep running average of timings.
|
| 620 |
+
"""
|
| 621 |
+
for k,v in self.all_timings_dict.items():
|
| 622 |
+
if v is None:
|
| 623 |
+
print("none_timing", k)
|
| 624 |
+
continue
|
| 625 |
+
self.avg_meter_dict[k].update(v)
|
| 626 |
+
|
| 627 |
+
def record_timing(self, name, value):
|
| 628 |
+
if name not in self.all_timings_dict: self.add_new_timing(name)
|
| 629 |
+
# assert name in self.all_timings_dict
|
| 630 |
+
self.all_timings_dict[name] = value
|
| 631 |
+
|
| 632 |
+
def _record_cuda_event_start(self, name):
|
| 633 |
+
if name in self.cuda_event_timers:
|
| 634 |
+
self.cuda_event_timers[name].start.record(
|
| 635 |
+
torch.cuda.current_stream(self.device))
|
| 636 |
+
|
| 637 |
+
def _record_cuda_event_end(self, name):
|
| 638 |
+
if name in self.cuda_event_timers:
|
| 639 |
+
self.cuda_event_timers[name].end.record(
|
| 640 |
+
torch.cuda.current_stream(self.device))
|
| 641 |
+
|
| 642 |
+
@contextmanager
|
| 643 |
+
def cuda_timing_scope(self, name, profile=True):
|
| 644 |
+
if name not in self.all_timings_dict: self.add_new_cuda_timing(name)
|
| 645 |
+
with ScopedTorchProfiler(name) if profile else nullcontext():
|
| 646 |
+
self._record_cuda_event_start(name)
|
| 647 |
+
try:
|
| 648 |
+
yield
|
| 649 |
+
finally:
|
| 650 |
+
self._record_cuda_event_end(name)
|
| 651 |
+
|
| 652 |
+
@contextmanager
|
| 653 |
+
def timing_scope(self, name, profile=True):
|
| 654 |
+
if name not in self.all_timings_dict: self.add_new_timing(name)
|
| 655 |
+
with ScopedTorchProfiler(name) if profile else nullcontext():
|
| 656 |
+
start_time = time.time()
|
| 657 |
+
try:
|
| 658 |
+
yield
|
| 659 |
+
finally:
|
| 660 |
+
self.record_timing(name, time.time()-start_time)
|
| 661 |
+
|
| 662 |
+
def record_all_cuda_timings(self):
|
| 663 |
+
""" After all the cuda events call this to synchronize and record down the cuda timings. """
|
| 664 |
+
for k, events in self.cuda_event_timers.items():
|
| 665 |
+
with torch.no_grad():
|
| 666 |
+
events.end.synchronize()
|
| 667 |
+
# Convert to seconds
|
| 668 |
+
time_elapsed = events.start.elapsed_time(events.end)/1000.
|
| 669 |
+
self.all_timings_dict[k] = time_elapsed
|
| 670 |
+
|
| 671 |
+
def init_s3(config_file):
|
| 672 |
+
config = json.load(open(config_file, 'r'))
|
| 673 |
+
s3_client = boto3.client("s3", **config)
|
| 674 |
+
return s3_client
|
| 675 |
+
|
| 676 |
+
def download_from_s3(file_path, target_path, cfg):
|
| 677 |
+
tic = time.time()
|
| 678 |
+
s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
|
| 679 |
+
bucket_name = file_path.split('/')[2]
|
| 680 |
+
file_key = file_path.split(bucket_name+'/')[-1]
|
| 681 |
+
print(bucket_name, file_key)
|
| 682 |
+
s3_client.download_file(bucket_name, file_key, target_path)
|
| 683 |
+
logger.info(f'finish download from ! s3://{bucket_name}/{file_key} to {target_path} %.1f sec'%(
|
| 684 |
+
time.time() - tic))
|
| 685 |
+
|
| 686 |
+
def upload_to_s3(buffer, bucket_name, key, config_dict):
|
| 687 |
+
logger.info(f'start upload_to_s3! bucket_name={bucket_name}, key={key}')
|
| 688 |
+
tic = time.time()
|
| 689 |
+
s3 = boto3.client('s3', **config_dict)
|
| 690 |
+
s3.put_object(Bucket=bucket_name, Key=key, Body=buffer.getvalue())
|
| 691 |
+
logger.info(f'finish upload_to_s3! s3://{bucket_name}/{key} %.1f sec'%(time.time() - tic))
|
| 692 |
+
|
| 693 |
+
def write_ckpt_to_s3(cfg, all_model_dict, ckpt_name):
|
| 694 |
+
buffer = io.BytesIO()
|
| 695 |
+
tic = time.time()
|
| 696 |
+
torch.save(all_model_dict, buffer) # take ~0.25 sec
|
| 697 |
+
# logger.info('write ckpt to buffer: %.2f sec'%(time.time() - tic))
|
| 698 |
+
group, name = cfg.outdir.rstrip("/").split("/")[-2:]
|
| 699 |
+
key = f"checkpoints/{group}/{name}/ckpt/{ckpt_name}"
|
| 700 |
+
bucket_name = cfg.checkpoint.write_s3_bucket
|
| 701 |
+
|
| 702 |
+
s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
|
| 703 |
+
|
| 704 |
+
config_dict = json.load(open(cfg.checkpoint.write_s3_config, 'r'))
|
| 705 |
+
upload_thread = threading.Thread(target=upload_to_s3, args=(buffer, bucket_name, key, config_dict))
|
| 706 |
+
upload_thread.start()
|
| 707 |
+
path = f"s3://{bucket_name}/{key}"
|
| 708 |
+
return path
|
| 709 |
+
|
| 710 |
+
def upload_file_to_s3(cfg, file_path, key_name=None):
|
| 711 |
+
# file_path is the local file path, can be a yaml file
|
| 712 |
+
# this function is used to upload the ckecpoint only
|
| 713 |
+
tic = time.time()
|
| 714 |
+
group, name = cfg.outdir.rstrip("/").split("/")[-2:]
|
| 715 |
+
if key_name is None:
|
| 716 |
+
key = os.path.basename(file_path)
|
| 717 |
+
key = f"checkpoints/{group}/{name}/{key}"
|
| 718 |
+
bucket_name = cfg.checkpoint.write_s3_bucket
|
| 719 |
+
s3_client = init_s3(cfg.checkpoint.write_s3_config)
|
| 720 |
+
# Upload the file
|
| 721 |
+
with open(file_path, 'rb') as f:
|
| 722 |
+
s3_client.upload_fileobj(f, bucket_name, key)
|
| 723 |
+
full_s3_path = f"s3://{bucket_name}/{key}"
|
| 724 |
+
logger.info(f'upload_to_s3: {file_path} {full_s3_path} | use time: {time.time()-tic}')
|
| 725 |
+
|
| 726 |
+
return full_s3_path
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def load_from_s3(file_path, cfg, load_fn):
|
| 730 |
+
"""
|
| 731 |
+
ckpt_path example:
|
| 732 |
+
s3://xzeng/checkpoints/2023_0413/vae_kl_5e-1/ckpt/snapshot_epo000163_iter164000.pt
|
| 733 |
+
"""
|
| 734 |
+
s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
|
| 735 |
+
bucket_name = file_path.split("s3://")[-1].split('/')[0]
|
| 736 |
+
key = file_path.split(f'{bucket_name}/')[-1]
|
| 737 |
+
# logger.info(f"-> try to load s3://{bucket_name}/{key} ")
|
| 738 |
+
tic = time.time()
|
| 739 |
+
for attemp in range(10):
|
| 740 |
+
try:
|
| 741 |
+
# Download the state dict from S3 into memory (as a binary stream)
|
| 742 |
+
with io.BytesIO() as buffer:
|
| 743 |
+
s3_client.download_fileobj(bucket_name, key, buffer)
|
| 744 |
+
buffer.seek(0)
|
| 745 |
+
|
| 746 |
+
# Load the state dict into a PyTorch model
|
| 747 |
+
# out = torch.load(buffer, map_location=torch.device("cpu"))
|
| 748 |
+
out = load_fn(buffer)
|
| 749 |
+
break
|
| 750 |
+
except:
|
| 751 |
+
logger.info(f"fail to load s3://{bucket_name}/{key} attemp: {attemp}")
|
| 752 |
+
from torch_utils.dist_utils import is_rank0
|
| 753 |
+
if is_rank0():
|
| 754 |
+
logger.info(f'loaded {file_path} | use time: {time.time()-tic:.1f} sec')
|
| 755 |
+
return out
|
| 756 |
+
|
| 757 |
+
def load_torch_dict_from_s3(ckpt_path, cfg):
|
| 758 |
+
"""
|
| 759 |
+
ckpt_path example:
|
| 760 |
+
s3://xzeng/checkpoints/2023_0413/vae_kl_5e-1/ckpt/snapshot_epo000163_iter164000.pt
|
| 761 |
+
"""
|
| 762 |
+
s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
|
| 763 |
+
bucket_name = ckpt_path.split("s3://")[-1].split('/')[0]
|
| 764 |
+
key = ckpt_path.split(f'{bucket_name}/')[-1]
|
| 765 |
+
for attemp in range(10):
|
| 766 |
+
try:
|
| 767 |
+
# Download the state dict from S3 into memory (as a binary stream)
|
| 768 |
+
with io.BytesIO() as buffer:
|
| 769 |
+
s3_client.download_fileobj(bucket_name, key, buffer)
|
| 770 |
+
buffer.seek(0)
|
| 771 |
+
|
| 772 |
+
# Load the state dict into a PyTorch model
|
| 773 |
+
out = torch.load(buffer, map_location=torch.device("cpu"))
|
| 774 |
+
break
|
| 775 |
+
except:
|
| 776 |
+
logger.info(f"fail to load s3://{bucket_name}/{key} attemp: {attemp}")
|
| 777 |
+
return out
|
| 778 |
+
|
| 779 |
+
def count_parameters_in_M(model):
|
| 780 |
+
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
|
| 781 |
+
|
| 782 |
+
def printarr(*arrs, float_width=6, **kwargs):
|
| 783 |
+
"""
|
| 784 |
+
Print a pretty table giving name, shape, dtype, type, and content information for input tensors or scalars.
|
| 785 |
+
|
| 786 |
+
Call like: printarr(my_arr, some_other_arr, maybe_a_scalar). Accepts a variable number of arguments.
|
| 787 |
+
|
| 788 |
+
Inputs can be:
|
| 789 |
+
- Numpy tensor arrays
|
| 790 |
+
- Pytorch tensor arrays
|
| 791 |
+
- Jax tensor arrays
|
| 792 |
+
- Python ints / floats
|
| 793 |
+
- None
|
| 794 |
+
|
| 795 |
+
It may also work with other array-like types, but they have not been tested.
|
| 796 |
+
|
| 797 |
+
Use the `float_width` option specify the precision to which floating point types are printed.
|
| 798 |
+
|
| 799 |
+
Author: Nicholas Sharp (nmwsharp.com)
|
| 800 |
+
Canonical source: https://gist.github.com/nmwsharp/54d04af87872a4988809f128e1a1d233
|
| 801 |
+
License: This snippet may be used under an MIT license, and it is also released into the public domain.
|
| 802 |
+
Please retain this docstring as a reference.
|
| 803 |
+
"""
|
| 804 |
+
|
| 805 |
+
frame = inspect.currentframe().f_back
|
| 806 |
+
default_name = "[temporary]"
|
| 807 |
+
|
| 808 |
+
## helpers to gather data about each array
|
| 809 |
+
def name_from_outer_scope(a):
|
| 810 |
+
if a is None:
|
| 811 |
+
return '[None]'
|
| 812 |
+
name = default_name
|
| 813 |
+
for k, v in frame.f_locals.items():
|
| 814 |
+
if v is a:
|
| 815 |
+
name = k
|
| 816 |
+
break
|
| 817 |
+
return name
|
| 818 |
+
|
| 819 |
+
def type_strip(type_str):
|
| 820 |
+
return type_str.lstrip('<class ').rstrip('>').replace('torch.', '').strip("'")
|
| 821 |
+
|
| 822 |
+
def dtype_str(a):
|
| 823 |
+
if a is None:
|
| 824 |
+
return 'None'
|
| 825 |
+
if isinstance(a, int):
|
| 826 |
+
return 'int'
|
| 827 |
+
if isinstance(a, float):
|
| 828 |
+
return 'float'
|
| 829 |
+
if isinstance(a, list) and len(a)>0:
|
| 830 |
+
return type_strip(str(type(a[0])))
|
| 831 |
+
if hasattr(a, 'dtype'):
|
| 832 |
+
return type_strip(str(a.dtype))
|
| 833 |
+
else:
|
| 834 |
+
return ''
|
| 835 |
+
def shape_str(a):
|
| 836 |
+
if a is None:
|
| 837 |
+
return 'N/A'
|
| 838 |
+
if isinstance(a, int):
|
| 839 |
+
return 'scalar'
|
| 840 |
+
if isinstance(a, float):
|
| 841 |
+
return 'scalar'
|
| 842 |
+
if isinstance(a, list):
|
| 843 |
+
return f"[{shape_str(a[0]) if len(a)>0 else '?'}]*{len(a)}"
|
| 844 |
+
if hasattr(a, 'shape'):
|
| 845 |
+
return str(tuple(a.shape))
|
| 846 |
+
else:
|
| 847 |
+
return ''
|
| 848 |
+
def type_str(a):
|
| 849 |
+
return type_strip(str(type(a))) # TODO this is is weird... what's the better way?
|
| 850 |
+
def device_str(a):
|
| 851 |
+
if hasattr(a, 'device'):
|
| 852 |
+
device_str = str(a.device)
|
| 853 |
+
if len(device_str) < 10:
|
| 854 |
+
# heuristic: jax returns some goofy long string we don't want, ignore it
|
| 855 |
+
return device_str
|
| 856 |
+
return ""
|
| 857 |
+
def format_float(x):
|
| 858 |
+
return f"{x:{float_width}g}"
|
| 859 |
+
def minmaxmean_str(a):
|
| 860 |
+
if a is None:
|
| 861 |
+
return ('N/A', 'N/A', 'N/A', 'N/A')
|
| 862 |
+
if isinstance(a, int) or isinstance(a, float):
|
| 863 |
+
return (format_float(a),)*4
|
| 864 |
+
|
| 865 |
+
# compute min/max/mean. if anything goes wrong, just print 'N/A'
|
| 866 |
+
min_str = "N/A"
|
| 867 |
+
try: min_str = format_float(a.min())
|
| 868 |
+
except: pass
|
| 869 |
+
max_str = "N/A"
|
| 870 |
+
try: max_str = format_float(a.max())
|
| 871 |
+
except: pass
|
| 872 |
+
mean_str = "N/A"
|
| 873 |
+
try: mean_str = format_float(a.mean())
|
| 874 |
+
except: pass
|
| 875 |
+
try: median_str = format_float(a.median())
|
| 876 |
+
except:
|
| 877 |
+
try: median_str = format_float(np.median(np.array(a)))
|
| 878 |
+
except: median_str = 'N/A'
|
| 879 |
+
return (min_str, max_str, mean_str, median_str)
|
| 880 |
+
|
| 881 |
+
def get_prop_dict(a,k=None):
|
| 882 |
+
minmaxmean = minmaxmean_str(a)
|
| 883 |
+
props = {
|
| 884 |
+
'name' : name_from_outer_scope(a) if k is None else k,
|
| 885 |
+
# 'type' : str(type(a)).replace('torch.',''),
|
| 886 |
+
'dtype' : dtype_str(a),
|
| 887 |
+
'shape' : shape_str(a),
|
| 888 |
+
'type' : type_str(a),
|
| 889 |
+
'device' : device_str(a),
|
| 890 |
+
'min' : minmaxmean[0],
|
| 891 |
+
'max' : minmaxmean[1],
|
| 892 |
+
'mean' : minmaxmean[2],
|
| 893 |
+
'median': minmaxmean[3]
|
| 894 |
+
}
|
| 895 |
+
return props
|
| 896 |
+
|
| 897 |
+
try:
|
| 898 |
+
|
| 899 |
+
props = ['name', 'type', 'dtype', 'shape', 'device', 'min', 'max', 'mean', 'median']
|
| 900 |
+
|
| 901 |
+
# precompute all of the properties for each input
|
| 902 |
+
str_props = []
|
| 903 |
+
for a in arrs:
|
| 904 |
+
str_props.append(get_prop_dict(a))
|
| 905 |
+
for k,a in kwargs.items():
|
| 906 |
+
str_props.append(get_prop_dict(a, k=k))
|
| 907 |
+
|
| 908 |
+
# for each property, compute its length
|
| 909 |
+
maxlen = {}
|
| 910 |
+
for p in props: maxlen[p] = 0
|
| 911 |
+
for sp in str_props:
|
| 912 |
+
for p in props:
|
| 913 |
+
maxlen[p] = max(maxlen[p], len(sp[p]))
|
| 914 |
+
|
| 915 |
+
# if any property got all empty strings, don't bother printing it, remove if from the list
|
| 916 |
+
props = [p for p in props if maxlen[p] > 0]
|
| 917 |
+
|
| 918 |
+
# print a header
|
| 919 |
+
header_str = ""
|
| 920 |
+
for p in props:
|
| 921 |
+
prefix = "" if p == 'name' else " | "
|
| 922 |
+
fmt_key = ">" if p == 'name' else "<"
|
| 923 |
+
header_str += f"{prefix}{p:{fmt_key}{maxlen[p]}}"
|
| 924 |
+
print(header_str)
|
| 925 |
+
print("-"*len(header_str))
|
| 926 |
+
|
| 927 |
+
# now print the acual arrays
|
| 928 |
+
for strp in str_props:
|
| 929 |
+
for p in props:
|
| 930 |
+
prefix = "" if p == 'name' else " | "
|
| 931 |
+
fmt_key = ">" if p == 'name' else "<"
|
| 932 |
+
print(f"{prefix}{strp[p]:{fmt_key}{maxlen[p]}}", end='')
|
| 933 |
+
print("")
|
| 934 |
+
|
| 935 |
+
finally:
|
| 936 |
+
del frame
|
| 937 |
+
|
| 938 |
+
def debug_print_all_tensor_sizes(min_tot_size = 0):
|
| 939 |
+
import gc
|
| 940 |
+
print("---------------------------------------"*3)
|
| 941 |
+
for obj in gc.get_objects():
|
| 942 |
+
try:
|
| 943 |
+
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
|
| 944 |
+
if np.prod(obj.size())>=min_tot_size:
|
| 945 |
+
print(type(obj), obj.size())
|
| 946 |
+
except:
|
| 947 |
+
pass
|
| 948 |
+
def print_cpu_usage():
|
| 949 |
+
|
| 950 |
+
# Get current CPU usage as a percentage
|
| 951 |
+
cpu_usage = psutil.cpu_percent()
|
| 952 |
+
|
| 953 |
+
# Get current memory usage
|
| 954 |
+
memory_usage = psutil.virtual_memory().used
|
| 955 |
+
|
| 956 |
+
# Convert memory usage to a human-readable format
|
| 957 |
+
memory_usage_str = psutil._common.bytes2human(memory_usage)
|
| 958 |
+
|
| 959 |
+
# Print CPU and memory usage
|
| 960 |
+
msg = f"Current CPU usage: {cpu_usage}% | "
|
| 961 |
+
msg += f"Current memory usage: {memory_usage_str}"
|
| 962 |
+
return msg
|
| 963 |
+
|
| 964 |
+
def calmsize(num_bytes):
|
| 965 |
+
if math.isnan(num_bytes):
|
| 966 |
+
return ''
|
| 967 |
+
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
| 968 |
+
if abs(num_bytes) < 1024.0:
|
| 969 |
+
return "{:.1f}{}B".format(num_bytes, unit)
|
| 970 |
+
num_bytes /= 1024.0
|
| 971 |
+
return "{:.1f}{}B".format(num_bytes, 'Y')
|
| 972 |
+
|
| 973 |
+
def readable_size(num_bytes: int) -> str:
|
| 974 |
+
return calmsize(num_bytes) ## '' if math.isnan(num_bytes) else '{:.1f}'.format(calmsize(num_bytes))
|
| 975 |
+
|
| 976 |
+
def get_gpu_memory():
|
| 977 |
+
"""
|
| 978 |
+
Get the current GPU memory usage for each device as a dictionary
|
| 979 |
+
"""
|
| 980 |
+
output = subprocess.check_output(["nvidia-smi", "--query-gpu=memory.used", "--format=csv"])
|
| 981 |
+
output = output.decode("utf-8")
|
| 982 |
+
gpu_memory_values = output.split("\n")[1:-1]
|
| 983 |
+
gpu_memory_values = [int(x.strip().split()[0]) for x in gpu_memory_values]
|
| 984 |
+
gpu_memory = dict(zip(range(len(gpu_memory_values)), gpu_memory_values))
|
| 985 |
+
return gpu_memory
|
| 986 |
+
|
| 987 |
+
def get_gpu_util():
|
| 988 |
+
"""
|
| 989 |
+
Get the current GPU memory usage for each device as a dictionary
|
| 990 |
+
"""
|
| 991 |
+
output = subprocess.check_output(["nvidia-smi", "--query-gpu=utilization.gpu", "--format=csv"])
|
| 992 |
+
output = output.decode("utf-8")
|
| 993 |
+
gpu_memory_values = output.split("\n")[1:-1]
|
| 994 |
+
gpu_memory_values = [int(x.strip().split()[0]) for x in gpu_memory_values]
|
| 995 |
+
gpu_util = dict(zip(range(len(gpu_memory_values)), gpu_memory_values))
|
| 996 |
+
return gpu_util
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
def print_gpu_usage():
|
| 1000 |
+
useage = get_gpu_memory()
|
| 1001 |
+
msg = f" | GPU usage: "
|
| 1002 |
+
for k, v in useage.items():
|
| 1003 |
+
msg += f"{k}: {v} MB "
|
| 1004 |
+
# utilization = get_gpu_util()
|
| 1005 |
+
# msg + ' | util '
|
| 1006 |
+
# for k, v in utilization.items():
|
| 1007 |
+
# msg += f"{k}: {v} % "
|
| 1008 |
+
return msg
|
| 1009 |
+
|
| 1010 |
+
class AverageMeter(object):
|
| 1011 |
+
|
| 1012 |
+
def __init__(self):
|
| 1013 |
+
self.reset()
|
| 1014 |
+
|
| 1015 |
+
def reset(self):
|
| 1016 |
+
self.avg = 0
|
| 1017 |
+
self.sum = 0
|
| 1018 |
+
self.cnt = 0
|
| 1019 |
+
|
| 1020 |
+
def update(self, val, n=1):
|
| 1021 |
+
self.sum += val * n
|
| 1022 |
+
self.cnt += n
|
| 1023 |
+
self.avg = self.sum / self.cnt
|
| 1024 |
+
|
| 1025 |
+
|
| 1026 |
+
def generate_random_string(length):
|
| 1027 |
+
# This script will generate a string of 10 random ASCII letters (both lowercase and uppercase).
|
| 1028 |
+
# You can adjust the length parameter to fit your needs.
|
| 1029 |
+
letters = string.ascii_letters
|
| 1030 |
+
return ''.join(random.choice(letters) for _ in range(length))
|
| 1031 |
+
|
| 1032 |
+
|
| 1033 |
+
class ForkedPdb(pdb.Pdb):
|
| 1034 |
+
"""
|
| 1035 |
+
PDB Subclass for debugging multi-processed code
|
| 1036 |
+
Suggested in: https://stackoverflow.com/questions/4716533/how-to-attach-debugger-to-a-python-subproccess
|
| 1037 |
+
"""
|
| 1038 |
+
def interaction(self, *args, **kwargs):
|
| 1039 |
+
_stdin = sys.stdin
|
| 1040 |
+
try:
|
| 1041 |
+
sys.stdin = open('/dev/stdin')
|
| 1042 |
+
pdb.Pdb.interaction(self, *args, **kwargs)
|
| 1043 |
+
finally:
|
| 1044 |
+
sys.stdin = _stdin
|
| 1045 |
+
|
| 1046 |
+
def check_exist_in_s3(file_path, s3_config):
|
| 1047 |
+
s3 = init_s3(s3_config)
|
| 1048 |
+
bucket_name, object_name = s3path_to_bucket_key(file_path)
|
| 1049 |
+
|
| 1050 |
+
try:
|
| 1051 |
+
s3.head_object(Bucket=bucket_name, Key=object_name)
|
| 1052 |
+
return 1
|
| 1053 |
+
except:
|
| 1054 |
+
logger.info(f'file not found: s3://{bucket_name}/{object_name}')
|
| 1055 |
+
return 0
|
| 1056 |
+
|
| 1057 |
+
def s3path_to_bucket_key(file_path):
|
| 1058 |
+
bucket_name = file_path.split('/')[2]
|
| 1059 |
+
object_name = file_path.split(bucket_name + '/')[-1]
|
| 1060 |
+
return bucket_name, object_name
|
| 1061 |
+
|
| 1062 |
+
def copy_file_to_s3(cfg, file_path_local, file_path_s3):
|
| 1063 |
+
# work similar as upload_file_to_s3, but not trying to parse the file path
|
| 1064 |
+
# file_path_s3: s3://{bucket}/{key}
|
| 1065 |
+
bucket_name, key = s3path_to_bucket_key(file_path_s3)
|
| 1066 |
+
tic = time.time()
|
| 1067 |
+
s3_client = init_s3(cfg.checkpoint.write_s3_config)
|
| 1068 |
+
|
| 1069 |
+
# Upload the file
|
| 1070 |
+
with open(file_path_local, 'rb') as f:
|
| 1071 |
+
s3_client.upload_fileobj(f, bucket_name, key)
|
| 1072 |
+
full_s3_path = f"s3://{bucket_name}/{key}"
|
| 1073 |
+
logger.info(f'copy file: {file_path_local} {full_s3_path} | use time: {time.time()-tic}')
|
| 1074 |
+
return full_s3_path
|
PartField/partfield/model/PVCNN/encoder_pc.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
from ast import Dict
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch_scatter import scatter_mean #, scatter_max
|
| 17 |
+
|
| 18 |
+
from .unet_3daware import setup_unet #UNetTriplane3dAware
|
| 19 |
+
from .conv_pointnet import ConvPointnet
|
| 20 |
+
|
| 21 |
+
from .pc_encoder import PVCNNEncoder #PointNet
|
| 22 |
+
|
| 23 |
+
import einops
|
| 24 |
+
|
| 25 |
+
from .dnnlib_util import ScopedTorchProfiler, printarr
|
| 26 |
+
|
| 27 |
+
def generate_plane_features(p, c, resolution, plane='xz'):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
p: (B,3,n_p)
|
| 31 |
+
c: (B,C,n_p)
|
| 32 |
+
"""
|
| 33 |
+
padding = 0.
|
| 34 |
+
c_dim = c.size(1)
|
| 35 |
+
# acquire indices of features in plane
|
| 36 |
+
xy = normalize_coordinate(p.clone(), plane=plane, padding=padding) # normalize to the range of (0, 1)
|
| 37 |
+
index = coordinate2index(xy, resolution)
|
| 38 |
+
|
| 39 |
+
# scatter plane features from points
|
| 40 |
+
fea_plane = c.new_zeros(p.size(0), c_dim, resolution**2)
|
| 41 |
+
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
|
| 42 |
+
fea_plane = fea_plane.reshape(p.size(0), c_dim, resolution, resolution) # sparce matrix (B x 512 x reso x reso)
|
| 43 |
+
return fea_plane
|
| 44 |
+
|
| 45 |
+
def normalize_coordinate(p, padding=0.1, plane='xz'):
|
| 46 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
p (tensor): point
|
| 50 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
| 51 |
+
plane (str): plane feature type, ['xz', 'xy', 'yz']
|
| 52 |
+
'''
|
| 53 |
+
if plane == 'xz':
|
| 54 |
+
xy = p[:, :, [0, 2]]
|
| 55 |
+
elif plane =='xy':
|
| 56 |
+
xy = p[:, :, [0, 1]]
|
| 57 |
+
else:
|
| 58 |
+
xy = p[:, :, [1, 2]]
|
| 59 |
+
|
| 60 |
+
xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
|
| 61 |
+
xy_new = xy_new + 0.5 # range (0, 1)
|
| 62 |
+
|
| 63 |
+
# if there are outliers out of the range
|
| 64 |
+
if xy_new.max() >= 1:
|
| 65 |
+
xy_new[xy_new >= 1] = 1 - 10e-6
|
| 66 |
+
if xy_new.min() < 0:
|
| 67 |
+
xy_new[xy_new < 0] = 0.0
|
| 68 |
+
return xy_new
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def coordinate2index(x, resolution):
|
| 72 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
| 73 |
+
Corresponds to our 3D model
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
x (tensor): coordinate
|
| 77 |
+
reso (int): defined resolution
|
| 78 |
+
coord_type (str): coordinate type
|
| 79 |
+
'''
|
| 80 |
+
x = (x * resolution).long()
|
| 81 |
+
index = x[:, :, 0] + resolution * x[:, :, 1]
|
| 82 |
+
index = index[:, None, :]
|
| 83 |
+
return index
|
| 84 |
+
|
| 85 |
+
def softclip(x, min, max, hardness=5):
|
| 86 |
+
# Soft clipping for the logsigma
|
| 87 |
+
x = min + F.softplus(hardness*(x - min))/hardness
|
| 88 |
+
x = max - F.softplus(-hardness*(x - max))/hardness
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def sample_triplane_feat(feature_triplane, normalized_pos):
|
| 93 |
+
'''
|
| 94 |
+
normalized_pos [-1, 1]
|
| 95 |
+
'''
|
| 96 |
+
tri_plane = torch.unbind(feature_triplane, dim=1)
|
| 97 |
+
|
| 98 |
+
x_feat = F.grid_sample(
|
| 99 |
+
tri_plane[0],
|
| 100 |
+
torch.cat(
|
| 101 |
+
[normalized_pos[:, :, 0:1], normalized_pos[:, :, 1:2]],
|
| 102 |
+
dim=-1).unsqueeze(dim=1), padding_mode='border',
|
| 103 |
+
align_corners=True)
|
| 104 |
+
y_feat = F.grid_sample(
|
| 105 |
+
tri_plane[1],
|
| 106 |
+
torch.cat(
|
| 107 |
+
[normalized_pos[:, :, 1:2], normalized_pos[:, :, 2:3]],
|
| 108 |
+
dim=-1).unsqueeze(dim=1), padding_mode='border',
|
| 109 |
+
align_corners=True)
|
| 110 |
+
|
| 111 |
+
z_feat = F.grid_sample(
|
| 112 |
+
tri_plane[2],
|
| 113 |
+
torch.cat(
|
| 114 |
+
[normalized_pos[:, :, 0:1], normalized_pos[:, :, 2:3]],
|
| 115 |
+
dim=-1).unsqueeze(dim=1), padding_mode='border',
|
| 116 |
+
align_corners=True)
|
| 117 |
+
final_feat = (x_feat + y_feat + z_feat)
|
| 118 |
+
final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1) # 32dimension
|
| 119 |
+
return final_feat
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# @persistence.persistent_class
|
| 123 |
+
class TriPlanePC2Encoder(torch.nn.Module):
|
| 124 |
+
# Encoder that encode point cloud to triplane feature vector similar to ConvOccNet
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
cfg,
|
| 128 |
+
device='cuda',
|
| 129 |
+
shape_min=-1.0,
|
| 130 |
+
shape_length=2.0,
|
| 131 |
+
use_2d_feat=False,
|
| 132 |
+
# point_encoder='pvcnn',
|
| 133 |
+
# use_point_scatter=False
|
| 134 |
+
):
|
| 135 |
+
"""
|
| 136 |
+
Outputs latent triplane from PC input
|
| 137 |
+
Configs:
|
| 138 |
+
max_logsigma: (float) Soft clip upper range for logsigm
|
| 139 |
+
min_logsigma: (float)
|
| 140 |
+
point_encoder_type: (str) one of ['pvcnn', 'pointnet']
|
| 141 |
+
pvcnn_flatten_voxels: (bool) for pvcnn whether to reduce voxel
|
| 142 |
+
features (instead of scattering point features)
|
| 143 |
+
unet_cfg: (dict)
|
| 144 |
+
z_triplane_channels: (int) output latent triplane
|
| 145 |
+
z_triplane_resolution: (int)
|
| 146 |
+
Args:
|
| 147 |
+
|
| 148 |
+
"""
|
| 149 |
+
# assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.device = device
|
| 152 |
+
|
| 153 |
+
self.cfg = cfg
|
| 154 |
+
|
| 155 |
+
self.shape_min = shape_min
|
| 156 |
+
self.shape_length = shape_length
|
| 157 |
+
|
| 158 |
+
self.z_triplane_resolution = cfg.z_triplane_resolution
|
| 159 |
+
z_triplane_channels = cfg.z_triplane_channels
|
| 160 |
+
|
| 161 |
+
point_encoder_out_dim = z_triplane_channels #* 2
|
| 162 |
+
|
| 163 |
+
in_channels = 6
|
| 164 |
+
# self.resample_filter=[1, 3, 3, 1]
|
| 165 |
+
if cfg.point_encoder_type == 'pvcnn':
|
| 166 |
+
self.pc_encoder = PVCNNEncoder(point_encoder_out_dim,
|
| 167 |
+
device=self.device, in_channels=in_channels, use_2d_feat=use_2d_feat) # Encode it to a volume vector.
|
| 168 |
+
elif cfg.point_encoder_type == 'pointnet':
|
| 169 |
+
# TODO the pointnet was buggy, investigate
|
| 170 |
+
self.pc_encoder = ConvPointnet(c_dim=point_encoder_out_dim,
|
| 171 |
+
dim=in_channels, hidden_dim=32,
|
| 172 |
+
plane_resolution=self.z_triplane_resolution,
|
| 173 |
+
padding=0)
|
| 174 |
+
else:
|
| 175 |
+
raise NotImplementedError(f"Point encoder {cfg.point_encoder_type} not implemented")
|
| 176 |
+
|
| 177 |
+
if cfg.unet_cfg.enabled:
|
| 178 |
+
self.unet_encoder = setup_unet(
|
| 179 |
+
output_channels=point_encoder_out_dim,
|
| 180 |
+
input_channels=point_encoder_out_dim,
|
| 181 |
+
unet_cfg=cfg.unet_cfg)
|
| 182 |
+
else:
|
| 183 |
+
self.unet_encoder = None
|
| 184 |
+
|
| 185 |
+
# @ScopedTorchProfiler('encode')
|
| 186 |
+
def encode(self, point_cloud_xyz, point_cloud_feature, mv_feat=None, pc2pc_idx=None) -> Dict:
|
| 187 |
+
# output = AttrDict()
|
| 188 |
+
point_cloud_xyz = (point_cloud_xyz - self.shape_min) / self.shape_length # [0, 1]
|
| 189 |
+
point_cloud_xyz = point_cloud_xyz - 0.5 # [-0.5, 0.5]
|
| 190 |
+
point_cloud = torch.cat([point_cloud_xyz, point_cloud_feature], dim=-1)
|
| 191 |
+
|
| 192 |
+
if self.cfg.point_encoder_type == 'pvcnn':
|
| 193 |
+
if mv_feat is not None:
|
| 194 |
+
pc_feat, points_feat = self.pc_encoder(point_cloud, mv_feat, pc2pc_idx)
|
| 195 |
+
else:
|
| 196 |
+
pc_feat, points_feat = self.pc_encoder(point_cloud) # 3D feature volume: BxDx32x32x32
|
| 197 |
+
if self.cfg.use_point_scatter:
|
| 198 |
+
# Scattering from PVCNN point features
|
| 199 |
+
points_feat_ = points_feat[0]
|
| 200 |
+
# shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
|
| 201 |
+
pc_feat_1 = generate_plane_features(point_cloud_xyz, points_feat_,
|
| 202 |
+
resolution=self.z_triplane_resolution, plane='xy')
|
| 203 |
+
pc_feat_2 = generate_plane_features(point_cloud_xyz, points_feat_,
|
| 204 |
+
resolution=self.z_triplane_resolution, plane='yz')
|
| 205 |
+
pc_feat_3 = generate_plane_features(point_cloud_xyz, points_feat_,
|
| 206 |
+
resolution=self.z_triplane_resolution, plane='xz')
|
| 207 |
+
pc_feat = pc_feat[0]
|
| 208 |
+
|
| 209 |
+
else:
|
| 210 |
+
pc_feat = pc_feat[0]
|
| 211 |
+
sf = self.z_triplane_resolution//32 # 32 is PVCNN's voxel dim
|
| 212 |
+
|
| 213 |
+
pc_feat_1 = torch.mean(pc_feat, dim=-1) #xy_plane, normalize in z plane
|
| 214 |
+
pc_feat_2 = torch.mean(pc_feat, dim=-3) #yz_plane, normalize in x plane
|
| 215 |
+
pc_feat_3 = torch.mean(pc_feat, dim=-2) #xz_plane, normalize in y plane
|
| 216 |
+
|
| 217 |
+
# nearest upsample
|
| 218 |
+
pc_feat_1 = einops.repeat(pc_feat_1, 'b c h w -> b c (h hm ) (w wm)', hm = sf, wm = sf)
|
| 219 |
+
pc_feat_2 = einops.repeat(pc_feat_2, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf)
|
| 220 |
+
pc_feat_3 = einops.repeat(pc_feat_3, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf)
|
| 221 |
+
elif self.cfg.point_encoder_type == 'pointnet':
|
| 222 |
+
assert self.cfg.use_point_scatter
|
| 223 |
+
# Run ConvPointnet
|
| 224 |
+
pc_feat = self.pc_encoder(point_cloud)
|
| 225 |
+
pc_feat_1 = pc_feat['xy'] #
|
| 226 |
+
pc_feat_2 = pc_feat['yz']
|
| 227 |
+
pc_feat_3 = pc_feat['xz']
|
| 228 |
+
else:
|
| 229 |
+
raise NotImplementedError()
|
| 230 |
+
|
| 231 |
+
if self.unet_encoder is not None:
|
| 232 |
+
# TODO eval adding a skip connection
|
| 233 |
+
# Unet expects B, 3, C, H, W
|
| 234 |
+
pc_feat_tri_plane_stack_pre = torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1)
|
| 235 |
+
# dpc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre)
|
| 236 |
+
# pc_feat_tri_plane_stack = pc_feat_tri_plane_stack_pre + dpc_feat_tri_plane_stack
|
| 237 |
+
pc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre)
|
| 238 |
+
pc_feat_1, pc_feat_2, pc_feat_3 = torch.unbind(pc_feat_tri_plane_stack, dim=1)
|
| 239 |
+
|
| 240 |
+
return torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1)
|
| 241 |
+
|
| 242 |
+
def forward(self, point_cloud_xyz, point_cloud_feature=None, mv_feat=None, pc2pc_idx=None):
|
| 243 |
+
return self.encode(point_cloud_xyz, point_cloud_feature=point_cloud_feature, mv_feat=mv_feat, pc2pc_idx=pc2pc_idx)
|
PartField/partfield/model/PVCNN/pc_encoder.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import functools
|
| 5 |
+
|
| 6 |
+
from .pv_module import SharedMLP, PVConv
|
| 7 |
+
|
| 8 |
+
def create_pointnet_components(
|
| 9 |
+
blocks, in_channels, with_se=False, normalize=True, eps=0,
|
| 10 |
+
width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=False, device='cuda'):
|
| 11 |
+
r, vr = width_multiplier, voxel_resolution_multiplier
|
| 12 |
+
layers, concat_channels = [], 0
|
| 13 |
+
for out_channels, num_blocks, voxel_resolution in blocks:
|
| 14 |
+
out_channels = int(r * out_channels)
|
| 15 |
+
if voxel_resolution is None:
|
| 16 |
+
block = functools.partial(SharedMLP, device=device)
|
| 17 |
+
else:
|
| 18 |
+
block = functools.partial(
|
| 19 |
+
PVConv, kernel_size=3, resolution=int(vr * voxel_resolution),
|
| 20 |
+
with_se=with_se, normalize=normalize, eps=eps, scale_pvcnn=scale_pvcnn, device=device)
|
| 21 |
+
for _ in range(num_blocks):
|
| 22 |
+
layers.append(block(in_channels, out_channels))
|
| 23 |
+
in_channels = out_channels
|
| 24 |
+
concat_channels += out_channels
|
| 25 |
+
return layers, in_channels, concat_channels
|
| 26 |
+
|
| 27 |
+
class PCMerger(nn.Module):
|
| 28 |
+
# merge surface sampled PC and rendering backprojected PC (w/ 2D features):
|
| 29 |
+
def __init__(self, in_channels=204, device="cuda"):
|
| 30 |
+
super(PCMerger, self).__init__()
|
| 31 |
+
self.mlp_normal = SharedMLP(3, [128, 128], device=device)
|
| 32 |
+
self.mlp_rgb = SharedMLP(3, [128, 128], device=device)
|
| 33 |
+
self.mlp_sam = SharedMLP(204 - 6, [128, 128], device=device)
|
| 34 |
+
|
| 35 |
+
def forward(self, feat, mv_feat, pc2pc_idx):
|
| 36 |
+
mv_feat_normal = self.mlp_normal(mv_feat[:, :3, :])
|
| 37 |
+
mv_feat_rgb = self.mlp_rgb(mv_feat[:, 3:6, :])
|
| 38 |
+
mv_feat_sam = self.mlp_sam(mv_feat[:, 6:, :])
|
| 39 |
+
|
| 40 |
+
mv_feat_normal = mv_feat_normal.permute(0, 2, 1)
|
| 41 |
+
mv_feat_rgb = mv_feat_rgb.permute(0, 2, 1)
|
| 42 |
+
mv_feat_sam = mv_feat_sam.permute(0, 2, 1)
|
| 43 |
+
feat = feat.permute(0, 2, 1)
|
| 44 |
+
|
| 45 |
+
for i in range(mv_feat.shape[0]):
|
| 46 |
+
mask = (pc2pc_idx[i] != -1).reshape(-1)
|
| 47 |
+
idx = pc2pc_idx[i][mask].reshape(-1)
|
| 48 |
+
feat[i][mask] += mv_feat_normal[i][idx] + mv_feat_rgb[i][idx] + mv_feat_sam[i][idx]
|
| 49 |
+
|
| 50 |
+
return feat.permute(0, 2, 1)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class PVCNNEncoder(nn.Module):
|
| 54 |
+
def __init__(self, pvcnn_feat_dim, device='cuda', in_channels=3, use_2d_feat=False):
|
| 55 |
+
super(PVCNNEncoder, self).__init__()
|
| 56 |
+
self.device = device
|
| 57 |
+
self.blocks = ((pvcnn_feat_dim, 1, 32), (128, 2, 16), (256, 1, 8))
|
| 58 |
+
self.use_2d_feat=use_2d_feat
|
| 59 |
+
if in_channels == 6:
|
| 60 |
+
self.append_channel = 2
|
| 61 |
+
elif in_channels == 3:
|
| 62 |
+
self.append_channel = 1
|
| 63 |
+
else:
|
| 64 |
+
raise NotImplementedError
|
| 65 |
+
layers, channels_point, concat_channels_point = create_pointnet_components(
|
| 66 |
+
blocks=self.blocks, in_channels=in_channels + self.append_channel, with_se=False, normalize=False,
|
| 67 |
+
width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=True,
|
| 68 |
+
device=device
|
| 69 |
+
)
|
| 70 |
+
self.encoder = nn.ModuleList(layers)#.to(self.device)
|
| 71 |
+
if self.use_2d_feat:
|
| 72 |
+
self.merger = PCMerger()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def forward(self, input_pc, mv_feat=None, pc2pc_idx=None):
|
| 77 |
+
features = input_pc.permute(0, 2, 1) * 2 # make point cloud [-1, 1]
|
| 78 |
+
coords = features[:, :3, :]
|
| 79 |
+
out_features_list = []
|
| 80 |
+
voxel_feature_list = []
|
| 81 |
+
zero_padding = torch.zeros(features.shape[0], self.append_channel, features.shape[-1], device=features.device, dtype=torch.float)
|
| 82 |
+
features = torch.cat([features, zero_padding], dim=1)##################
|
| 83 |
+
|
| 84 |
+
for i in range(len(self.encoder)):
|
| 85 |
+
features, _, voxel_feature = self.encoder[i]((features, coords))
|
| 86 |
+
if i == 0 and mv_feat is not None:
|
| 87 |
+
features = self.merger(features, mv_feat.permute(0, 2, 1), pc2pc_idx)
|
| 88 |
+
out_features_list.append(features)
|
| 89 |
+
voxel_feature_list.append(voxel_feature)
|
| 90 |
+
return voxel_feature_list, out_features_list
|
PartField/partfield/model/PVCNN/pv_module/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pvconv import PVConv
|
| 2 |
+
from .shared_mlp import SharedMLP
|
PartField/partfield/model/PVCNN/pv_module/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (284 Bytes). View file
|
|
|
PartField/partfield/model/PVCNN/pv_module/__pycache__/pvconv.cpython-310.pyc
ADDED
|
Binary file (1.62 kB). View file
|
|
|
PartField/partfield/model/PVCNN/pv_module/__pycache__/shared_mlp.cpython-310.pyc
ADDED
|
Binary file (1.18 kB). View file
|
|
|
PartField/partfield/model/PVCNN/pv_module/__pycache__/voxelization.cpython-310.pyc
ADDED
|
Binary file (2.23 kB). View file
|
|
|
PartField/partfield/model/PVCNN/pv_module/ball_query.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from . import functional as F
|
| 5 |
+
|
| 6 |
+
__all__ = ['BallQuery']
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BallQuery(nn.Module):
|
| 10 |
+
def __init__(self, radius, num_neighbors, include_coordinates=True):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.radius = radius
|
| 13 |
+
self.num_neighbors = num_neighbors
|
| 14 |
+
self.include_coordinates = include_coordinates
|
| 15 |
+
|
| 16 |
+
def forward(self, points_coords, centers_coords, points_features=None):
|
| 17 |
+
points_coords = points_coords.contiguous()
|
| 18 |
+
centers_coords = centers_coords.contiguous()
|
| 19 |
+
neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
|
| 20 |
+
neighbor_coordinates = F.grouping(points_coords, neighbor_indices)
|
| 21 |
+
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
|
| 22 |
+
|
| 23 |
+
if points_features is None:
|
| 24 |
+
assert self.include_coordinates, 'No Features For Grouping'
|
| 25 |
+
neighbor_features = neighbor_coordinates
|
| 26 |
+
else:
|
| 27 |
+
neighbor_features = F.grouping(points_features, neighbor_indices)
|
| 28 |
+
if self.include_coordinates:
|
| 29 |
+
neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1)
|
| 30 |
+
return neighbor_features
|
| 31 |
+
|
| 32 |
+
def extra_repr(self):
|
| 33 |
+
return 'radius={}, num_neighbors={}{}'.format(
|
| 34 |
+
self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
|
PartField/partfield/model/PVCNN/pv_module/frustum.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from . import functional as PF
|
| 7 |
+
|
| 8 |
+
__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d']
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FrustumPointNetLoss(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0,
|
| 14 |
+
corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.box_loss_weight = box_loss_weight
|
| 17 |
+
self.corners_loss_weight = corners_loss_weight
|
| 18 |
+
self.heading_residual_loss_weight = heading_residual_loss_weight
|
| 19 |
+
self.size_residual_loss_weight = size_residual_loss_weight
|
| 20 |
+
|
| 21 |
+
self.num_heading_angle_bins = num_heading_angle_bins
|
| 22 |
+
self.num_size_templates = num_size_templates
|
| 23 |
+
self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3))
|
| 24 |
+
self.register_buffer(
|
| 25 |
+
'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def forward(self, inputs, targets):
|
| 29 |
+
mask_logits = inputs['mask_logits'] # (B, 2, N)
|
| 30 |
+
center_reg = inputs['center_reg'] # (B, 3)
|
| 31 |
+
center = inputs['center'] # (B, 3)
|
| 32 |
+
heading_scores = inputs['heading_scores'] # (B, NH)
|
| 33 |
+
heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH)
|
| 34 |
+
heading_residuals = inputs['heading_residuals'] # (B, NH)
|
| 35 |
+
size_scores = inputs['size_scores'] # (B, NS)
|
| 36 |
+
size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3)
|
| 37 |
+
size_residuals = inputs['size_residuals'] # (B, NS, 3)
|
| 38 |
+
|
| 39 |
+
mask_logits_target = targets['mask_logits'] # (B, N)
|
| 40 |
+
center_target = targets['center'] # (B, 3)
|
| 41 |
+
heading_bin_id_target = targets['heading_bin_id'] # (B, )
|
| 42 |
+
heading_residual_target = targets['heading_residual'] # (B, )
|
| 43 |
+
size_template_id_target = targets['size_template_id'] # (B, )
|
| 44 |
+
size_residual_target = targets['size_residual'] # (B, 3)
|
| 45 |
+
|
| 46 |
+
batch_size = center.size(0)
|
| 47 |
+
batch_id = torch.arange(batch_size, device=center.device)
|
| 48 |
+
|
| 49 |
+
# Basic Classification and Regression losses
|
| 50 |
+
mask_loss = F.cross_entropy(mask_logits, mask_logits_target)
|
| 51 |
+
heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target)
|
| 52 |
+
size_loss = F.cross_entropy(size_scores, size_template_id_target)
|
| 53 |
+
center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0)
|
| 54 |
+
center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0)
|
| 55 |
+
|
| 56 |
+
# Refinement losses for size/heading
|
| 57 |
+
heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, )
|
| 58 |
+
heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins)
|
| 59 |
+
heading_residual_normalized_loss = PF.huber_loss(
|
| 60 |
+
heading_residuals_normalized - heading_residual_normalized_target, delta=1.0
|
| 61 |
+
)
|
| 62 |
+
size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3)
|
| 63 |
+
size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target]
|
| 64 |
+
size_residual_normalized_loss = PF.huber_loss(
|
| 65 |
+
torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Bounding box losses
|
| 69 |
+
heading = (heading_residuals[batch_id, heading_bin_id_target]
|
| 70 |
+
+ self.heading_angle_bin_centers[heading_bin_id_target]) # (B, )
|
| 71 |
+
# Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
|
| 72 |
+
size = (size_residuals[batch_id, size_template_id_target]
|
| 73 |
+
+ self.size_templates[size_template_id_target]) # (B, 3)
|
| 74 |
+
corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
|
| 75 |
+
heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
|
| 76 |
+
size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
|
| 77 |
+
corners_target, corners_target_flip = get_box_corners_3d(
|
| 78 |
+
centers=center_target, headings=heading_target,
|
| 79 |
+
sizes=size_target, with_flip=True) # (B, 3, 8)
|
| 80 |
+
corners_loss = PF.huber_loss(
|
| 81 |
+
torch.min(
|
| 82 |
+
torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)
|
| 83 |
+
), delta=1.0)
|
| 84 |
+
# Summing up
|
| 85 |
+
loss = mask_loss + self.box_loss_weight * (
|
| 86 |
+
center_loss + center_reg_loss + heading_loss + size_loss
|
| 87 |
+
+ self.heading_residual_loss_weight * heading_residual_normalized_loss
|
| 88 |
+
+ self.size_residual_loss_weight * size_residual_normalized_loss
|
| 89 |
+
+ self.corners_loss_weight * corners_loss
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return loss
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_box_corners_3d(centers, headings, sizes, with_flip=False):
|
| 96 |
+
"""
|
| 97 |
+
:param centers: coords of box centers, FloatTensor[N, 3]
|
| 98 |
+
:param headings: heading angles, FloatTensor[N, ]
|
| 99 |
+
:param sizes: box sizes, FloatTensor[N, 3]
|
| 100 |
+
:param with_flip: bool, whether to return flipped box (headings + np.pi)
|
| 101 |
+
:return:
|
| 102 |
+
coords of box corners, FloatTensor[N, 3, 8]
|
| 103 |
+
NOTE: corner points are in counter clockwise order, e.g.,
|
| 104 |
+
2--1
|
| 105 |
+
3--0 5
|
| 106 |
+
7--4
|
| 107 |
+
"""
|
| 108 |
+
l = sizes[:, 0] # (N,)
|
| 109 |
+
w = sizes[:, 1] # (N,)
|
| 110 |
+
h = sizes[:, 2] # (N,)
|
| 111 |
+
x_corners = torch.stack([l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2], dim=1) # (N, 8)
|
| 112 |
+
y_corners = torch.stack([h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2], dim=1) # (N, 8)
|
| 113 |
+
z_corners = torch.stack([w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2], dim=1) # (N, 8)
|
| 114 |
+
|
| 115 |
+
c = torch.cos(headings) # (N,)
|
| 116 |
+
s = torch.sin(headings) # (N,)
|
| 117 |
+
o = torch.ones_like(headings) # (N,)
|
| 118 |
+
z = torch.zeros_like(headings) # (N,)
|
| 119 |
+
|
| 120 |
+
centers = centers.unsqueeze(-1) # (B, 3, 1)
|
| 121 |
+
corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
|
| 122 |
+
R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3)
|
| 123 |
+
if with_flip:
|
| 124 |
+
R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3)
|
| 125 |
+
return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers
|
| 126 |
+
else:
|
| 127 |
+
return torch.matmul(R, corners) + centers
|
| 128 |
+
|
| 129 |
+
# centers = centers.unsqueeze(1) # (B, 1, 3)
|
| 130 |
+
# corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3)
|
| 131 |
+
# RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
|
| 132 |
+
# if with_flip:
|
| 133 |
+
# RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3)
|
| 134 |
+
# return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3)
|
| 135 |
+
# else:
|
| 136 |
+
# return torch.matmul(corners, RT) + centers # (N, 8, 3)
|
| 137 |
+
|
| 138 |
+
# corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
|
| 139 |
+
# R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
|
| 140 |
+
# corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8)
|
| 141 |
+
# corners = corners.transpose(1, 2) # (N, 8, 3)
|
PartField/partfield/model/PVCNN/pv_module/functional/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .devoxelization import trilinear_devoxelize
|
PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (273 Bytes). View file
|
|
|
PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/devoxelization.cpython-310.pyc
ADDED
|
Binary file (748 Bytes). View file
|
|
|
PartField/partfield/model/PVCNN/pv_module/functional/devoxelization.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.autograd import Function
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
__all__ = ['trilinear_devoxelize']
|
| 6 |
+
|
| 7 |
+
def trilinear_devoxelize(c, coords, r, training=None):
|
| 8 |
+
coords = (coords * 2 + 1.0) / r - 1.0
|
| 9 |
+
coords = coords.permute(0, 2, 1).reshape(c.shape[0], 1, 1, -1, 3)
|
| 10 |
+
f = F.grid_sample(input=c, grid=coords, padding_mode='border', align_corners=False)
|
| 11 |
+
f = f.squeeze(dim=2).squeeze(dim=2)
|
| 12 |
+
return f
|
PartField/partfield/model/PVCNN/pv_module/loss.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from . import functional as F
|
| 4 |
+
|
| 5 |
+
__all__ = ['KLLoss']
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class KLLoss(nn.Module):
|
| 9 |
+
def forward(self, x, y):
|
| 10 |
+
return F.kl_loss(x, y)
|
PartField/partfield/model/PVCNN/pv_module/pointnet.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from . import functional as F
|
| 5 |
+
from .ball_query import BallQuery
|
| 6 |
+
from .shared_mlp import SharedMLP
|
| 7 |
+
|
| 8 |
+
__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule']
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PointNetAModule(nn.Module):
|
| 12 |
+
def __init__(self, in_channels, out_channels, include_coordinates=True):
|
| 13 |
+
super().__init__()
|
| 14 |
+
if not isinstance(out_channels, (list, tuple)):
|
| 15 |
+
out_channels = [[out_channels]]
|
| 16 |
+
elif not isinstance(out_channels[0], (list, tuple)):
|
| 17 |
+
out_channels = [out_channels]
|
| 18 |
+
|
| 19 |
+
mlps = []
|
| 20 |
+
total_out_channels = 0
|
| 21 |
+
for _out_channels in out_channels:
|
| 22 |
+
mlps.append(
|
| 23 |
+
SharedMLP(
|
| 24 |
+
in_channels=in_channels + (3 if include_coordinates else 0),
|
| 25 |
+
out_channels=_out_channels, dim=1)
|
| 26 |
+
)
|
| 27 |
+
total_out_channels += _out_channels[-1]
|
| 28 |
+
|
| 29 |
+
self.include_coordinates = include_coordinates
|
| 30 |
+
self.out_channels = total_out_channels
|
| 31 |
+
self.mlps = nn.ModuleList(mlps)
|
| 32 |
+
|
| 33 |
+
def forward(self, inputs):
|
| 34 |
+
features, coords = inputs
|
| 35 |
+
if self.include_coordinates:
|
| 36 |
+
features = torch.cat([features, coords], dim=1)
|
| 37 |
+
coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
|
| 38 |
+
if len(self.mlps) > 1:
|
| 39 |
+
features_list = []
|
| 40 |
+
for mlp in self.mlps:
|
| 41 |
+
features_list.append(mlp(features).max(dim=-1, keepdim=True).values)
|
| 42 |
+
return torch.cat(features_list, dim=1), coords
|
| 43 |
+
else:
|
| 44 |
+
return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords
|
| 45 |
+
|
| 46 |
+
def extra_repr(self):
|
| 47 |
+
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class PointNetSAModule(nn.Module):
|
| 51 |
+
def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True):
|
| 52 |
+
super().__init__()
|
| 53 |
+
if not isinstance(radius, (list, tuple)):
|
| 54 |
+
radius = [radius]
|
| 55 |
+
if not isinstance(num_neighbors, (list, tuple)):
|
| 56 |
+
num_neighbors = [num_neighbors] * len(radius)
|
| 57 |
+
assert len(radius) == len(num_neighbors)
|
| 58 |
+
if not isinstance(out_channels, (list, tuple)):
|
| 59 |
+
out_channels = [[out_channels]] * len(radius)
|
| 60 |
+
elif not isinstance(out_channels[0], (list, tuple)):
|
| 61 |
+
out_channels = [out_channels] * len(radius)
|
| 62 |
+
assert len(radius) == len(out_channels)
|
| 63 |
+
|
| 64 |
+
groupers, mlps = [], []
|
| 65 |
+
total_out_channels = 0
|
| 66 |
+
for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors):
|
| 67 |
+
groupers.append(
|
| 68 |
+
BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates)
|
| 69 |
+
)
|
| 70 |
+
mlps.append(
|
| 71 |
+
SharedMLP(
|
| 72 |
+
in_channels=in_channels + (3 if include_coordinates else 0),
|
| 73 |
+
out_channels=_out_channels, dim=2)
|
| 74 |
+
)
|
| 75 |
+
total_out_channels += _out_channels[-1]
|
| 76 |
+
|
| 77 |
+
self.num_centers = num_centers
|
| 78 |
+
self.out_channels = total_out_channels
|
| 79 |
+
self.groupers = nn.ModuleList(groupers)
|
| 80 |
+
self.mlps = nn.ModuleList(mlps)
|
| 81 |
+
|
| 82 |
+
def forward(self, inputs):
|
| 83 |
+
features, coords = inputs
|
| 84 |
+
centers_coords = F.furthest_point_sample(coords, self.num_centers)
|
| 85 |
+
features_list = []
|
| 86 |
+
for grouper, mlp in zip(self.groupers, self.mlps):
|
| 87 |
+
features_list.append(mlp(grouper(coords, centers_coords, features)).max(dim=-1).values)
|
| 88 |
+
if len(features_list) > 1:
|
| 89 |
+
return torch.cat(features_list, dim=1), centers_coords
|
| 90 |
+
else:
|
| 91 |
+
return features_list[0], centers_coords
|
| 92 |
+
|
| 93 |
+
def extra_repr(self):
|
| 94 |
+
return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class PointNetFPModule(nn.Module):
|
| 98 |
+
def __init__(self, in_channels, out_channels):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1)
|
| 101 |
+
|
| 102 |
+
def forward(self, inputs):
|
| 103 |
+
if len(inputs) == 3:
|
| 104 |
+
points_coords, centers_coords, centers_features = inputs
|
| 105 |
+
points_features = None
|
| 106 |
+
else:
|
| 107 |
+
points_coords, centers_coords, centers_features, points_features = inputs
|
| 108 |
+
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
|
| 109 |
+
if points_features is not None:
|
| 110 |
+
interpolated_features = torch.cat(
|
| 111 |
+
[interpolated_features, points_features], dim=1
|
| 112 |
+
)
|
| 113 |
+
return self.mlp(interpolated_features), points_coords
|
PartField/partfield/model/PVCNN/pv_module/pvconv.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from . import functional as F
|
| 4 |
+
from .voxelization import Voxelization
|
| 5 |
+
from .shared_mlp import SharedMLP
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
__all__ = ['PVConv']
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PVConv(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self, in_channels, out_channels, kernel_size, resolution, with_se=False, normalize=True, eps=0, scale_pvcnn=False,
|
| 14 |
+
device='cuda'):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.in_channels = in_channels
|
| 17 |
+
self.out_channels = out_channels
|
| 18 |
+
self.kernel_size = kernel_size
|
| 19 |
+
self.resolution = resolution
|
| 20 |
+
self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps, scale_pvcnn=scale_pvcnn)
|
| 21 |
+
voxel_layers = [
|
| 22 |
+
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2, device=device),
|
| 23 |
+
nn.InstanceNorm3d(out_channels, eps=1e-4, device=device),
|
| 24 |
+
nn.LeakyReLU(0.1, True),
|
| 25 |
+
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2, device=device),
|
| 26 |
+
nn.InstanceNorm3d(out_channels, eps=1e-4, device=device),
|
| 27 |
+
nn.LeakyReLU(0.1, True),
|
| 28 |
+
]
|
| 29 |
+
self.voxel_layers = nn.Sequential(*voxel_layers)
|
| 30 |
+
self.point_features = SharedMLP(in_channels, out_channels, device=device)
|
| 31 |
+
|
| 32 |
+
def forward(self, inputs):
|
| 33 |
+
features, coords = inputs
|
| 34 |
+
voxel_features, voxel_coords = self.voxelization(features, coords)
|
| 35 |
+
voxel_features = self.voxel_layers(voxel_features)
|
| 36 |
+
devoxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training)
|
| 37 |
+
fused_features = devoxel_features + self.point_features(features)
|
| 38 |
+
return fused_features, coords, voxel_features
|
PartField/partfield/model/PVCNN/pv_module/shared_mlp.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
__all__ = ['SharedMLP']
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SharedMLP(nn.Module):
|
| 7 |
+
def __init__(self, in_channels, out_channels, dim=1, device='cuda'):
|
| 8 |
+
super().__init__()
|
| 9 |
+
# print('==> SharedMLP device: ', device)
|
| 10 |
+
if dim == 1:
|
| 11 |
+
conv = nn.Conv1d
|
| 12 |
+
bn = nn.InstanceNorm1d
|
| 13 |
+
elif dim == 2:
|
| 14 |
+
conv = nn.Conv2d
|
| 15 |
+
bn = nn.InstanceNorm1d
|
| 16 |
+
else:
|
| 17 |
+
raise ValueError
|
| 18 |
+
if not isinstance(out_channels, (list, tuple)):
|
| 19 |
+
out_channels = [out_channels]
|
| 20 |
+
layers = []
|
| 21 |
+
for oc in out_channels:
|
| 22 |
+
layers.extend(
|
| 23 |
+
[
|
| 24 |
+
conv(in_channels, oc, 1, device=device),
|
| 25 |
+
bn(oc, device=device),
|
| 26 |
+
nn.ReLU(True),
|
| 27 |
+
])
|
| 28 |
+
in_channels = oc
|
| 29 |
+
self.layers = nn.Sequential(*layers)
|
| 30 |
+
|
| 31 |
+
def forward(self, inputs):
|
| 32 |
+
if isinstance(inputs, (list, tuple)):
|
| 33 |
+
return (self.layers(inputs[0]), *inputs[1:])
|
| 34 |
+
else:
|
| 35 |
+
return self.layers(inputs)
|
PartField/partfield/model/PVCNN/pv_module/voxelization.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from . import functional as F
|
| 5 |
+
|
| 6 |
+
__all__ = ['Voxelization']
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def my_voxelization(features, coords, resolution):
|
| 10 |
+
b, c, _ = features.shape
|
| 11 |
+
result = torch.zeros(b, c + 1, resolution * resolution * resolution, device=features.device, dtype=torch.float)
|
| 12 |
+
r = resolution
|
| 13 |
+
r2 = resolution * resolution
|
| 14 |
+
indices = coords[:, 0] * r2 + coords[:, 1] * r + coords[:, 2]
|
| 15 |
+
indices = indices.unsqueeze(dim=1).expand(-1, result.shape[1], -1)
|
| 16 |
+
features = torch.cat([features, torch.ones(features.shape[0], 1, features.shape[2], device=features.device, dtype=features.dtype)], dim=1)
|
| 17 |
+
out_feature = result.scatter_(index=indices.long(), src=features, dim=2, reduce='add')
|
| 18 |
+
cnt = out_feature[:, -1:, :]
|
| 19 |
+
zero_mask = (cnt == 0).float()
|
| 20 |
+
cnt = cnt * (1 - zero_mask) + zero_mask * 1e-5
|
| 21 |
+
vox_feature = out_feature[:, :-1, :] / cnt
|
| 22 |
+
return vox_feature.view(b, c, resolution, resolution, resolution)
|
| 23 |
+
|
| 24 |
+
class Voxelization(nn.Module):
|
| 25 |
+
def __init__(self, resolution, normalize=True, eps=0, scale_pvcnn=False):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.r = int(resolution)
|
| 28 |
+
self.normalize = normalize
|
| 29 |
+
self.eps = eps
|
| 30 |
+
self.scale_pvcnn = scale_pvcnn
|
| 31 |
+
assert not normalize
|
| 32 |
+
|
| 33 |
+
def forward(self, features, coords):
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
coords = coords.detach()
|
| 36 |
+
|
| 37 |
+
if self.normalize:
|
| 38 |
+
norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5
|
| 39 |
+
else:
|
| 40 |
+
if self.scale_pvcnn:
|
| 41 |
+
norm_coords = (coords + 1) / 2.0 # [0, 1]
|
| 42 |
+
else:
|
| 43 |
+
norm_coords = (norm_coords + 1) / 2.0
|
| 44 |
+
norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
|
| 45 |
+
vox_coords = torch.round(norm_coords)
|
| 46 |
+
new_vox_feat = my_voxelization(features, vox_coords, self.r)
|
| 47 |
+
return new_vox_feat, norm_coords
|
| 48 |
+
|
| 49 |
+
def extra_repr(self):
|
| 50 |
+
return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '')
|
PartField/partfield/model/PVCNN/unet_3daware.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch.nn import init
|
| 7 |
+
|
| 8 |
+
import einops
|
| 9 |
+
|
| 10 |
+
def conv3x3(in_channels, out_channels, stride=1,
|
| 11 |
+
padding=1, bias=True, groups=1):
|
| 12 |
+
return nn.Conv2d(
|
| 13 |
+
in_channels,
|
| 14 |
+
out_channels,
|
| 15 |
+
kernel_size=3,
|
| 16 |
+
stride=stride,
|
| 17 |
+
padding=padding,
|
| 18 |
+
bias=bias,
|
| 19 |
+
groups=groups)
|
| 20 |
+
|
| 21 |
+
def upconv2x2(in_channels, out_channels, mode='transpose'):
|
| 22 |
+
if mode == 'transpose':
|
| 23 |
+
return nn.ConvTranspose2d(
|
| 24 |
+
in_channels,
|
| 25 |
+
out_channels,
|
| 26 |
+
kernel_size=2,
|
| 27 |
+
stride=2)
|
| 28 |
+
else:
|
| 29 |
+
# out_channels is always going to be the same
|
| 30 |
+
# as in_channels
|
| 31 |
+
return nn.Sequential(
|
| 32 |
+
nn.Upsample(mode='bilinear', scale_factor=2),
|
| 33 |
+
conv1x1(in_channels, out_channels))
|
| 34 |
+
|
| 35 |
+
def conv1x1(in_channels, out_channels, groups=1):
|
| 36 |
+
return nn.Conv2d(
|
| 37 |
+
in_channels,
|
| 38 |
+
out_channels,
|
| 39 |
+
kernel_size=1,
|
| 40 |
+
groups=groups,
|
| 41 |
+
stride=1)
|
| 42 |
+
|
| 43 |
+
class ConvTriplane3dAware(nn.Module):
|
| 44 |
+
""" 3D aware triplane conv (as described in RODIN) """
|
| 45 |
+
def __init__(self, internal_conv_f, in_channels, out_channels, order='xz'):
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
internal_conv_f: function that should return a 2D convolution Module
|
| 49 |
+
given in and out channels
|
| 50 |
+
order: if triplane input is in 'xz' order
|
| 51 |
+
"""
|
| 52 |
+
super(ConvTriplane3dAware, self).__init__()
|
| 53 |
+
# Need 3 seperate convolutions
|
| 54 |
+
self.in_channels = in_channels
|
| 55 |
+
self.out_channels = out_channels
|
| 56 |
+
assert order in ['xz', 'zx']
|
| 57 |
+
self.order = order
|
| 58 |
+
# Going to stack from other planes
|
| 59 |
+
self.plane_convs = nn.ModuleList([
|
| 60 |
+
internal_conv_f(3*self.in_channels, self.out_channels) for _ in range(3)])
|
| 61 |
+
|
| 62 |
+
def forward(self, triplanes_list):
|
| 63 |
+
"""
|
| 64 |
+
Args:
|
| 65 |
+
triplanes_list: [(B,Ci,H,W)]*3 in xy,yz,(zx or xz) depending on order
|
| 66 |
+
Returns:
|
| 67 |
+
out_triplanes_list: [(B,Co,H,W)]*3 in xy,yz,(zx or xz) depending on order
|
| 68 |
+
"""
|
| 69 |
+
inps = list(triplanes_list)
|
| 70 |
+
xp = 1 #(yz)
|
| 71 |
+
yp = 2 #(zx)
|
| 72 |
+
zp = 0 #(xy)
|
| 73 |
+
|
| 74 |
+
if self.order == 'xz':
|
| 75 |
+
# get into zx order
|
| 76 |
+
inps[yp] = einops.rearrange(inps[yp], 'b c x z -> b c z x')
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
oplanes = [None]*3
|
| 80 |
+
# order shouldn't matter
|
| 81 |
+
for iplane in [zp, xp, yp]:
|
| 82 |
+
# i_plane -> (j,k)
|
| 83 |
+
|
| 84 |
+
# need to average out i and convert to (j,k)
|
| 85 |
+
# j_plane -> (k,i)
|
| 86 |
+
# k_plane -> (i,j)
|
| 87 |
+
jplane = (iplane+1)%3
|
| 88 |
+
kplane = (iplane+2)%3
|
| 89 |
+
|
| 90 |
+
ifeat = inps[iplane]
|
| 91 |
+
# need to average out nonshared dim
|
| 92 |
+
# Average pool across
|
| 93 |
+
|
| 94 |
+
# j_plane -> (k,i) -> (k,1) -> (1,k) -> (j,k)
|
| 95 |
+
# b c k i -> b c k 1
|
| 96 |
+
jpool = torch.mean(inps[jplane], dim=3 ,keepdim=True)
|
| 97 |
+
jpool = einops.rearrange(jpool, 'b c k 1 -> b c 1 k')
|
| 98 |
+
jpool = einops.repeat(jpool, 'b c 1 k -> b c j k', j=ifeat.size(2))
|
| 99 |
+
|
| 100 |
+
# k_plane -> (i,j) -> (1,j) -> (j,1) -> (j,k)
|
| 101 |
+
# b c i j -> b c 1 j
|
| 102 |
+
kpool = torch.mean(inps[kplane], dim=2 ,keepdim=True)
|
| 103 |
+
kpool = einops.rearrange(kpool, 'b c 1 j -> b c j 1')
|
| 104 |
+
kpool = einops.repeat(kpool, 'b c j 1 -> b c j k', k=ifeat.size(3))
|
| 105 |
+
|
| 106 |
+
# b c h w
|
| 107 |
+
# jpool = jpool.expand_as(ifeat)
|
| 108 |
+
# kpool = kpool.expand_as(ifeat)
|
| 109 |
+
|
| 110 |
+
# concat and conv on feature dim
|
| 111 |
+
catfeat = torch.cat([ifeat, jpool, kpool], dim=1)
|
| 112 |
+
oplane = self.plane_convs[iplane](catfeat)
|
| 113 |
+
oplanes[iplane] = oplane
|
| 114 |
+
|
| 115 |
+
if self.order == 'xz':
|
| 116 |
+
# get back into xz order
|
| 117 |
+
oplanes[yp] = einops.rearrange(oplanes[yp], 'b c z x -> b c x z')
|
| 118 |
+
|
| 119 |
+
return oplanes
|
| 120 |
+
|
| 121 |
+
def roll_triplanes(triplanes_list):
|
| 122 |
+
# B, C, tri, h, w
|
| 123 |
+
tristack = torch.stack((triplanes_list),dim=2)
|
| 124 |
+
return einops.rearrange(tristack, 'b c tri h w -> b c (tri h) w', tri=3)
|
| 125 |
+
|
| 126 |
+
def unroll_triplanes(rolled_triplane):
|
| 127 |
+
# B, C, tri*h, w
|
| 128 |
+
tristack = einops.rearrange(rolled_triplane, 'b c (tri h) w -> b c tri h w', tri=3)
|
| 129 |
+
return torch.unbind(tristack, dim=2)
|
| 130 |
+
|
| 131 |
+
def conv1x1triplane3daware(in_channels, out_channels, order='xz', **kwargs):
|
| 132 |
+
return ConvTriplane3dAware(lambda inp, out: conv1x1(inp,out,**kwargs),
|
| 133 |
+
in_channels, out_channels,order=order)
|
| 134 |
+
|
| 135 |
+
def Normalize(in_channels, num_groups=32):
|
| 136 |
+
num_groups = min(in_channels, num_groups) # avoid error if in_channels < 32
|
| 137 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 138 |
+
|
| 139 |
+
def nonlinearity(x):
|
| 140 |
+
# return F.relu(x)
|
| 141 |
+
# Swish
|
| 142 |
+
return x*torch.sigmoid(x)
|
| 143 |
+
|
| 144 |
+
class Upsample(nn.Module):
|
| 145 |
+
def __init__(self, in_channels, with_conv):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.with_conv = with_conv
|
| 148 |
+
if self.with_conv:
|
| 149 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 150 |
+
in_channels,
|
| 151 |
+
kernel_size=3,
|
| 152 |
+
stride=1,
|
| 153 |
+
padding=1)
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 157 |
+
if self.with_conv:
|
| 158 |
+
x = self.conv(x)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
class Downsample(nn.Module):
|
| 162 |
+
def __init__(self, in_channels, with_conv):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.with_conv = with_conv
|
| 165 |
+
if self.with_conv:
|
| 166 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 167 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 168 |
+
in_channels,
|
| 169 |
+
kernel_size=3,
|
| 170 |
+
stride=2,
|
| 171 |
+
padding=0)
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
if self.with_conv:
|
| 175 |
+
pad = (0,1,0,1)
|
| 176 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 177 |
+
x = self.conv(x)
|
| 178 |
+
else:
|
| 179 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 180 |
+
return x
|
| 181 |
+
|
| 182 |
+
class ResnetBlock3dAware(nn.Module):
|
| 183 |
+
def __init__(self, in_channels, out_channels=None):
|
| 184 |
+
#, conv_shortcut=False):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.in_channels = in_channels
|
| 187 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 188 |
+
self.out_channels = out_channels
|
| 189 |
+
# self.use_conv_shortcut = conv_shortcut
|
| 190 |
+
|
| 191 |
+
self.norm1 = Normalize(in_channels)
|
| 192 |
+
self.conv1 = conv3x3(self.in_channels, self.out_channels)
|
| 193 |
+
|
| 194 |
+
self.norm_mid = Normalize(out_channels)
|
| 195 |
+
self.conv_3daware = conv1x1triplane3daware(self.out_channels, self.out_channels)
|
| 196 |
+
|
| 197 |
+
self.norm2 = Normalize(out_channels)
|
| 198 |
+
self.conv2 = conv3x3(self.out_channels, self.out_channels)
|
| 199 |
+
|
| 200 |
+
if self.in_channels != self.out_channels:
|
| 201 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
| 202 |
+
out_channels,
|
| 203 |
+
kernel_size=1,
|
| 204 |
+
stride=1,
|
| 205 |
+
padding=0)
|
| 206 |
+
|
| 207 |
+
def forward(self, x):
|
| 208 |
+
# 3x3 plane comm
|
| 209 |
+
h = x
|
| 210 |
+
h = self.norm1(h)
|
| 211 |
+
h = nonlinearity(h)
|
| 212 |
+
h = self.conv1(h)
|
| 213 |
+
|
| 214 |
+
# 1x1 3d aware, crossplane comm
|
| 215 |
+
h = self.norm_mid(h)
|
| 216 |
+
h = nonlinearity(h)
|
| 217 |
+
h = unroll_triplanes(h)
|
| 218 |
+
h = self.conv_3daware(h)
|
| 219 |
+
h = roll_triplanes(h)
|
| 220 |
+
|
| 221 |
+
# 3x3 plane comm
|
| 222 |
+
h = self.norm2(h)
|
| 223 |
+
h = nonlinearity(h)
|
| 224 |
+
h = self.conv2(h)
|
| 225 |
+
|
| 226 |
+
if self.in_channels != self.out_channels:
|
| 227 |
+
x = self.nin_shortcut(x)
|
| 228 |
+
|
| 229 |
+
return x+h
|
| 230 |
+
|
| 231 |
+
class DownConv3dAware(nn.Module):
|
| 232 |
+
"""
|
| 233 |
+
A helper Module that performs 2 convolutions and 1 MaxPool.
|
| 234 |
+
A ReLU activation follows each convolution.
|
| 235 |
+
"""
|
| 236 |
+
def __init__(self, in_channels, out_channels, downsample=True, with_conv=False):
|
| 237 |
+
super(DownConv3dAware, self).__init__()
|
| 238 |
+
|
| 239 |
+
self.in_channels = in_channels
|
| 240 |
+
self.out_channels = out_channels
|
| 241 |
+
|
| 242 |
+
self.block = ResnetBlock3dAware(in_channels=in_channels,
|
| 243 |
+
out_channels=out_channels)
|
| 244 |
+
|
| 245 |
+
self.do_downsample = downsample
|
| 246 |
+
self.downsample = Downsample(out_channels, with_conv=with_conv)
|
| 247 |
+
|
| 248 |
+
def forward(self, x):
|
| 249 |
+
"""
|
| 250 |
+
rolled input, rolled output
|
| 251 |
+
Args:
|
| 252 |
+
x: rolled (b c (tri*h) w)
|
| 253 |
+
"""
|
| 254 |
+
x = self.block(x)
|
| 255 |
+
before_pool = x
|
| 256 |
+
# if self.pooling:
|
| 257 |
+
# x = self.pool(x)
|
| 258 |
+
if self.do_downsample:
|
| 259 |
+
# unroll and cat channel-wise (to prevent pooling across triplane boundaries)
|
| 260 |
+
x = einops.rearrange(x, 'b c (tri h) w -> b (c tri) h w', tri=3)
|
| 261 |
+
x = self.downsample(x)
|
| 262 |
+
# undo
|
| 263 |
+
x = einops.rearrange(x, 'b (c tri) h w -> b c (tri h) w', tri=3)
|
| 264 |
+
return x, before_pool
|
| 265 |
+
|
| 266 |
+
class UpConv3dAware(nn.Module):
|
| 267 |
+
"""
|
| 268 |
+
A helper Module that performs 2 convolutions and 1 UpConvolution.
|
| 269 |
+
A ReLU activation follows each convolution.
|
| 270 |
+
"""
|
| 271 |
+
def __init__(self, in_channels, out_channels,
|
| 272 |
+
merge_mode='concat', with_conv=False): #up_mode='transpose', ):
|
| 273 |
+
super(UpConv3dAware, self).__init__()
|
| 274 |
+
|
| 275 |
+
self.in_channels = in_channels
|
| 276 |
+
self.out_channels = out_channels
|
| 277 |
+
self.merge_mode = merge_mode
|
| 278 |
+
|
| 279 |
+
self.upsample = Upsample(in_channels, with_conv)
|
| 280 |
+
|
| 281 |
+
if self.merge_mode == 'concat':
|
| 282 |
+
self.norm1 = Normalize(in_channels+out_channels)
|
| 283 |
+
self.block = ResnetBlock3dAware(in_channels=in_channels+out_channels,
|
| 284 |
+
out_channels=out_channels)
|
| 285 |
+
else:
|
| 286 |
+
self.norm1 = Normalize(in_channels)
|
| 287 |
+
self.block = ResnetBlock3dAware(in_channels=in_channels,
|
| 288 |
+
out_channels=out_channels)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def forward(self, from_down, from_up):
|
| 292 |
+
""" Forward pass
|
| 293 |
+
rolled inputs, rolled output
|
| 294 |
+
rolled (b c (tri*h) w)
|
| 295 |
+
Arguments:
|
| 296 |
+
from_down: tensor from the encoder pathway
|
| 297 |
+
from_up: upconv'd tensor from the decoder pathway
|
| 298 |
+
"""
|
| 299 |
+
# from_up = self.upconv(from_up)
|
| 300 |
+
from_up = self.upsample(from_up)
|
| 301 |
+
if self.merge_mode == 'concat':
|
| 302 |
+
x = torch.cat((from_up, from_down), 1)
|
| 303 |
+
else:
|
| 304 |
+
x = from_up + from_down
|
| 305 |
+
|
| 306 |
+
x = self.norm1(x)
|
| 307 |
+
x = self.block(x)
|
| 308 |
+
return x
|
| 309 |
+
|
| 310 |
+
class UNetTriplane3dAware(nn.Module):
|
| 311 |
+
def __init__(self, out_channels, in_channels=3, depth=5,
|
| 312 |
+
start_filts=64,# up_mode='transpose',
|
| 313 |
+
use_initial_conv=False,
|
| 314 |
+
merge_mode='concat', **kwargs):
|
| 315 |
+
"""
|
| 316 |
+
Arguments:
|
| 317 |
+
in_channels: int, number of channels in the input tensor.
|
| 318 |
+
Default is 3 for RGB images.
|
| 319 |
+
depth: int, number of MaxPools in the U-Net.
|
| 320 |
+
start_filts: int, number of convolutional filters for the
|
| 321 |
+
first conv.
|
| 322 |
+
"""
|
| 323 |
+
super(UNetTriplane3dAware, self).__init__()
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
self.out_channels = out_channels
|
| 327 |
+
self.in_channels = in_channels
|
| 328 |
+
self.start_filts = start_filts
|
| 329 |
+
self.depth = depth
|
| 330 |
+
|
| 331 |
+
self.use_initial_conv = use_initial_conv
|
| 332 |
+
if use_initial_conv:
|
| 333 |
+
self.conv_initial = conv1x1(self.in_channels, self.start_filts)
|
| 334 |
+
|
| 335 |
+
self.down_convs = []
|
| 336 |
+
self.up_convs = []
|
| 337 |
+
|
| 338 |
+
# create the encoder pathway and add to a list
|
| 339 |
+
for i in range(depth):
|
| 340 |
+
if i == 0:
|
| 341 |
+
ins = self.start_filts if use_initial_conv else self.in_channels
|
| 342 |
+
else:
|
| 343 |
+
ins = outs
|
| 344 |
+
outs = self.start_filts*(2**i)
|
| 345 |
+
downsamp_it = True if i < depth-1 else False
|
| 346 |
+
|
| 347 |
+
down_conv = DownConv3dAware(ins, outs, downsample = downsamp_it)
|
| 348 |
+
self.down_convs.append(down_conv)
|
| 349 |
+
|
| 350 |
+
for i in range(depth-1):
|
| 351 |
+
ins = outs
|
| 352 |
+
outs = ins // 2
|
| 353 |
+
up_conv = UpConv3dAware(ins, outs,
|
| 354 |
+
merge_mode=merge_mode)
|
| 355 |
+
self.up_convs.append(up_conv)
|
| 356 |
+
|
| 357 |
+
# add the list of modules to current module
|
| 358 |
+
self.down_convs = nn.ModuleList(self.down_convs)
|
| 359 |
+
self.up_convs = nn.ModuleList(self.up_convs)
|
| 360 |
+
|
| 361 |
+
self.norm_out = Normalize(outs)
|
| 362 |
+
self.conv_final = conv1x1(outs, self.out_channels)
|
| 363 |
+
|
| 364 |
+
self.reset_params()
|
| 365 |
+
|
| 366 |
+
@staticmethod
|
| 367 |
+
def weight_init(m):
|
| 368 |
+
if isinstance(m, nn.Conv2d):
|
| 369 |
+
# init.xavier_normal_(m.weight, gain=0.1)
|
| 370 |
+
init.xavier_normal_(m.weight)
|
| 371 |
+
init.constant_(m.bias, 0)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def reset_params(self):
|
| 375 |
+
for i, m in enumerate(self.modules()):
|
| 376 |
+
self.weight_init(m)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def forward(self, x):
|
| 380 |
+
"""
|
| 381 |
+
Args:
|
| 382 |
+
x: Stacked triplane expected to be in (B,3,C,H,W)
|
| 383 |
+
"""
|
| 384 |
+
# Roll
|
| 385 |
+
x = einops.rearrange(x, 'b tri c h w -> b c (tri h) w', tri=3)
|
| 386 |
+
|
| 387 |
+
if self.use_initial_conv:
|
| 388 |
+
x = self.conv_initial(x)
|
| 389 |
+
|
| 390 |
+
encoder_outs = []
|
| 391 |
+
# encoder pathway, save outputs for merging
|
| 392 |
+
for i, module in enumerate(self.down_convs):
|
| 393 |
+
x, before_pool = module(x)
|
| 394 |
+
encoder_outs.append(before_pool)
|
| 395 |
+
|
| 396 |
+
# Spend a block in the middle
|
| 397 |
+
# x = self.block_mid(x)
|
| 398 |
+
|
| 399 |
+
for i, module in enumerate(self.up_convs):
|
| 400 |
+
before_pool = encoder_outs[-(i+2)]
|
| 401 |
+
x = module(before_pool, x)
|
| 402 |
+
|
| 403 |
+
x = self.norm_out(x)
|
| 404 |
+
|
| 405 |
+
# No softmax is used. This means you need to use
|
| 406 |
+
# nn.CrossEntropyLoss is your training script,
|
| 407 |
+
# as this module includes a softmax already.
|
| 408 |
+
x = self.conv_final(nonlinearity(x))
|
| 409 |
+
|
| 410 |
+
# Unroll
|
| 411 |
+
x = einops.rearrange(x, 'b c (tri h) w -> b tri c h w', tri=3)
|
| 412 |
+
return x
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def setup_unet(output_channels, input_channels, unet_cfg):
|
| 416 |
+
if unet_cfg['use_3d_aware']:
|
| 417 |
+
assert(unet_cfg['rolled'])
|
| 418 |
+
unet = UNetTriplane3dAware(
|
| 419 |
+
out_channels=output_channels,
|
| 420 |
+
in_channels=input_channels,
|
| 421 |
+
depth=unet_cfg['depth'],
|
| 422 |
+
use_initial_conv=unet_cfg['use_initial_conv'],
|
| 423 |
+
start_filts=unet_cfg['start_hidden_channels'],)
|
| 424 |
+
else:
|
| 425 |
+
raise NotImplementedError
|
| 426 |
+
return unet
|
| 427 |
+
|
PartField/partfield/model/UNet/__pycache__/buildingblocks.cpython-310.pyc
ADDED
|
Binary file (17.2 kB). View file
|
|
|