Spaces:
Running
on
Zero
Running
on
Zero
| from partfield.config import default_argument_parser, setup | |
| from lightning.pytorch import seed_everything, Trainer | |
| from lightning.pytorch.strategies import DDPStrategy | |
| from lightning.pytorch.callbacks import ModelCheckpoint | |
| import torch | |
| import glob | |
| import os | |
| import numpy as np | |
| import random | |
| import zipfile | |
| from partfield.model.PVCNN.encoder_pc import sample_triplane_feat | |
| def predict(cfg): | |
| seed_everything(cfg.seed) | |
| torch.manual_seed(0) | |
| random.seed(0) | |
| np.random.seed(0) | |
| checkpoint_callbacks = [ModelCheckpoint( | |
| monitor="train/current_epoch", | |
| dirpath=cfg.output_dir, | |
| filename="{epoch:02d}", | |
| save_top_k=100, | |
| save_last=True, | |
| every_n_epochs=cfg.save_every_epoch, | |
| mode="max", | |
| verbose=True | |
| )] | |
| trainer = Trainer(devices=-1, | |
| accelerator="gpu", | |
| precision="16-mixed", | |
| strategy=DDPStrategy(find_unused_parameters=True), | |
| max_epochs=cfg.training_epochs, | |
| log_every_n_steps=1, | |
| limit_train_batches=3500, | |
| limit_val_batches=None, | |
| callbacks=checkpoint_callbacks | |
| ) | |
| from partfield.model_trainer_pvcnn_only_demo import Model | |
| model = Model(cfg) | |
| if cfg.remesh_demo: | |
| cfg.n_point_per_face = 10 | |
| trainer.predict(model, ckpt_path=cfg.continue_ckpt) | |
| def main(): | |
| from tqdm import tqdm | |
| parser = default_argument_parser() | |
| parser.add_argument('--num_jobs', type=int, default=1, help='Total number of parallel jobs') | |
| parser.add_argument('--job_id', type=int, default=0, help='Current job ID (0 to num_jobs-1)') | |
| args = parser.parse_args() | |
| cfg = setup(args, freeze=False) | |
| cfg.is_pc = True | |
| # Validate job arguments | |
| if args.job_id >= args.num_jobs: | |
| raise ValueError(f"job_id ({args.job_id}) must be less than num_jobs ({args.num_jobs})") | |
| if args.job_id < 0: | |
| raise ValueError(f"job_id ({args.job_id}) must be >= 0") | |
| from partfield.model_trainer_pvcnn_only_demo import Model | |
| model = Model.load_from_checkpoint(cfg.continue_ckpt, cfg=cfg) | |
| model.eval() | |
| model.to('cuda') | |
| encode_pc_root = "/scratch/shared/beegfs/ruining/data/articulate-3d/Lightwheel/all-uniform-100k-singlestate-pts" | |
| decode_pc_root = "/scratch/shared/beegfs/ruining/data/articulate-3d/Lightwheel/all-sharp50pct-40k-singlestate-pts" | |
| dest_feat_root = "/scratch/shared/beegfs/ruining/data/articulate-3d/Lightwheel/all-sharp50pct-40k-singlestate-feats" | |
| # Create destination directory | |
| os.makedirs(dest_feat_root, exist_ok=True) | |
| encode_files = sorted(glob.glob(os.path.join(encode_pc_root, "*.npy"))) | |
| decode_files = sorted(glob.glob(os.path.join(decode_pc_root, "*.npy"))) | |
| # Filter files for this job | |
| job_files = [pair for i, pair in enumerate(zip(encode_files, decode_files)) if i % args.num_jobs == args.job_id] | |
| print(f"Job {args.job_id}/{args.num_jobs}: Processing {len(job_files)}/{len(encode_files)} files") | |
| num_bad_zip, num_failed_others = 0, 0 | |
| for encode_file, decode_file in tqdm(job_files, desc=f"Job {args.job_id}"): | |
| try: | |
| # Get UID from decode file (the one we're extracting features for) | |
| uid = os.path.basename(decode_file).split('.')[0] | |
| assert uid == os.path.basename(encode_file).split('.')[0] | |
| dest_feat_file = os.path.join(dest_feat_root, f"{uid}.npy") | |
| if os.path.exists(dest_feat_file): | |
| continue | |
| # Load both encode and decode point clouds | |
| encode_pc = np.load(encode_file) | |
| decode_pc = np.load(decode_file) | |
| # Validate input data | |
| if np.isnan(encode_pc).any() or np.isnan(decode_pc).any(): | |
| print(f"Skipping {uid}: NaN values in point cloud") | |
| num_failed_others += 1 | |
| continue | |
| if np.isinf(encode_pc).any() or np.isinf(decode_pc).any(): | |
| print(f"Skipping {uid}: Inf values in point cloud") | |
| num_failed_others += 1 | |
| continue | |
| # Compute bounding box from ALL points (encode + decode) for consistent normalization | |
| all_points = np.vstack([encode_pc, decode_pc]) | |
| bbmin = all_points.min(0) | |
| bbmax = all_points.max(0) | |
| # Check for degenerate bounding box | |
| bbox_size = (bbmax - bbmin).max() | |
| if bbox_size < 1e-6: | |
| print(f"Skipping {uid}: Degenerate bounding box (size={bbox_size})") | |
| num_failed_others += 1 | |
| continue | |
| center = (bbmin + bbmax) * 0.5 | |
| scale = 2.0 * 0.9 / bbox_size | |
| # Apply same normalization to both point clouds | |
| encode_pc_normalized = (encode_pc - center) * scale | |
| decode_pc_normalized = (decode_pc - center) * scale | |
| # Validate normalized coordinates | |
| if np.isnan(encode_pc_normalized).any() or np.isnan(decode_pc_normalized).any(): | |
| print(f"Skipping {uid}: NaN in normalized coordinates") | |
| num_failed_others += 1 | |
| continue | |
| if np.isinf(encode_pc_normalized).any() or np.isinf(decode_pc_normalized).any(): | |
| print(f"Skipping {uid}: Inf in normalized coordinates") | |
| num_failed_others += 1 | |
| continue | |
| # Check if normalized coordinates are within reasonable range (should be ~[-1, 1]) | |
| encode_max = np.abs(encode_pc_normalized).max() | |
| decode_max = np.abs(decode_pc_normalized).max() | |
| if encode_max > 10 or decode_max > 10: | |
| print(f"Skipping {uid}: Normalized coordinates out of range (encode_max={encode_max:.2f}, decode_max={decode_max:.2f})") | |
| num_failed_others += 1 | |
| continue | |
| # Use encode_pc to generate triplane | |
| batch_encode_pc = torch.from_numpy(encode_pc_normalized).unsqueeze(0).float().to('cuda') | |
| with torch.no_grad(): | |
| try: | |
| # Generate triplane from encode_pc | |
| pc_feat = model.pvcnn(batch_encode_pc, batch_encode_pc) | |
| planes = model.triplane_transformer(pc_feat) | |
| sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2) | |
| # Sample features at decode_pc points | |
| tensor_vertices = torch.from_numpy(decode_pc_normalized).reshape(1, -1, 3).to(torch.float32).cuda() | |
| # Validate tensor before sampling | |
| if torch.isnan(tensor_vertices).any() or torch.isinf(tensor_vertices).any(): | |
| print(f"Skipping {uid}: Invalid tensor_vertices after conversion to torch") | |
| num_failed_others += 1 | |
| continue | |
| point_feat = sample_triplane_feat(part_planes, tensor_vertices) | |
| point_feat = point_feat.cpu().detach().numpy().reshape(-1, 448) | |
| # Save point features | |
| np.save(dest_feat_file, point_feat.astype(np.float16)) | |
| except RuntimeError as e: | |
| if "CUDA" in str(e) or "index" in str(e).lower(): | |
| print(f"Skipping {uid}: CUDA error - {str(e)[:100]}") | |
| print(f" encode shape: {encode_pc.shape}, decode shape: {decode_pc.shape}") | |
| print(f" bbox_size: {bbox_size:.6f}, scale: {scale:.6f}") | |
| print(f" normalized range: [{encode_pc_normalized.min():.3f}, {encode_pc_normalized.max():.3f}]") | |
| num_failed_others += 1 | |
| # Clear CUDA cache to recover from error | |
| torch.cuda.empty_cache() | |
| continue | |
| else: | |
| raise | |
| except zipfile.BadZipFile: | |
| num_bad_zip += 1 | |
| continue | |
| except Exception: | |
| num_failed_others += 1 | |
| continue | |
| print(f"Job {args.job_id} - Number of bad zip files: {num_bad_zip}") | |
| print(f"Job {args.job_id} - Number of failed others: {num_failed_others}") | |
| print(f"Job {args.job_id} completed successfully!") | |
| if __name__ == "__main__": | |
| main() |