particulate / PartField /partfield_inference_pc.py
Ruining Li
Init: add PartField + particulate, track example assets via LFS
4f22fc0
from partfield.config import default_argument_parser, setup
from lightning.pytorch import seed_everything, Trainer
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.callbacks import ModelCheckpoint
import torch
import glob
import os
import numpy as np
import random
import zipfile
from partfield.model.PVCNN.encoder_pc import sample_triplane_feat
def predict(cfg):
seed_everything(cfg.seed)
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
checkpoint_callbacks = [ModelCheckpoint(
monitor="train/current_epoch",
dirpath=cfg.output_dir,
filename="{epoch:02d}",
save_top_k=100,
save_last=True,
every_n_epochs=cfg.save_every_epoch,
mode="max",
verbose=True
)]
trainer = Trainer(devices=-1,
accelerator="gpu",
precision="16-mixed",
strategy=DDPStrategy(find_unused_parameters=True),
max_epochs=cfg.training_epochs,
log_every_n_steps=1,
limit_train_batches=3500,
limit_val_batches=None,
callbacks=checkpoint_callbacks
)
from partfield.model_trainer_pvcnn_only_demo import Model
model = Model(cfg)
if cfg.remesh_demo:
cfg.n_point_per_face = 10
trainer.predict(model, ckpt_path=cfg.continue_ckpt)
def main():
from tqdm import tqdm
parser = default_argument_parser()
parser.add_argument('--num_jobs', type=int, default=1, help='Total number of parallel jobs')
parser.add_argument('--job_id', type=int, default=0, help='Current job ID (0 to num_jobs-1)')
args = parser.parse_args()
cfg = setup(args, freeze=False)
cfg.is_pc = True
# Validate job arguments
if args.job_id >= args.num_jobs:
raise ValueError(f"job_id ({args.job_id}) must be less than num_jobs ({args.num_jobs})")
if args.job_id < 0:
raise ValueError(f"job_id ({args.job_id}) must be >= 0")
from partfield.model_trainer_pvcnn_only_demo import Model
model = Model.load_from_checkpoint(cfg.continue_ckpt, cfg=cfg)
model.eval()
model.to('cuda')
encode_pc_root = "/scratch/shared/beegfs/ruining/data/articulate-3d/Lightwheel/all-uniform-100k-singlestate-pts"
decode_pc_root = "/scratch/shared/beegfs/ruining/data/articulate-3d/Lightwheel/all-sharp50pct-40k-singlestate-pts"
dest_feat_root = "/scratch/shared/beegfs/ruining/data/articulate-3d/Lightwheel/all-sharp50pct-40k-singlestate-feats"
# Create destination directory
os.makedirs(dest_feat_root, exist_ok=True)
encode_files = sorted(glob.glob(os.path.join(encode_pc_root, "*.npy")))
decode_files = sorted(glob.glob(os.path.join(decode_pc_root, "*.npy")))
# Filter files for this job
job_files = [pair for i, pair in enumerate(zip(encode_files, decode_files)) if i % args.num_jobs == args.job_id]
print(f"Job {args.job_id}/{args.num_jobs}: Processing {len(job_files)}/{len(encode_files)} files")
num_bad_zip, num_failed_others = 0, 0
for encode_file, decode_file in tqdm(job_files, desc=f"Job {args.job_id}"):
try:
# Get UID from decode file (the one we're extracting features for)
uid = os.path.basename(decode_file).split('.')[0]
assert uid == os.path.basename(encode_file).split('.')[0]
dest_feat_file = os.path.join(dest_feat_root, f"{uid}.npy")
if os.path.exists(dest_feat_file):
continue
# Load both encode and decode point clouds
encode_pc = np.load(encode_file)
decode_pc = np.load(decode_file)
# Validate input data
if np.isnan(encode_pc).any() or np.isnan(decode_pc).any():
print(f"Skipping {uid}: NaN values in point cloud")
num_failed_others += 1
continue
if np.isinf(encode_pc).any() or np.isinf(decode_pc).any():
print(f"Skipping {uid}: Inf values in point cloud")
num_failed_others += 1
continue
# Compute bounding box from ALL points (encode + decode) for consistent normalization
all_points = np.vstack([encode_pc, decode_pc])
bbmin = all_points.min(0)
bbmax = all_points.max(0)
# Check for degenerate bounding box
bbox_size = (bbmax - bbmin).max()
if bbox_size < 1e-6:
print(f"Skipping {uid}: Degenerate bounding box (size={bbox_size})")
num_failed_others += 1
continue
center = (bbmin + bbmax) * 0.5
scale = 2.0 * 0.9 / bbox_size
# Apply same normalization to both point clouds
encode_pc_normalized = (encode_pc - center) * scale
decode_pc_normalized = (decode_pc - center) * scale
# Validate normalized coordinates
if np.isnan(encode_pc_normalized).any() or np.isnan(decode_pc_normalized).any():
print(f"Skipping {uid}: NaN in normalized coordinates")
num_failed_others += 1
continue
if np.isinf(encode_pc_normalized).any() or np.isinf(decode_pc_normalized).any():
print(f"Skipping {uid}: Inf in normalized coordinates")
num_failed_others += 1
continue
# Check if normalized coordinates are within reasonable range (should be ~[-1, 1])
encode_max = np.abs(encode_pc_normalized).max()
decode_max = np.abs(decode_pc_normalized).max()
if encode_max > 10 or decode_max > 10:
print(f"Skipping {uid}: Normalized coordinates out of range (encode_max={encode_max:.2f}, decode_max={decode_max:.2f})")
num_failed_others += 1
continue
# Use encode_pc to generate triplane
batch_encode_pc = torch.from_numpy(encode_pc_normalized).unsqueeze(0).float().to('cuda')
with torch.no_grad():
try:
# Generate triplane from encode_pc
pc_feat = model.pvcnn(batch_encode_pc, batch_encode_pc)
planes = model.triplane_transformer(pc_feat)
sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2)
# Sample features at decode_pc points
tensor_vertices = torch.from_numpy(decode_pc_normalized).reshape(1, -1, 3).to(torch.float32).cuda()
# Validate tensor before sampling
if torch.isnan(tensor_vertices).any() or torch.isinf(tensor_vertices).any():
print(f"Skipping {uid}: Invalid tensor_vertices after conversion to torch")
num_failed_others += 1
continue
point_feat = sample_triplane_feat(part_planes, tensor_vertices)
point_feat = point_feat.cpu().detach().numpy().reshape(-1, 448)
# Save point features
np.save(dest_feat_file, point_feat.astype(np.float16))
except RuntimeError as e:
if "CUDA" in str(e) or "index" in str(e).lower():
print(f"Skipping {uid}: CUDA error - {str(e)[:100]}")
print(f" encode shape: {encode_pc.shape}, decode shape: {decode_pc.shape}")
print(f" bbox_size: {bbox_size:.6f}, scale: {scale:.6f}")
print(f" normalized range: [{encode_pc_normalized.min():.3f}, {encode_pc_normalized.max():.3f}]")
num_failed_others += 1
# Clear CUDA cache to recover from error
torch.cuda.empty_cache()
continue
else:
raise
except zipfile.BadZipFile:
num_bad_zip += 1
continue
except Exception:
num_failed_others += 1
continue
print(f"Job {args.job_id} - Number of bad zip files: {num_bad_zip}")
print(f"Job {args.job_id} - Number of failed others: {num_failed_others}")
print(f"Job {args.job_id} completed successfully!")
if __name__ == "__main__":
main()