import math from typing import Optional, Tuple, Iterable, Dict import os, json, warnings import torch import torch.nn as nn from mmengine import print_log from transformers.models.qwen2.modeling_qwen2 import ( Qwen2RMSNorm, Qwen2MLP, eager_attention_forward, ) from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS # ----------------------------- Helpers ---------------------------------------- def sinusoidal_positions(L: int, D: int, device=None, base: float = 10000.0): """ Standard Transformer sinusoidal absolute positions: (1, L, D) Works for any D (even/odd handled by slicing). """ position = torch.arange(L, device=device, dtype=torch.float32).unsqueeze(1) # (L,1) div_term = torch.exp(torch.arange(0, D, 2, device=device, dtype=torch.float32) * (-math.log(base) / max(1, D // 2))) pe = torch.zeros(L, D, dtype=torch.float32, device=device) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) return pe.unsqueeze(0) # (1, L, D) # ------------------------- Cross-Attention ------------------------------------ class Qwen2CrossAttention(nn.Module): """ Cross-attention that mirrors Qwen2Attention's backend interface. - Queries: concatenated [K learnable latents ⊕ text tokens] - Keys/Values: visual tokens - Backends: eager / sdpa / flash_attn2 via ALL_ATTENTION_FUNCTIONS - Positional handling: none here (positions are added in the caller). """ def __init__(self, config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_kv_heads self.scaling = self.head_dim ** -0.5 self.attention_dropout = config.attention_dropout self.is_causal = False # cross-attn is not causal self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) def forward( self, query_hidden_states: torch.Tensor, # (B, Q, D) -> [latents ⊕ text] key_value_hidden_states: torch.Tensor, # (B, N, D) -> visual tokens attention_mask: Optional[torch.Tensor] = None, # ignored; always None in our caller **kwargs, ): B, Q, _ = query_hidden_states.shape N = key_value_hidden_states.shape[1] Hd = self.head_dim # Projections q = self.q_proj(query_hidden_states).view(B, Q, self.num_heads, Hd).transpose(1, 2) # (B, H, Q, Hd) k = self.k_proj(key_value_hidden_states).view(B, N, self.num_kv_heads, Hd).transpose(1, 2) v = self.v_proj(key_value_hidden_states).view(B, N, self.num_kv_heads, Hd).transpose(1, 2) attention_interface = eager_attention_forward if getattr(self.config, "_attn_implementation", "eager") != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, _ = attention_interface( self, q, k, v, attention_mask=None, # <- force None (no masks) dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=None, **kwargs, ) attn_output = attn_output.transpose(1, 2).contiguous().view(B, Q, self.num_heads * Hd) return self.o_proj(attn_output) # ------------------------- Perceiver Resampler -------------------------------- class PerceiverResampler(nn.Module): """ Perceiver-style resampler with: - Optional concatenation of text tokens to the learnable query latents (controlled by `concat_text_to_queries`) - NO attention masks - NO RoPE - A single learned absolute positional embedding applied to queries over either [latents ⊕ text] or [latents] only - Optional gradient checkpointing per block Args: llm: Qwen2 model (or wrapper exposing .model and get_input_embeddings()). num_latents: number of learnable latent slots (K). depth: number of Perceiver blocks. max_text_len: maximum supported text length (T_max). concat_text_to_queries: if True, queries are [latents ⊕ text]; if False, queries are [latents] only. """ def __init__( self, llm, num_latents: int = 64, depth: int = 2, *, max_text_len: int = 4096, concat_text_to_queries: bool = True, ): super().__init__() base = llm.model if hasattr(llm, "model") else llm self.config = base.config self.hidden_size = self.config.hidden_size self.num_latents = num_latents self.depth = depth self.max_text_len = max_text_len # NEW: whether to append text to query slots self.concat_text_to_queries: bool = concat_text_to_queries # Learnable latent queries (K, D) self.latents = nn.Parameter( torch.randn(1, num_latents, self.hidden_size) / math.sqrt(self.hidden_size) ) self.visual_ln = Qwen2RMSNorm(self.hidden_size, eps=self.config.rms_norm_eps) # Perceiver blocks self.blocks = nn.ModuleList() for i in range(depth): self.blocks.append(nn.ModuleDict({ "input_ln": Qwen2RMSNorm(self.hidden_size, eps=self.config.rms_norm_eps), "cross_attn": Qwen2CrossAttention(self.config, layer_idx=i), "post_ln": Qwen2RMSNorm(self.hidden_size, eps=self.config.rms_norm_eps), "mlp": Qwen2MLP(self.config), })) # Learned ABS positional embedding across query positions. # We allocate enough positions for the maximum possible query length: K + T_max. self.query_pos = nn.Embedding(self.num_latents + self.max_text_len, self.hidden_size) nn.init.normal_(self.query_pos.weight, mean=0.0, std=0.02) self.resid_scale = 1.0 / math.sqrt(2.0) # Optional (handy if you embed ids upstream) self.text_embed = llm.get_input_embeddings() # ---- Gradient checkpointing controls (off by default) ---- self.gradient_checkpointing: bool = False self.gc_use_reentrant: bool = True self.gc_preserve_rng_state: bool = True # Public helpers to toggle checkpointing def enable_gradient_checkpointing(self, *, use_reentrant: bool = True, preserve_rng_state: bool = True): self.gradient_checkpointing = True self.gc_use_reentrant = use_reentrant self.gc_preserve_rng_state = preserve_rng_state def disable_gradient_checkpointing(self): self.gradient_checkpointing = False def enable_input_require_grads(self): def make_inputs_require_grad(module, input, output): output.requires_grad_(True) self.register_forward_hook(make_inputs_require_grad) # NEW: runtime toggle def set_concat_text_queries(self, enabled: bool): """If False, only the learnable latents are used as queries.""" self.concat_text_to_queries = enabled def forward( self, text_embeddings: torch.FloatTensor, # (B, T, D) visual_tokens: torch.FloatTensor, # (B, N, D) attention_mask: Optional[torch.Tensor] = None, # ignored visual_mask: Optional[torch.Tensor] = None, # ignored ) -> torch.Tensor: device = self.latents.device B, T, D = text_embeddings.shape assert D == self.hidden_size, f"text hidden {D} != {self.hidden_size}" if T > self.max_text_len: raise ValueError( f"text length {T} exceeds max_text_len={self.max_text_len}; " f"increase max_text_len when constructing PerceiverResampler." ) # Queries: either [latents ⊕ text] (default) or [latents] only K = self.num_latents Q_lat = self.latents.expand(B, -1, -1) # (B, K, D) if self.concat_text_to_queries: x = torch.cat([Q_lat, text_embeddings], dim=1) # (B, K+T, D) else: x = Q_lat # (B, K, D) Q = x.size(1) # Learned absolute positions across queries (consume first Q positions) pos_ids = torch.arange(Q, device=device).unsqueeze(0).expand(B, -1) # (B, Q) x = x + self.query_pos(pos_ids) visual_tokens = self.visual_ln(visual_tokens) # Per-block forward (optionally checkpointed) for blk_idx, blk in enumerate(self.blocks): def _block_fn(t_x, t_v): r = t_x xn = blk["input_ln"](t_x) a = blk["cross_attn"](xn, t_v, attention_mask=None) t_x = r + a * self.resid_scale r = t_x xn = blk["post_ln"](t_x) return r + blk["mlp"](xn) * self.resid_scale if self.gradient_checkpointing and self.training: x = torch.utils.checkpoint.checkpoint( _block_fn, x, visual_tokens, use_reentrant=self.gc_use_reentrant, preserve_rng_state=self.gc_preserve_rng_state, ) else: x = _block_fn(x, visual_tokens) # Return only the latent slots (K, D) return x[:, :self.num_latents, :] try: from safetensors.torch import load_file as safe_load_file _HAS_SAFE = True except Exception: _HAS_SAFE = False def _find_weight_index(ckpt_dir: str) -> Optional[str]: """Return the path to the model weight index (if sharded) or None.""" cands = ["model.safetensors.index.json", "pytorch_model.bin.index.json"] for c in cands: p = os.path.join(ckpt_dir, c) if os.path.isfile(p): return p return None def _list_all_weight_files(ckpt_dir: str) -> Iterable[str]: """Yield all likely weight files in a directory.""" for name in os.listdir(ckpt_dir): if name.endswith(".safetensors") or name.endswith(".bin"): # skip the top-level consolidated adapter/optimizer etc. if "optimizer" in name or "trainer" in name or name.endswith(".index.json"): continue yield os.path.join(ckpt_dir, name) def _load_shard(shard_path: str) -> Dict[str, torch.Tensor]: """Load one shard to CPU.""" if shard_path.endswith(".safetensors"): if not _HAS_SAFE: raise RuntimeError("safetensors not available; install safetensors or provide .bin weights.") return safe_load_file(shard_path, device="cpu") return torch.load(shard_path, map_location="cpu") def _gather_needed_tensors_from_checkpoint(ckpt_dir: str, needed_keys: Iterable[str]) -> Dict[str, torch.Tensor]: """ Load only the tensors we need from a (possibly sharded) HF checkpoint dir. """ needed = set(needed_keys) out: Dict[str, torch.Tensor] = {} index_path = _find_weight_index(ckpt_dir) if index_path is None: # Non-sharded: scan files and pick keys if present for fpath in _list_all_weight_files(ckpt_dir): shard = _load_shard(fpath) for k in list(needed): if k in shard: out[k] = shard[k] needed.remove(k) # free asap del shard if not needed: break else: # Sharded: index maps param key -> shard filename with open(index_path, "r") as f: idx = json.load(f) weight_map = idx.get("weight_map") or idx.get("weight_map", {}) # group by shard shard_to_keys: Dict[str, list] = {} for k in needed: shard_name = weight_map.get(k) if shard_name is None: continue shard_to_keys.setdefault(shard_name, []).append(k) # load per shard for shard_name, keys in shard_to_keys.items(): shard_path = os.path.join(ckpt_dir, shard_name) shard = _load_shard(shard_path) for k in keys: if k in shard: out[k] = shard[k] del shard needed = {k for k in needed if k not in out} if needed: missing_sorted = "\n - " + "\n - ".join(sorted(needed)) raise KeyError(f"Missing keys in checkpoint for Perceiver init:{missing_sorted}") return out def _copy_param_like(dst_param: torch.nn.Parameter, src_tensor: torch.Tensor): print_log( f'Copying param {dst_param.shape} <- {src_tensor.shape}', logger='current' ) dst_param.data.copy_(src_tensor.to(dtype=dst_param.dtype, device=dst_param.device)) def _safe_copy_linear_from_tensor(dst_linear: torch.nn.Linear, w: torch.Tensor, b: Optional[torch.Tensor]): if dst_linear.weight.shape != w.shape: raise RuntimeError( f"Shape mismatch copying linear: dst {tuple(dst_linear.weight.shape)} vs src {tuple(w.shape)}" ) print_log( f'Copying linear {dst_linear.weight.shape}, bias={dst_linear.bias is not None}', logger='current' ) dst_linear.weight.data.copy_(w.to(dtype=dst_linear.weight.dtype, device=dst_linear.weight.device)) if dst_linear.bias is not None: if b is not None: dst_linear.bias.data.copy_(b.to(dtype=dst_linear.bias.dtype, device=dst_linear.bias.device)) else: dst_linear.bias.data.zero_() def init_perceiver_from_llm_checkpoint( perceiver, ckpt_dir: str, init_from_layers: Optional[int] = None, layer_offset: int = 0, ): """ Initialize PerceiverResampler from the raw LLM checkpoint files on disk. - Supports .safetensors or .bin, sharded or single-file. - Copies: input/post norms, q/k/v/o, mlp gate/up/down for the first `L` layers. - `layer_offset` lets you start from a later LLM block if you prefer. Args: perceiver: PerceiverResampler instance (with .blocks ModuleList) ckpt_dir: path to LLM checkpoint directory (the one you pass to from_pretrained) init_from_layers: how many LLM layers to use (defaults to perceiver.depth) layer_offset: start copying from LLM layer `layer_offset` (default 0) """ base_depth = perceiver.depth L = min(init_from_layers or base_depth, base_depth) # Build the list of keys we need from the LLM checkpoint needed = [] for i in range(L): li = i + layer_offset prefix = f"model.layers.{li}" # norms needed += [ f"{prefix}.input_layernorm.weight", f"{prefix}.post_attention_layernorm.weight", ] # attention needed += [ f"{prefix}.self_attn.q_proj.weight", f"{prefix}.self_attn.k_proj.weight", f"{prefix}.self_attn.v_proj.weight", f"{prefix}.self_attn.o_proj.weight", ] # mlp needed += [ f"{prefix}.mlp.gate_proj.weight", f"{prefix}.mlp.up_proj.weight", f"{prefix}.mlp.down_proj.weight", ] # optional biases (some Qwen2 variants have none; we’ll tolerate missing) needed += [ f"{prefix}.self_attn.q_proj.bias", f"{prefix}.self_attn.k_proj.bias", f"{prefix}.self_attn.v_proj.bias", f"{prefix}.self_attn.o_proj.bias", f"{prefix}.mlp.gate_proj.bias", f"{prefix}.mlp.up_proj.bias", f"{prefix}.mlp.down_proj.bias", ] # Load what's available; we’ll allow bias keys to be missing without failing: try: tensors = _gather_needed_tensors_from_checkpoint(ckpt_dir, [k for k in needed if "bias" not in k]) except KeyError as e: # Re-raise with a hint if fuse-qkv is detected msg = str(e) if "W_pack" in msg or "qkv" in msg: msg += "\nDetected fused QKV in checkpoint. This loader expects separate q_proj/k_proj/v_proj. "\ "If your checkpoint uses fused QKV (e.g., *.W_pack.weight), we’ll need a small slicer—ping me." raise # Biases: try to load, but don't error if absent bias_tensors: Dict[str, torch.Tensor] = {} idx_path = _find_weight_index(ckpt_dir) if idx_path is not None: with open(idx_path, "r") as f: idx = json.load(f) wmap = idx.get("weight_map") or {} shard_to_biases: Dict[str, list] = {} for k in needed: if "bias" not in k: continue sn = wmap.get(k) if sn is None: continue shard_to_biases.setdefault(sn, []).append(k) for sn, keys in shard_to_biases.items(): shard_path = os.path.join(ckpt_dir, sn) shard = _load_shard(shard_path) for k in keys: if k in shard: bias_tensors[k] = shard[k] del shard else: # non-sharded: scan files once for fpath in _list_all_weight_files(ckpt_dir): shard = _load_shard(fpath) for k in [k for k in needed if "bias" in k]: if k in shard: bias_tensors[k] = shard[k] del shard # Copy into perceiver blocks with torch.no_grad(): for i in range(L): li = i + layer_offset prefix = f"model.layers.{li}" dst = perceiver.blocks[i] # norms _copy_param_like(dst["input_ln"].weight, tensors[f"{prefix}.input_layernorm.weight"]) _copy_param_like(dst["post_ln"].weight, tensors[f"{prefix}.post_attention_layernorm.weight"]) # attention _safe_copy_linear_from_tensor( dst["cross_attn"].q_proj, tensors[f"{prefix}.self_attn.q_proj.weight"], bias_tensors.get(f"{prefix}.self_attn.q_proj.bias"), ) _safe_copy_linear_from_tensor( dst["cross_attn"].k_proj, tensors[f"{prefix}.self_attn.k_proj.weight"], bias_tensors.get(f"{prefix}.self_attn.k_proj.bias"), ) _safe_copy_linear_from_tensor( dst["cross_attn"].v_proj, tensors[f"{prefix}.self_attn.v_proj.weight"], bias_tensors.get(f"{prefix}.self_attn.v_proj.bias"), ) _safe_copy_linear_from_tensor( dst["cross_attn"].o_proj, tensors[f"{prefix}.self_attn.o_proj.weight"], bias_tensors.get(f"{prefix}.self_attn.o_proj.bias"), ) # mlp _safe_copy_linear_from_tensor( dst["mlp"].gate_proj, tensors[f"{prefix}.mlp.gate_proj.weight"], bias_tensors.get(f"{prefix}.mlp.gate_proj.bias"), ) _safe_copy_linear_from_tensor( dst["mlp"].up_proj, tensors[f"{prefix}.mlp.up_proj.weight"], bias_tensors.get(f"{prefix}.mlp.up_proj.bias"), ) _safe_copy_linear_from_tensor( dst["mlp"].down_proj, tensors[f"{prefix}.mlp.down_proj.weight"], bias_tensors.get(f"{prefix}.mlp.down_proj.bias"), ) def resolve_llm_checkpoint_dir(llm, ckpt_hint: str | None = None, allow_download: bool = False) -> str | None: """ Try to find a local directory for the LLM checkpoint. - If 'ckpt_hint' is a real directory, use it. - Else check llm.config._name_or_path / llm.name_or_path. - If those look like HF repo ids, query the local HF cache (snapshot_download with local_files_only=True). - If still not found and 'allow_download' is True, try downloading (optional). Returns an absolute directory path or None if it couldn't be resolved. """ # 1) explicit hint if ckpt_hint and os.path.isdir(ckpt_hint): return os.path.abspath(ckpt_hint) candidates = [] # Usual places Transformers stores the source string if getattr(getattr(llm, "config", None), "_name_or_path", None): candidates.append(llm.config._name_or_path) if getattr(llm, "name_or_path", None): candidates.append(llm.name_or_path) # 2) if any candidate is already a dir, take it for cand in candidates: if isinstance(cand, str) and os.path.isdir(cand): return os.path.abspath(cand) # 3) HF cache lookup for repo ids try: from huggingface_hub import snapshot_download except Exception: snapshot_download = None for cand in candidates: if not isinstance(cand, str): continue # Heuristic: repo ids usually contain '/' looks_like_repo_id = "/" in cand and not os.path.isabs(cand) if snapshot_download is None or not looks_like_repo_id: continue # First: local cache only (offline-safe) try: path = snapshot_download( repo_id=cand, local_files_only=True, # Narrow to model files to avoid pulling huge repos allow_patterns=["*.safetensors", "*.bin", "*.json", "*.index.json", "config.json"], ) return path except Exception: # Optional online fetch if allowed if allow_download: try: path = snapshot_download( repo_id=cand, local_files_only=False, allow_patterns=["*.safetensors", "*.bin", "*.json", "*.index.json", "config.json"], ) return path except Exception: pass return None def init_perceiver_from_llm(perceiver, llm, init_from_layers: int | None = None): """ Copies weights from the LLM's first few layers into the PerceiverResampler blocks. """ base = llm.model if hasattr(llm, "model") else llm depth = perceiver.depth L = min(init_from_layers or depth, depth, len(base.layers)) with torch.no_grad(): for i in range(L): src = base.layers[i] # Qwen2DecoderLayer dst = perceiver.blocks[i] # norms dst["input_ln"].weight.copy_(src.input_layernorm.weight) dst["post_ln"].weight.copy_(src.post_attention_layernorm.weight) # attention projections dst["cross_attn"].q_proj.weight.copy_(src.self_attn.q_proj.weight) dst["cross_attn"].q_proj.bias.copy_(src.self_attn.q_proj.bias) dst["cross_attn"].k_proj.weight.copy_(src.self_attn.k_proj.weight) dst["cross_attn"].k_proj.bias.copy_(src.self_attn.k_proj.bias) dst["cross_attn"].v_proj.weight.copy_(src.self_attn.v_proj.weight) dst["cross_attn"].v_proj.bias.copy_(src.self_attn.v_proj.bias) dst["cross_attn"].o_proj.weight.copy_(src.self_attn.o_proj.weight) # mlp dst["mlp"].gate_proj.weight.copy_(src.mlp.gate_proj.weight) dst["mlp"].up_proj.weight.copy_(src.mlp.up_proj.weight) dst["mlp"].down_proj.weight.copy_(src.mlp.down_proj.weight) def init_perceiver_from_llm_auto( perceiver, llm, ckpt_hint: str | None = None, init_from_layers: int | None = None, layer_offset: int = 0, allow_download: bool = False, ): """ Prefer initializing from the raw checkpoint on disk; if not found, fall back to in-memory quantization-aware init. """ ckpt_dir = resolve_llm_checkpoint_dir(llm, ckpt_hint=ckpt_hint, allow_download=allow_download) if ckpt_dir is not None: print(f"[Perceiver init] Using checkpoint dir: {ckpt_dir}") return init_perceiver_from_llm_checkpoint( perceiver, ckpt_dir=ckpt_dir, init_from_layers=init_from_layers or perceiver.depth, layer_offset=layer_offset, ) warnings.warn( "[Perceiver init] Could not resolve a checkpoint directory; falling back to " "in-memory quantization-aware initialization from the loaded LLM." ) return init_perceiver_from_llm(perceiver, llm, init_from_layers=init_from_layers)