Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,586 Bytes
4f22fc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
from partfield.config import default_argument_parser, setup
from lightning.pytorch import seed_everything, Trainer
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.callbacks import ModelCheckpoint
import torch
import glob
import os
import numpy as np
import random
import zipfile
from partfield.model.PVCNN.encoder_pc import sample_triplane_feat
def predict(cfg):
seed_everything(cfg.seed)
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
checkpoint_callbacks = [ModelCheckpoint(
monitor="train/current_epoch",
dirpath=cfg.output_dir,
filename="{epoch:02d}",
save_top_k=100,
save_last=True,
every_n_epochs=cfg.save_every_epoch,
mode="max",
verbose=True
)]
trainer = Trainer(devices=-1,
accelerator="gpu",
precision="16-mixed",
strategy=DDPStrategy(find_unused_parameters=True),
max_epochs=cfg.training_epochs,
log_every_n_steps=1,
limit_train_batches=3500,
limit_val_batches=None,
callbacks=checkpoint_callbacks
)
from partfield.model_trainer_pvcnn_only_demo import Model
model = Model(cfg)
if cfg.remesh_demo:
cfg.n_point_per_face = 10
trainer.predict(model, ckpt_path=cfg.continue_ckpt)
def main():
from tqdm import tqdm
parser = default_argument_parser()
parser.add_argument('--num_jobs', type=int, default=1, help='Total number of parallel jobs')
parser.add_argument('--job_id', type=int, default=0, help='Current job ID (0 to num_jobs-1)')
args = parser.parse_args()
cfg = setup(args, freeze=False)
cfg.is_pc = True
# Validate job arguments
if args.job_id >= args.num_jobs:
raise ValueError(f"job_id ({args.job_id}) must be less than num_jobs ({args.num_jobs})")
if args.job_id < 0:
raise ValueError(f"job_id ({args.job_id}) must be >= 0")
from partfield.model_trainer_pvcnn_only_demo import Model
model = Model.load_from_checkpoint(cfg.continue_ckpt, cfg=cfg)
model.eval()
model.to('cuda')
encode_pc_root = "/scratch/shared/beegfs/ruining/data/articulate-3d/Lightwheel/all-uniform-100k-singlestate-pts"
decode_pc_root = "/scratch/shared/beegfs/ruining/data/articulate-3d/Lightwheel/all-sharp50pct-40k-singlestate-pts"
dest_feat_root = "/scratch/shared/beegfs/ruining/data/articulate-3d/Lightwheel/all-sharp50pct-40k-singlestate-feats"
# Create destination directory
os.makedirs(dest_feat_root, exist_ok=True)
encode_files = sorted(glob.glob(os.path.join(encode_pc_root, "*.npy")))
decode_files = sorted(glob.glob(os.path.join(decode_pc_root, "*.npy")))
# Filter files for this job
job_files = [pair for i, pair in enumerate(zip(encode_files, decode_files)) if i % args.num_jobs == args.job_id]
print(f"Job {args.job_id}/{args.num_jobs}: Processing {len(job_files)}/{len(encode_files)} files")
num_bad_zip, num_failed_others = 0, 0
for encode_file, decode_file in tqdm(job_files, desc=f"Job {args.job_id}"):
try:
# Get UID from decode file (the one we're extracting features for)
uid = os.path.basename(decode_file).split('.')[0]
assert uid == os.path.basename(encode_file).split('.')[0]
dest_feat_file = os.path.join(dest_feat_root, f"{uid}.npy")
if os.path.exists(dest_feat_file):
continue
# Load both encode and decode point clouds
encode_pc = np.load(encode_file)
decode_pc = np.load(decode_file)
# Validate input data
if np.isnan(encode_pc).any() or np.isnan(decode_pc).any():
print(f"Skipping {uid}: NaN values in point cloud")
num_failed_others += 1
continue
if np.isinf(encode_pc).any() or np.isinf(decode_pc).any():
print(f"Skipping {uid}: Inf values in point cloud")
num_failed_others += 1
continue
# Compute bounding box from ALL points (encode + decode) for consistent normalization
all_points = np.vstack([encode_pc, decode_pc])
bbmin = all_points.min(0)
bbmax = all_points.max(0)
# Check for degenerate bounding box
bbox_size = (bbmax - bbmin).max()
if bbox_size < 1e-6:
print(f"Skipping {uid}: Degenerate bounding box (size={bbox_size})")
num_failed_others += 1
continue
center = (bbmin + bbmax) * 0.5
scale = 2.0 * 0.9 / bbox_size
# Apply same normalization to both point clouds
encode_pc_normalized = (encode_pc - center) * scale
decode_pc_normalized = (decode_pc - center) * scale
# Validate normalized coordinates
if np.isnan(encode_pc_normalized).any() or np.isnan(decode_pc_normalized).any():
print(f"Skipping {uid}: NaN in normalized coordinates")
num_failed_others += 1
continue
if np.isinf(encode_pc_normalized).any() or np.isinf(decode_pc_normalized).any():
print(f"Skipping {uid}: Inf in normalized coordinates")
num_failed_others += 1
continue
# Check if normalized coordinates are within reasonable range (should be ~[-1, 1])
encode_max = np.abs(encode_pc_normalized).max()
decode_max = np.abs(decode_pc_normalized).max()
if encode_max > 10 or decode_max > 10:
print(f"Skipping {uid}: Normalized coordinates out of range (encode_max={encode_max:.2f}, decode_max={decode_max:.2f})")
num_failed_others += 1
continue
# Use encode_pc to generate triplane
batch_encode_pc = torch.from_numpy(encode_pc_normalized).unsqueeze(0).float().to('cuda')
with torch.no_grad():
try:
# Generate triplane from encode_pc
pc_feat = model.pvcnn(batch_encode_pc, batch_encode_pc)
planes = model.triplane_transformer(pc_feat)
sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2)
# Sample features at decode_pc points
tensor_vertices = torch.from_numpy(decode_pc_normalized).reshape(1, -1, 3).to(torch.float32).cuda()
# Validate tensor before sampling
if torch.isnan(tensor_vertices).any() or torch.isinf(tensor_vertices).any():
print(f"Skipping {uid}: Invalid tensor_vertices after conversion to torch")
num_failed_others += 1
continue
point_feat = sample_triplane_feat(part_planes, tensor_vertices)
point_feat = point_feat.cpu().detach().numpy().reshape(-1, 448)
# Save point features
np.save(dest_feat_file, point_feat.astype(np.float16))
except RuntimeError as e:
if "CUDA" in str(e) or "index" in str(e).lower():
print(f"Skipping {uid}: CUDA error - {str(e)[:100]}")
print(f" encode shape: {encode_pc.shape}, decode shape: {decode_pc.shape}")
print(f" bbox_size: {bbox_size:.6f}, scale: {scale:.6f}")
print(f" normalized range: [{encode_pc_normalized.min():.3f}, {encode_pc_normalized.max():.3f}]")
num_failed_others += 1
# Clear CUDA cache to recover from error
torch.cuda.empty_cache()
continue
else:
raise
except zipfile.BadZipFile:
num_bad_zip += 1
continue
except Exception:
num_failed_others += 1
continue
print(f"Job {args.job_id} - Number of bad zip files: {num_bad_zip}")
print(f"Job {args.job_id} - Number of failed others: {num_failed_others}")
print(f"Job {args.job_id} completed successfully!")
if __name__ == "__main__":
main() |