# Copyright (c) OpenMMLab. All rights reserved. import math import os.path as osp import warnings from collections import OrderedDict import torch import torch.nn as nn import torch.distributed as dist # === MOD === from accelerate import init_empty_weights from mmengine import print_log from mmengine.config import Config, ConfigDict from mmengine.model import BaseModel from peft import get_peft_model, prepare_model_for_kbit_training from transformers import (AddedToken, AutoConfig, CLIPImageProcessor, CLIPVisionModel, LlamaForCausalLM, LlamaTokenizerFast, LlavaConfig, LlavaForConditionalGeneration, LlavaProcessor) from transformers.integrations import is_deepspeed_zero3_enabled import os from safetensors.torch import load_file, save_file from xtuner.registry import BUILDER from xtuner.utils import DEFAULT_IMAGE_TOKEN from .modules import ProjectorConfig, ProjectorModel, dispatch_modules from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 from .utils import (LoadWoInit, find_all_linear_names, get_peft_model_state_dict, guess_load_checkpoint, make_inputs_require_grad, prepare_inputs_labels_for_multimodal, traverse_dict) import torch.nn.functional as F from .sparse_token_merge import SparsePatchMerging from xtuner.model.torchscale.model.pos_embed import get_2d_sincos_pos_embed from peft import PeftModel from peft.tuners.lora.layer import LoraLayer # ===== 在类前或类内其它位置都可以:新增一个探测函数 ===== def _detect_qwen_major_version(llm) -> int: """ 返回 3 表示 Qwen3,2 表示 Qwen2,0 表示未知/其它。 优先用 config.model_type,其次回退到类名字符串。 """ base = llm.model if hasattr(llm, "model") else llm cfg = getattr(base, "config", None) mt = (getattr(cfg, "model_type", None) or "").lower() if mt == "qwen3": return 3 if mt == "qwen2": return 2 # 回退:根据类名判别 cname = base.__class__.__name__.lower() if "qwen3" in cname: return 3 if "qwen2" in cname: return 2 return 0 def convert_state_dict_to_hf(state_dict, mapping): new_state_dict = {} for key, value in state_dict.items(): if key.endswith('.inv_freq'): continue for key_to_modify, new_key in mapping.items(): if key_to_modify in key: key = key.replace(key_to_modify, new_key) new_state_dict[key] = value return new_state_dict class AdaptiveAvgPool1dLayer(nn.Module): def __init__(self, output_size): super(AdaptiveAvgPool1dLayer, self).__init__() self.output_size = output_size def forward(self, x): return F.adaptive_avg_pool1d(x, self.output_size) class LLaVAModel(BaseModel): def __init__(self, llm, freeze_llm=True, visual_select_layer=-2, pretrained_pth=None, projector_depth=2, llm_lora=None, visual_encoder_lora=None, use_activation_checkpointing=True, max_position_embeddings=None, hidden_size=512, train_stage='2', # slide/pos-embed 参数 slide_ngrids=1000, pe_gate_value=1.0, pe_dropout=0.1, tile_size=224, # 各子模块权重路径 projector_pth=None, perceiver_pth=None, token_merge_pth=None, pe_gate_pth=None, # Token Merge enable_token_merge=True, # Perceiver Resampler 配置 use_perceiver_resampler=True, concat_text_to_queries=True, perceiver_num_latents=64, perceiver_depth=2, # === 新增:Stage-2 冻结选项 === freeze_mm_in_stage2=False, # 总开关:在 stage-2 冻结 projector / perceiver / token_merge freeze_projector_stage2=None, # 子开关(None 表示跟随总开关) freeze_perceiver_stage2=None, # 子开关(None 表示跟随总开关) freeze_token_merge_stage2=None # 子开关(None 表示跟随总开关) ): super().__init__() self.freeze_llm = freeze_llm self.freeze_visual_encoder = True self.tile_size = tile_size # 训练阶段控制 if train_stage == '0': print_log('train_stage == 0', 'current') self.freeze_llm = True if train_stage == '1': print_log('train_stage == 1', 'current') self.freeze_llm = True elif train_stage == '2': print_log('train_stage == 2', 'current') self.freeze_llm = False # 解析 stage-2 的冻结意图 def _resolve(flag): return freeze_mm_in_stage2 if flag is None else bool(flag) self._freeze_projector_in_s2 = _resolve(freeze_projector_stage2) self._freeze_perceiver_in_s2 = _resolve(freeze_perceiver_stage2) self._freeze_token_merge_in_s2 = _resolve(freeze_token_merge_stage2) # 构建 / 派发 LLM with LoadWoInit(): if isinstance(llm, dict): llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings) self.llm = self._build_from_cfg_or_module(llm) self.llm.config.use_cache = False dispatch_modules(self.llm) # Token Merge self.enable_token_merge = enable_token_merge if self.enable_token_merge: self.token_merge = SparsePatchMerging( embed_dim=hidden_size, layernorm_eps=1e-6, merge_size=2 ) # Projector self.projector_depth = projector_depth projector_config = ProjectorConfig( visual_hidden_size=hidden_size * 4 if self.enable_token_merge else hidden_size, llm_hidden_size=self.llm.config.hidden_size, depth=self.projector_depth ) self.projector = ProjectorModel(projector_config).to(self.llm.dtype) self.projector.requires_grad_(True) # Perceiver Resampler self.use_perceiver_resampler = use_perceiver_resampler self.slide_ngrids = slide_ngrids if self.use_perceiver_resampler: self.perceiver_num_latents = perceiver_num_latents self.perceiver_depth = perceiver_depth num_patches = slide_ngrids ** 2 self.pe_gate = nn.Parameter(torch.tensor(pe_gate_value, dtype=self.llm.dtype)) self.pe_drop = nn.Dropout(pe_dropout) self.register_buffer( 'pos_embed', torch.zeros(1, num_patches, self.llm.config.hidden_size), persistent=False ) # 自动选择 Qwen2 / Qwen3 的 Perceiver 实现 qwen_major = _detect_qwen_major_version(self.llm) print_log(f'using qwen version {qwen_major}', 'current') if qwen_major == 3: try: from .qwen3_perceiver_resampler import ( PerceiverResampler as _PR, init_perceiver_from_llm_auto as _init_pr, ) print_log('using qwen3', 'current') except Exception as e: raise RuntimeError( "检测到 Qwen3,但未找到 qwen3_perceiver_resampler,请确认文件存在且 transformers 版本满足要求(>=4.51)。" ) from e elif qwen_major == 2: from .qwen2_perceiver_resampler import ( PerceiverResampler as _PR, init_perceiver_from_llm_auto as _init_pr, ) else: warnings.warn( "未能确定 Qwen 主版本(既不是 qwen3 也不是 qwen2)。将回退到 Qwen2 的 Perceiver 实现。", RuntimeWarning, ) from .qwen2_perceiver_resampler import ( PerceiverResampler as _PR, init_perceiver_from_llm_auto as _init_pr, ) if concat_text_to_queries: print_log("concat text to queries in perceiver", 'current') self.perceiver = _PR( self.llm, num_latents=self.perceiver_num_latents, depth=self.perceiver_depth, concat_text_to_queries=concat_text_to_queries, ).to(self.llm.dtype) # 仅当没有提供 perceiver_pth 或路径不存在时,才尝试从 LLM 自动初始化 if perceiver_pth is None or not os.path.exists(perceiver_pth): _init_pr( perceiver=self.perceiver, llm=self.llm, ckpt_hint=getattr(self.llm.config, "_name_or_path", None), init_from_layers=self.perceiver.depth, layer_offset=0, allow_download=False, ) # 初始化 pos-embed 等 self.initialize_pe_weights() # 冻结 LLM if self.freeze_llm: print('freeze_llm') self.llm.requires_grad_(False) # 激活检查点(按需对冻结模块跳过 input-grad 使能) if use_activation_checkpointing: if hasattr(self.llm, 'enable_input_require_grads'): self.llm.enable_input_require_grads() else: self.llm.get_input_embeddings().register_forward_hook(make_inputs_require_grad) if self.use_perceiver_resampler: _perceiver_frozen = (train_stage == '2' and self._freeze_perceiver_in_s2) if not _perceiver_frozen: self.perceiver.enable_input_require_grads() else: print_log('[stage-2] Skipping perceiver.enable_input_require_grads() (frozen)', 'current') _projector_frozen = (train_stage == '2' and self._freeze_projector_in_s2) if not _projector_frozen: print('enable projector input require grads') print_log('enable projector input require grads', 'current') self.projector.enable_input_require_grads() else: print_log('[stage-2] Skipping projector.enable_input_require_grads() (frozen)', 'current') # 启用激活检查点 self.gradient_checkpointing_enable() # LoRA self.use_llm_lora = llm_lora is not None self.use_visual_encoder_lora = None if self.use_llm_lora: print_log(f"Building lora {llm_lora.__str__}", "current") self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing) self.verify_lora() # 加载 token_merge / projector / perceiver / pe_gate 的 safetensors if token_merge_pth is not None and enable_token_merge and hasattr(self, 'token_merge'): print_log(f'loading token_merge from {token_merge_pth}', 'current') merger_sd = load_file(token_merge_pth, device='cpu') self.token_merge.load_state_dict(merger_sd, strict=False) self.token_merge.to(self.llm.dtype) if projector_pth is not None: print_log(f"Loading projector from {projector_pth}", "current") proj_sd = load_file(projector_pth, device="cpu") self.projector.load_state_dict(proj_sd, strict=False) self.projector.to(self.llm.dtype) if perceiver_pth is not None and self.use_perceiver_resampler and hasattr(self, 'perceiver'): print_log(f'Loading perceiver from {perceiver_pth}', 'current') perceiver_sd = load_file(perceiver_pth, device="cpu") self.perceiver.load_state_dict(perceiver_sd, strict=False) self.perceiver.to(self.llm.dtype) if pe_gate_pth is not None and self.use_perceiver_resampler and hasattr(self, "pe_gate"): print_log(f'Loading pe_gate from {pe_gate_pth}', 'current') sd = load_file(pe_gate_pth, device="cpu") if "pe_gate" not in sd: raise KeyError(f"'pe_gate' not found in {pe_gate_pth}. Keys: {list(sd.keys())}") with torch.no_grad(): self.pe_gate.copy_(sd["pe_gate"].to(dtype=self.llm.dtype, device=self.pe_gate.device)) # 额外加载 float 权重(可选) if pretrained_pth is not None: sd = guess_load_checkpoint(pretrained_pth) model_sd = self.state_dict() filtered = {k: v for k, v in sd.items() if k in model_sd and model_sd[k].shape == v.shape} missing, unexpected = self.load_state_dict(filtered, strict=False) print_log(f"Loaded float ckpt from {pretrained_pth}", "current") print_log(f" missing: {missing}", "current") print_log(f" unexpected:{unexpected}", "current") # 记录可视层 self.visual_select_layer = visual_select_layer # 初始化标志 self._is_init = True self.is_first_iter = True # === 关键新增:在 Stage-2 按需冻结三个多模态子模块 === if train_stage == '2': # projector if hasattr(self, 'projector') and self._freeze_projector_in_s2: self.projector.requires_grad_(False) self.projector.eval() print_log('[stage-2] Freezing projector parameters', 'current') # perceiver(含 pe_gate) if getattr(self, 'use_perceiver_resampler', False) and hasattr(self, 'perceiver') and self._freeze_perceiver_in_s2: self.perceiver.requires_grad_(False) self.perceiver.eval() print_log('[stage-2] Freezing perceiver parameters', 'current') if hasattr(self, 'pe_gate') and self._freeze_perceiver_in_s2: self.pe_gate.requires_grad = False # token_merge if getattr(self, 'enable_token_merge', False) and hasattr(self, 'token_merge') and self._freeze_token_merge_in_s2: self.token_merge.requires_grad_(False) self.token_merge.eval() print_log('[stage-2] Freezing token_merge parameters', 'current') def _parse_lora_config(self, lora_config): if isinstance(lora_config, dict) or isinstance( lora_config, Config) or isinstance(lora_config, ConfigDict): lora_config = BUILDER.build(lora_config) return lora_config # def initialize_pe_weights(self): # # initialization # # initialize (and freeze) pos_embed by sin-cos embedding # if self.use_perceiver_resampler: # pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.slide_ngrids, cls_token=False) # self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) @torch.no_grad() def initialize_pe_weights(self, chunk_rows: int = 64, chunk_cols: int = 64): """ 在 GPU 上用 float64 精度生成 2D sin-cos 位置编码, 逻辑与 numpy 版本完全一致,然后写入 self.pos_embed。 """ if not getattr(self, "use_perceiver_resampler", False): return H = W = int(self.slide_ngrids) D = int(self.llm.config.hidden_size) assert D % 4 == 0, "hidden_size 必须是 4 的倍数,才能和 numpy 实现严格对应。" device = self.pos_embed.device dtype64 = torch.float64 # 全程用 float64 # 预分配/调整 buffer 形状 if self.pos_embed.shape != (1, H * W, D): self.pos_embed.resize_(1, H * W, D) pos4d = self.pos_embed.view(1, H, W, D) # 频率向量 k = D // 4 inv = 1.0 / (10000 ** (torch.arange(k, device=device, dtype=dtype64) / k)) # 整数坐标 (与 numpy 一致) y_lin = torch.arange(H, device=device, dtype=dtype64) x_lin = torch.arange(W, device=device, dtype=dtype64) # 一维编码 y_phase = y_lin.unsqueeze(1) * inv.unsqueeze(0) # [H,k] x_phase = x_lin.unsqueeze(1) * inv.unsqueeze(0) # [W,k] y_enc = torch.cat([torch.sin(y_phase), torch.cos(y_phase)], dim=1) # [H,2k] x_enc = torch.cat([torch.sin(x_phase), torch.cos(x_phase)], dim=1) # [W,2k] # 分块写入,避免一次性大张量 for r0 in range(0, H, chunk_rows): r1 = min(r0 + chunk_rows, H) R = r1 - r0 y_chunk = y_enc[r0:r1].unsqueeze(1) # [R,1,2k] for c0 in range(0, W, chunk_cols): c1 = min(c0 + chunk_cols, W) C = c1 - c0 x_chunk = x_enc[c0:c1].unsqueeze(0) # [1,C,2k] # 拼接顺序与 numpy 一致: [emb_w, emb_h] emb_rc = torch.cat( [x_chunk.expand(R, C, 2*k), y_chunk.expand(R, C, 2*k)], dim=2 ) # [R,C,D] # copy 到 buffer(自动 cast 到 buffer dtype) pos4d[0, r0:r1, c0:c1, :].copy_(emb_rc.to(pos4d.dtype)) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def _prepare_llm_for_lora(self, lora_config, use_activation_checkpointing=True): lora_config = self._parse_lora_config(lora_config) self.llm = prepare_model_for_kbit_training( self.llm, use_activation_checkpointing) if lora_config.target_modules is None: modules = find_all_linear_names(self.llm) lora_config.target_modules = modules self.llm = get_peft_model(self.llm, lora_config) def verify_lora(self): m = self.llm # 1) Wrapped as a PEFT model assert isinstance(m, PeftModel), "LoRA not applied: model is not a PeftModel" # 2) Adapters are registered and active adapters = m.peft_config # dict: {adapter_name: LoraConfig} assert len(adapters) > 0, "No adapters registered in peft_config" active = m.active_adapter if hasattr(m, "active_adapter") else None assert active in adapters, f"Active adapter {active} not found in peft_config" # 3) LoRA layers are present on target modules lora_modules = [mod for mod in m.modules() if isinstance(mod, LoraLayer)] assert len(lora_modules) > 0, "No LoraLayer modules found (check target_modules)" # 4) LoRA params are the only trainable ones (typical for QLoRA) trainable = [(n,p) for n,p in m.named_parameters() if p.requires_grad] assert len(trainable) > 0, "No trainable parameters (LoRA params are not set to requires_grad=True)" # Optional: sanity-check that trainable params look like LoRA suspicious = [n for n,_ in trainable if "lora_" not in n and "modules_to_save" not in n] # It's okay if you intentionally left some modules_to_save; adjust as needed. assert len(suspicious) == 0, f"Unexpected trainable params (not LoRA): {suspicious[:5]}" # 5) Quick count + readable log total = sum(p.numel() for _,p in m.named_parameters()) trainable_cnt = sum(p.numel() for _,p in trainable) ratio = trainable_cnt / total print(f"[LoRA OK] adapters={list(adapters.keys())}, active={active}, " f"LoraLayers={len(lora_modules)}, trainable={trainable_cnt}/{total} ({ratio:.4%})") # 6) Forward+backward smoke test to confirm gradients flow to LoRA only m.train() dummy_inp = torch.randint(0, m.get_input_embeddings().num_embeddings, (1, 8)).to(next(m.parameters()).device) out = m(input_ids=dummy_inp, labels=dummy_inp) out.loss.backward() # should not error # Ensure some LoRA grads exist lora_grads = [p.grad for _,p in m.named_parameters() if p.requires_grad and p.grad is not None] assert len(lora_grads) > 0, "No gradients on LoRA parameters after backward()" def _prepare_visual_encoder_for_lora(self, lora_config, use_activation_checkpointing=True): lora_config = self._parse_lora_config(lora_config) if lora_config.target_modules is None: modules = find_all_linear_names(self.visual_encoder) lora_config.target_modules = modules self.visual_encoder = get_peft_model(self.visual_encoder, lora_config) def gradient_checkpointing_enable(self, use_reentrant=False): self.activation_checkpointing_enable(use_reentrant=use_reentrant) def activation_checkpointing_enable(self, use_reentrant=False): # LLM try: self.llm.gradient_checkpointing_enable(use_reentrant=use_reentrant) except TypeError: # older HF versions self.llm.gradient_checkpointing_enable() # projector try: self.projector.gradient_checkpointing_enable(use_reentrant=use_reentrant) except TypeError: self.projector.gradient_checkpointing_enable() # perceiver (if present) if getattr(self, 'use_perceiver_resampler', False) and getattr(self, 'perceiver', None) is not None: try: self.perceiver.gradient_checkpointing_enable(use_reentrant=use_reentrant) except AttributeError: # some custom modules only expose input-grad helper if hasattr(self.perceiver, 'enable_input_require_grads'): self.perceiver.enable_input_require_grads() def gradient_checkpointing_disable(self): self.activation_checkpointing_disable() def activation_checkpointing_disable(self): self.llm.gradient_checkpointing_disable() # self.visual_encoder.gradient_checkpointing_disable() self.projector.gradient_checkpointing_disable() if self.use_perceiver_resampler: self.perceiver.disable_gradient_checkpointing() def init_weights(self): pass def state_dict(self, *args, **kwargs): state_dict = super().state_dict(*args, **kwargs) to_return = OrderedDict() # Step 1. visual_encoder if self.use_visual_encoder_lora: to_return.update( get_peft_model_state_dict( self.visual_encoder, state_dict=state_dict)) elif not self.freeze_visual_encoder: to_return.update({ k: v for k, v in state_dict.items() if 'visual_encoder.' in k }) # Step 2. LLM if self.use_llm_lora: to_return.update( get_peft_model_state_dict(self.llm, state_dict=state_dict)) elif not self.freeze_llm: to_return.update( {k: v for k, v in state_dict.items() if 'llm.' in k}) # Step 3. Projector to_return.update( {k: v for k, v in state_dict.items() if 'projector.' in k}) # Step 5. Perceiver Resampler (unchanged) if getattr(self, 'use_perceiver_resampler', False) and getattr(self, 'perceiver', None) is not None: to_return.update({k: v for k, v in state_dict.items() if 'perceiver.' in k}) if getattr(self, 'pe_gate', False): to_return.update({k: v for k, v in state_dict.items() if 'pe_gate.' in k}) # step 5 token merger if getattr(self, 'token_merge', False): to_return.update({k: v for k, v in state_dict.items() if 'token_merge.' in k}) return to_return @staticmethod def _prepare_for_long_context_training(cfg, llm_cfg, max_position_embeddings): orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None) if orig_rope_scaling is None: orig_rope_scaling = {'factor': 1} orig_rope_scaling_factor = orig_rope_scaling[ 'factor'] if 'factor' in orig_rope_scaling.keys() else 1 orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None) if orig_ctx_len: orig_ctx_len *= orig_rope_scaling_factor if max_position_embeddings > orig_ctx_len: scaling_factor = float( math.ceil(max_position_embeddings / orig_ctx_len)) llm_cfg.rope_scaling = { 'type': 'linear', 'factor': scaling_factor } # hardcode for internlm2 llm_cfg.attn_implementation = 'flash_attention_2' cfg.config = llm_cfg return cfg, llm_cfg @staticmethod def _prepare_for_flash_attn(cfg, llm_cfg): cls_name = type(llm_cfg).__name__ SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig', 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig', 'Starcoder2Config', 'Starcoder2Config', 'Phi3Config') SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig', 'MistralConfig', 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig', 'Starcoder2Config', 'Starcoder2Config', 'Phi3Config') torch_dtype = torch.bfloat16 if ( torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ else torch.float16 if getattr(cfg, 'attn_implementation', None) is not None: # Flash Attention 2.0 only supports torch.float16 and # torch.bfloat16 dtypes if cfg.attn_implementation == 'flash_attention_2': cfg.torch_dtype = torch_dtype elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: cfg.torch_dtype = torch_dtype cfg.attn_implementation = 'flash_attention_2' elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN: cfg.attn_implementation = 'sdpa' return cfg, llm_cfg @staticmethod def _prepare_for_qlora_zero3(cfg): if (not is_deepspeed_zero3_enabled()) or (not hasattr( cfg, 'quantization_config')): return cfg torch_dtype = torch.bfloat16 if ( torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ else torch.float16 cfg.torch_dtype = torch_dtype quantization_config = cfg.quantization_config quantization_config.bnb_4bit_compute_dtype = torch_dtype quantization_config.bnb_4bit_quant_storage = torch_dtype return cfg def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None): cfg = self._prepare_for_qlora_zero3(cfg) pretrained_model_name_or_path = cfg.pretrained_model_name_or_path llm_cfg = AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=True) cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg) if max_position_embeddings is not None: cfg, llm_cfg = self._prepare_for_long_context_training( cfg, llm_cfg, max_position_embeddings) return cfg def _build_from_cfg_or_module(self, cfg_or_mod): if isinstance(cfg_or_mod, nn.Module): return cfg_or_mod elif isinstance(cfg_or_mod, dict): traverse_dict(cfg_or_mod) return BUILDER.build(cfg_or_mod) else: raise NotImplementedError def coords_to_pos(self, coords, tile_size: int = 224): """ This function is used to convert the coordinates to the positional indices Arguments: ---------- coords: torch.Tensor The coordinates of the patches, of shape [N, L, 2] output: torch.Tensor The positional indices of the patches, of shape [N, L] """ coords_ = torch.floor(coords / tile_size) pos = coords_[..., 0] * self.slide_ngrids + coords_[..., 1] return pos.long() # add 1 for the cls token @staticmethod def _coords_rc_to_pos(coords_rc: torch.Tensor, ngrids: int) -> torch.Tensor: if coords_rc.dtype.is_floating_point: coords_rc = coords_rc.round().to(torch.long) # row = coords_rc[:, 0].clamp_(0, ngrids-1) # col = coords_rc[:, 1].clamp_(0, ngrids-1) return (coords_rc[..., 0] * ngrids + coords_rc[..., 1]).long() # +1 for cls def forward(self, data, data_samples=None, mode='loss'): if self.is_first_iter: # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to # device # Only required in `LLaVAModel` . # We do not need this in `SupervisedFinetune` . self.to(data['input_ids'].device) self.is_first_iter = False coords = None if 'pixel_values' in data: feat_to_proj = data['pixel_values'].to(self.llm.dtype) # torch.Size([1, img_num, 512]) if 'coords' in data: coords = data['coords'].to(self.llm.dtype) # Accept: list[tensor], [L,2] tensor, or [B,L,2] tensor coords_t = coords[0] if isinstance(coords, list) else coords Bx = feat_to_proj.size(0) # actual batch size of inputs if not torch.is_tensor(coords_t): raise ValueError("coords must be a Tensor or list[Tensor].") if coords_t.dim() == 2: # [L, 2] coords_rc = coords_t elif coords_t.dim() == 3: # [B, L, 2] -> ensure B matches and either B==1 or all examples share coords if coords_t.size(0) != Bx: raise ValueError(f"coords batch dim mismatch: got {coords_t.size(0)} but inputs have B={Bx}") if Bx == 1: coords_rc = coords_t[0] else: # require same coords across the batch (cheap equality check) if not torch.equal(coords_t, coords_t[0].unsqueeze(0).expand_as(coords_t)): raise NotImplementedError( "Per-example coords (varying across batch) are not supported by the current " "patch-merging/layout path. Use batch size 1 or share coords across the batch." ) coords_rc = coords_t[0] else: raise ValueError("coords must have shape [L,2] or [B,L,2].") if coords_rc.size(-1) != 2: raise ValueError("coords last dimension must be 2.") else: raise RuntimeError # only works for batch size one if self.enable_token_merge: feat_to_proj, coords_rc_merged, _ = self.token_merge( x=feat_to_proj, coords_rc=self._coords_to_rowcol(coords_rc), # 你已有,生成 rc padmask=torch.zeros([feat_to_proj.size(0), feat_to_proj.size(1)], device=feat_to_proj.device, dtype=torch.bool) ) else: coords_rc_merged = self._coords_to_rowcol(coords_rc) padmask_merged = torch.zeros([feat_to_proj.size(0), feat_to_proj.size(1)], device=feat_to_proj.device, dtype=torch.bool) pixel_values = self.projector(feat_to_proj.to(self.llm.dtype)) # output shape [1, patch_num, 3584] if self.use_perceiver_resampler and 'input_ids' in data: text_emb = self.llm.get_input_embeddings()(data["input_ids"].clamp(min=0)) \ .to(self.llm.dtype).detach() # 注意:这里的 coords_RC 已经是合并后的 (row, col) # print(coords_rc_merged.max(), coords_rc_merged.shape) pos = self._coords_rc_to_pos(coords_rc_merged, self.slide_ngrids) # B==1 假设 # print(pos.max(), pos.shape) pixel_values = pixel_values + self.pe_drop(self.pos_embed[:, pos, :].squeeze(0) * self.pe_gate) compressed = self.perceiver( # input_ids = data["input_ids"], text_embeddings=text_emb, attention_mask=data.get("attention_mask", None), visual_tokens=pixel_values, ) data["pixel_values"] = compressed else: data['pixel_values'] = pixel_values # shape: [1, patch_num, 3584] # shape: [1, 576, 4096] # remove coords data.pop('coords', None) data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data) if mode == 'loss': return self.compute_loss(data, data_samples) elif mode == 'predict': return self.predict(data, data_samples) elif mode == 'tensor': return self._forward(data, data_samples) else: raise NotImplementedError @staticmethod def _coords_to_rowcol(coords_xy: torch.Tensor) -> torch.Tensor: with torch.no_grad(): x = coords_xy[:, 0] y = coords_xy[:, 1] x_for_unique = x y_for_unique = y if x_for_unique.dtype.is_floating_point: x_for_unique = x_for_unique.round().to(torch.int) y_for_unique = y_for_unique.round().to(torch.int) x_sorted = torch.unique(x_for_unique, sorted=True) y_sorted = torch.unique(y_for_unique, sorted = True) col = torch.searchsorted(x_sorted, x) row = torch.searchsorted(y_sorted, y) return torch.stack([row, col], dim=-1) def _forward(self, data, data_samples=None): outputs = self.llm(**data) return outputs def predict(self, data, data_samples=None): outputs = self.llm(**data) logits_dict = [{'logits': logits} for logits in outputs.logits] return logits_dict # def compute_loss(self, data, data_samples=None): # outputs = self.llm(**data) # # outputs.logits.shape (1, 1094, 152064) for Qwen # loss_dict = {'loss': outputs.loss} # return loss_dict # # === MOD: token-averaged, globally weighted loss (robust to variable lengths) def compute_loss(self, data, data_samples=None): """ 计算 token-level 交叉熵损失(分布式/AMP 兼容)。 - labels 中 -100 为 ignore_index - 自动屏蔽负 ID(如 -200 图像占位)与 special_ids 对应位置 """ # 1) 若无 labels,退回 HF 默认 if "labels" not in data: outputs = self.llm(**data) return {"loss": outputs.loss} labels = data["labels"] # [B, T] input_ids = data.get("input_ids", None) # [B, T] or None attn = data.get("attention_mask", None) # 可无 # 2) 标签清洗(不改原 labels) safe_labels = labels.clone() # 2.1 屏蔽负 ID(如 -200 图像占位) if input_ids is not None: neg_mask = (input_ids < 0) if neg_mask.any(): safe_labels = torch.where(neg_mask, torch.full_like(safe_labels, -100), safe_labels) # 2.2 屏蔽 tokenizer 的特殊 token(模板标记等) if getattr(self, "tokenizer", None) is not None: try: special_ids = set(self.tokenizer.all_special_ids or []) except Exception: special_ids = set() if special_ids: special_mask = torch.zeros_like(input_ids, dtype=torch.bool) for sid in special_ids: special_mask |= (input_ids == sid) if special_mask.any(): safe_labels = torch.where(special_mask, torch.full_like(safe_labels, -100), safe_labels) # 3) 前向,拿 logits(不把 labels 交给 HF,避免其先做 per-device mean) model_inputs = {k: v for k, v in data.items() if k != "labels"} outputs = self.llm(**model_inputs, use_cache=False) logits = outputs.logits # [B, T, V] # 形状断言 if logits.dim() != 3 or logits.shape[:2] != safe_labels.shape[:2]: raise RuntimeError( f"logits/labels length mismatch: logits {tuple(logits.shape)} vs labels {tuple(safe_labels.shape)}" ) # 4) CausalLM 对齐 shift_logits = logits[:, :-1, :].contiguous() shift_labels = safe_labels[:, 1:].contiguous() # 5) 统计有效 token & 分布式聚合 n_tok_local = (shift_labels != -100).sum().to(device=logits.device, dtype=torch.long) world_size = 1 n_tok_global = n_tok_local if dist.is_available() and dist.is_initialized(): world_size = dist.get_world_size() with torch.no_grad(): n_tok_global = n_tok_local.clone() dist.all_reduce(n_tok_global, op=dist.ReduceOp.SUM) # 若全局无监督 token,则返回 0(防 NaN) if n_tok_global.item() == 0: zero = shift_logits.sum() * 0.0 return {"loss": zero, "ntok": n_tok_global.to(zero.dtype)} # 6) 分子(sum over tokens,FP32 更稳) loss_sum_local = F.cross_entropy( shift_logits.float().view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, reduction="sum", ) # 7) 全局 token 平均的 loss(抵消 DDP 的梯度平均) denom = n_tok_global.clamp_min(1).to(loss_sum_local.dtype) loss = (loss_sum_local / denom) * float(world_size) # 8) 返回 ntok_tensor = denom.detach() return {"loss": loss, "ntok": ntok_tensor} def __getattr__(self, name: str): try: return super().__getattr__(name) except AttributeError: return getattr(self.llm, name) def to_hf(self, cfg, save_dir, fp32=False, save_pretrained_kwargs={}, save_format='xtuner', **kwargs): if save_format == 'xtuner': self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs) elif save_format == 'huggingface': self.to_huggingface_llava(cfg, save_dir, fp32, save_pretrained_kwargs) elif save_format == 'official': self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs) else: raise NotImplementedError def to_xtuner_llava(self, cfg, save_dir, fp32=False, save_pretrained_kwargs={}): # LLM self.llm.config.use_cache = True if not fp32: print_log('Convert LLM to float16', 'current') self.llm.half() if self.use_llm_lora: llm_path = osp.join(save_dir, 'llm_adapter') print_log(f'Saving LLM adapter to {llm_path}', 'current') self.llm.save_pretrained(llm_path, **save_pretrained_kwargs) elif not self.freeze_llm: llm_path = save_dir print_log(f'Saving LLM tokenizer to {llm_path}', 'current') tokenizer = BUILDER.build(cfg.tokenizer) tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs) print_log(f'Saving LLM to {llm_path}', 'current') self.llm.save_pretrained(llm_path, **save_pretrained_kwargs) self.llm.config.use_cache = False # Visual Encoder if self.use_visual_encoder_lora: visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter') print_log( f'Saving visual_encoder adapter to {visual_encoder_path}', 'current') self.visual_encoder.save_pretrained(visual_encoder_path, **save_pretrained_kwargs) elif not self.freeze_visual_encoder: visual_encoder_path = osp.join(save_dir, 'visual_encoder') print_log( 'Saving visual_encoder image_processor to' f'{visual_encoder_path}', 'current') image_processor = BUILDER.build(cfg.image_processor) image_processor.save_pretrained(visual_encoder_path, **save_pretrained_kwargs) print_log(f'Saving visual_encoder to {visual_encoder_path}', 'current') self.visual_encoder.save_pretrained(visual_encoder_path, **save_pretrained_kwargs) # Projector projector_path = osp.join(save_dir, 'projector') print_log(f'Saving projector to {projector_path}', 'current') # self.projector.save_pretrained(projector_path, # **save_pretrained_kwargs) os.makedirs(projector_path, exist_ok=True) output_path = os.path.join(projector_path, 'projector.safetensors') save_file(self.projector.state_dict(), output_path) if self.use_perceiver_resampler and hasattr(self, 'perceiver'): perceiver_path = osp.join(save_dir, "perceiver") print_log(f'Saving LongNet_encoder to {perceiver_path}', 'current') os.makedirs(perceiver_path, exist_ok=True) perceiver_output_path = os.path.join(perceiver_path, 'perceiver.safetensors') save_file(self.perceiver.state_dict(), perceiver_output_path) if self.enable_token_merge and hasattr(self, 'token_merge'): merger_path = osp.join(save_dir, 'token_merger') print_log(f'Saving token merger to{merger_path}', 'current') os.makedirs(merger_path, exist_ok= True) merger_path = osp.join(merger_path, 'merger.safetensors') save_file(self.token_merge.state_dict(), merger_path) if self.use_perceiver_resampler and hasattr(self, 'pe_gate'): pe_gate_path = osp.join(save_dir, 'pe_gate') print_log(f'saving pe_gate to {pe_gate_path}', 'current') os.makedirs(pe_gate_path, exist_ok= True) pe_gate_output_path = os.path.join(pe_gate_path, 'pe_gate.safetensors') # choose dtype for saving save_dtype = torch.float32 if fp32 else self.llm.dtype # save as a single-tensor safetensors file save_file( {"pe_gate": self.pe_gate.detach().to(save_dtype).cpu()}, pe_gate_output_path ) def to_huggingface_llava(self, cfg, save_dir, fp32=False, save_pretrained_kwargs={}): LLM_MAPPING = { 'model': 'language_model.model', 'lm_head': 'language_model.lm_head', } VIT_MAPPING = { 'vision_model': 'vision_tower.vision_model', } PROJECTOR_MAPPING = { 'model.0': 'multi_modal_projector.linear_1', 'model.2': 'multi_modal_projector.linear_2', } LONGNET_MAPPING = { 'layers.0': 'LongNet_encoder.layers.0', 'layers.1': 'LongNet_encoder.layers.1', 'layer_norm': 'LongNet_encoder.layer_norm' } assert getattr(self.llm, 'hf_quantizer', None) is None, \ 'This conversion format does not support quantized LLM.' # get state_dict llm = self.llm if self.use_llm_lora: llm = self.llm.merge_and_unload() llm.config.use_cache = True if not fp32: print_log('Convert LLM to float16', 'current') llm.half() assert isinstance(llm, LlamaForCausalLM), \ 'This conversion format only supports LlamaForCausalLM.' llm_state_dict = llm.state_dict() llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING) need_visual_encoder = (not self.freeze_visual_encoder or self.use_visual_encoder_lora) visual_encoder = self.visual_encoder if self.use_visual_encoder_lora: visual_encoder = self.visual_encoder.merge_and_unload() assert isinstance(visual_encoder, CLIPVisionModel),\ 'This conversion format only supports CLIPVisionModel.' if need_visual_encoder: visual_encoder_state_dict = visual_encoder.state_dict() visual_encoder_state_dict = convert_state_dict_to_hf( visual_encoder_state_dict, VIT_MAPPING) else: visual_encoder_state_dict = {} projector_state_dict = self.projector.state_dict() projector_state_dict = convert_state_dict_to_hf( projector_state_dict, PROJECTOR_MAPPING) LongNet_encoder_state_dict = self.LongNet_encoder.state_dict() LongNet_encoder_state_dict = convert_state_dict_to_hf( LongNet_encoder_state_dict, LONGNET_MAPPING) state_dict = { **projector_state_dict, **llm_state_dict, **visual_encoder_state_dict, **LongNet_encoder_state_dict } # init model text_config = llm.config vision_config = visual_encoder.config config = LlavaConfig( text_config=text_config, vision_config=vision_config, attn_implementation='eager') with init_empty_weights(): with warnings.catch_warnings(): warnings.filterwarnings( 'ignore', message='.*non-meta.*', category=UserWarning) model = LlavaForConditionalGeneration(config) model.load_state_dict(state_dict, strict=True, assign=True) # processor cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained tokenizer = BUILDER.build(cfg.tokenizer) tokenizer.add_tokens( AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False), special_tokens=True) tokenizer.add_special_tokens({'pad_token': ''}) image_processor = BUILDER.build(cfg.image_processor) assert isinstance(image_processor, CLIPImageProcessor),\ 'This conversion format only supports CLIPImageProcessor.' processor = LlavaProcessor( tokenizer=tokenizer, image_processor=image_processor) # Pad to 64 for performance reasons pad_shape = 64 pre_expansion_embeddings = \ model.language_model.model.embed_tokens.weight.data mu = torch.mean(pre_expansion_embeddings, dim=0).float() n = pre_expansion_embeddings.size()[0] sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n dist = torch.distributions.multivariate_normal.MultivariateNormal( mu, covariance_matrix=1e-5 * sigma) # We add an image token so we need to resize the model ori_vocab_size = config.text_config.vocab_size tokenizer_vocab_size = tokenizer.encode('')[-1] added_token = tokenizer_vocab_size - ori_vocab_size if added_token > 0: model.resize_token_embeddings(ori_vocab_size + added_token, pad_shape) model.language_model.model.embed_tokens.weight.data[ ori_vocab_size:] = torch.stack( tuple( dist.sample() for _ in range(model.language_model.model.embed_tokens. weight.data[ori_vocab_size:].shape[0])), dim=0, ) model.language_model.lm_head.weight.data[ ori_vocab_size:] = torch.stack( tuple(dist.sample() for _ in range(model.language_model.lm_head.weight. data[ori_vocab_size:].shape[0])), dim=0, ) model.config.image_token_index = tokenizer.encode( DEFAULT_IMAGE_TOKEN)[-1] model.config.pad_token_id = tokenizer.encode('')[-1] # save print_log(f'Saving to {save_dir}', 'current') model.save_pretrained(save_dir, **save_pretrained_kwargs) processor.save_pretrained(save_dir, **save_pretrained_kwargs) def to_official_llava(self, cfg, save_dir, fp32=False, save_pretrained_kwargs={}): VIT_MAPPING = { 'vision_model': 'model.vision_tower.vision_tower.vision_model', } PROJECTOR_MAPPING = { 'model.0': 'model.mm_projector.0', 'model.2': 'model.mm_projector.2', } LONGNET_MAPPING = { 'layers.0': 'LongNet_encoder.layers.0', 'layers.1': 'LongNet_encoder.layers.1', 'layer_norm': 'LongNet_encoder.layer_norm' } try: from llava.model import LlavaConfig, LlavaLlamaForCausalLM except ImportError: raise ImportError( 'Please install llava with ' '`pip install git+https://github.com/haotian-liu/LLaVA.git ' '--no-deps`.') assert getattr(self.llm, 'hf_quantizer', None) is None, \ 'This conversion format does not support quantized LLM.' # get state_dict llm = self.llm if self.use_llm_lora: llm = self.llm.merge_and_unload() llm.config.use_cache = True if not fp32: print_log('Convert LLM to float16', 'current') llm.half() assert isinstance(llm, LlamaForCausalLM), \ 'This conversion format only supports LlamaForCausalLM.' llm_state_dict = llm.state_dict() need_visual_encoder = (not self.freeze_visual_encoder or self.use_visual_encoder_lora) visual_encoder = self.visual_encoder if self.use_visual_encoder_lora: visual_encoder = self.visual_encoder.merge_and_unload() assert isinstance(visual_encoder, CLIPVisionModel),\ 'This conversion format only supports CLIPVisionModel.' if need_visual_encoder: visual_encoder_state_dict = visual_encoder.state_dict() visual_encoder_state_dict = convert_state_dict_to_hf( visual_encoder_state_dict, VIT_MAPPING) else: visual_encoder_state_dict = {} projector_state_dict = self.projector.state_dict() projector_state_dict = convert_state_dict_to_hf( projector_state_dict, PROJECTOR_MAPPING) LongNet_encoder_state_dict = self.LongNet_encoder.state_dict() LongNet_encoder_state_dict = convert_state_dict_to_hf( LongNet_encoder_state_dict, LONGNET_MAPPING) state_dict = { **projector_state_dict, **llm_state_dict, **visual_encoder_state_dict, **LongNet_encoder_state_dict } # init model tokenizer = BUILDER.build(cfg.tokenizer) image_processor = BUILDER.build(cfg.image_processor) assert isinstance(image_processor, CLIPImageProcessor),\ 'This conversion format only supports CLIPImageProcessor.' llava_config_dict = llm.config.__dict__.copy() llava_config_dict.update( dict( image_aspect_ratio='pad', mm_hidden_size=visual_encoder.config.hidden_size, mm_projector_type=f'mlp{self.projector_depth}x_gelu', mm_use_im_patch_token=False, mm_use_im_start_end=False, mm_vision_select_feature='patch', mm_vision_select_layer=self.visual_select_layer, mm_vision_tower=visual_encoder.config.name_or_path, unfreeze_mm_vision_tower=need_visual_encoder, model_type='llava', use_cache=True, use_mm_proj=True)) llava_config = LlavaConfig(**llava_config_dict) with init_empty_weights(): with warnings.catch_warnings(): warnings.filterwarnings( 'ignore', message='.*non-meta.*', category=UserWarning) model = LlavaLlamaForCausalLM(llava_config) model.load_state_dict(state_dict, strict=True, assign=True) # save print_log(f'Saving to {save_dir}', 'current') model.save_pretrained(save_dir, **save_pretrained_kwargs) image_processor.save_pretrained(save_dir, **save_pretrained_kwargs) tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)