# Copyright (c) OpenMMLab. All rights reserved. import math import os.path as osp import warnings from collections import OrderedDict import os from safetensors.torch import load_file, save_file import torch import torch.distributed as dist # === MOD === import torch.nn as nn from accelerate import init_empty_weights from mmengine import print_log from mmengine.config import Config, ConfigDict from mmengine.model import BaseModel from peft import get_peft_model, prepare_model_for_kbit_training from transformers import (AddedToken, AutoConfig, CLIPImageProcessor, CLIPVisionModel, LlamaForCausalLM, LlamaTokenizerFast, LlavaConfig, LlavaForConditionalGeneration, LlavaProcessor) from transformers.integrations import is_deepspeed_zero3_enabled from xtuner.registry import BUILDER from xtuner.utils import DEFAULT_IMAGE_TOKEN from .modules import ProjectorConfig, ProjectorModel, dispatch_modules from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 from .utils import (LoadWoInit, find_all_linear_names, get_peft_model_state_dict, guess_load_checkpoint, make_inputs_require_grad, prepare_inputs_labels_for_multimodal, traverse_dict) import torch.nn.functional as F def convert_state_dict_to_hf(state_dict, mapping): new_state_dict = {} for key, value in state_dict.items(): if key.endswith('.inv_freq'): continue for key_to_modify, new_key in mapping.items(): if key_to_modify in key: key = key.replace(key_to_modify, new_key) new_state_dict[key] = value return new_state_dict class AdaptiveAvgPool1dLayer(nn.Module): """ 自适应平均池化层(沿序列维 L),带输入/输出 LayerNorm,并在大 L 时切换为线性插值, 避免 CUDA AdaptiveAvgPool 的 sharedMem 限制导致的报错。 期望输入:x ∈ [B, H, L] - 先在 [B, L, H] 上做输入层归一化(LayerNorm(H))。 - 对序列维 L 做池化/插值到 output_size。 - 再在 [B, L_out, H] 上做输出层归一化。 参数: output_size (int): 池化后的 token 数 L_out。 hidden_size (int): 通道维 H 的大小(用于 LayerNorm 维度)。 eps (float): LayerNorm eps。 affine (bool): LayerNorm 是否带缩放平移参数。 impl (str): 'auto' | 'pool' | 'interp'。auto 根据长度阈值自动切换。 switch_threshold (int): 当 L >= 该阈值且 impl='auto' 时使用插值。 pool_in_fp32 (bool): 池化/插值内部提升到 FP32 计算以增强数稳。 """ def __init__(self, output_size: int, hidden_size: int, eps: float = 1e-5, affine: bool = True, impl: str = 'auto', switch_threshold: int = 8192, pool_in_fp32: bool = True): super().__init__() if output_size <= 0: raise ValueError("output_size must be positive") if hidden_size <= 0: raise ValueError("hidden_size must be positive") if impl not in ('auto', 'pool', 'interp'): raise ValueError("impl must be one of {'auto','pool','interp'}") self.output_size = int(output_size) self.hidden_size = int(hidden_size) self.impl = impl self.switch_threshold = int(switch_threshold) self.pool_in_fp32 = bool(pool_in_fp32) self.in_norm = nn.LayerNorm(hidden_size, eps=eps, elementwise_affine=affine) self.out_norm = nn.LayerNorm(hidden_size, eps=eps, elementwise_affine=affine) def forward(self, x: torch.Tensor) -> torch.Tensor: # 期待 x 形状为 [B, H, L] if x.dim() != 3: raise ValueError(f"AdaptiveAvgPool1dLayer expects 3D tensor [B,H,L], got {tuple(x.shape)}") B, H, L = x.shape if H != self.hidden_size: raise ValueError(f"Channel size mismatch: got H={H}, expected {self.hidden_size}") # 输入归一化:在 [B, L, H] 上做 LayerNorm(H) x = x.transpose(1, 2).contiguous() # [B, L, H] x = self.in_norm(x) x = x.transpose(1, 2).contiguous() # [B, H, L] # 选择实现:大 L 时使用插值以避免 CUDA sharedMem 报错 use_interp = (self.impl == 'interp') or (self.impl == 'auto' and L >= self.switch_threshold) orig_dtype = x.dtype if self.pool_in_fp32 and x.dtype in (torch.float16, torch.bfloat16): x = x.float() if use_interp: # 线性插值在 [B, H, L] 上稳定可导 x = F.interpolate(x, size=self.output_size, mode='linear', align_corners=False) else: x = F.adaptive_avg_pool1d(x.contiguous(), self.output_size) x = x.to(orig_dtype) # 输出归一化:在 [B, L_out, H] 上做 LayerNorm(H) x = x.transpose(1, 2).contiguous() # [B, L_out, H] x = self.out_norm(x) x = x.transpose(1, 2).contiguous() # [B, H, L_out] return x class LLaVAModel(BaseModel): def __init__(self, llm, freeze_llm=True, visual_select_layer=-2, pretrained_pth=None, projector_depth=2, llm_lora=None, visual_encoder_lora=None, use_activation_checkpointing=True, max_position_embeddings=None, hidden_size=512, train_stage='2', projector_pth=None, use_projector_pool = False, projector_pool_out_tokens = 1024, projector_pool_pth = None, projector_pool_ln_eps = 1e-6, projector_pool_ln_affine = True, ): super().__init__() self.freeze_llm = freeze_llm self.freeze_visual_encoder = True if train_stage == '1': print('train_stage == 1') self.freeze_llm = True elif train_stage == '2': print('train_stage == 2') self.freeze_llm = False with LoadWoInit(): if isinstance(llm, dict): llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings) self.llm = self._build_from_cfg_or_module(llm) self.llm.config.use_cache = False dispatch_modules(self.llm) self.projector_depth = projector_depth projector_config = ProjectorConfig( visual_hidden_size=hidden_size, llm_hidden_size=self.llm.config.hidden_size, depth=self.projector_depth) self.projector = ProjectorModel(projector_config).to( self.llm.dtype) self.use_projector_pool = use_projector_pool if self.use_projector_pool: hs = int(self.llm.config.hidden_size) self.projector_pool = AdaptiveAvgPool1dLayer( output_size=int(projector_pool_out_tokens), hidden_size=hs, eps=float(projector_pool_ln_eps), affine=bool(projector_pool_ln_affine), impl= 'auto', switch_threshold= 10240, pool_in_fp32= True, ) if self.freeze_llm: print('freeze_llm') self.llm.requires_grad_(False) if use_activation_checkpointing: # For backward compatibility if hasattr(self.llm, 'enable_input_require_grads'): self.llm.enable_input_require_grads() else: self.llm.get_input_embeddings().register_forward_hook( make_inputs_require_grad) self.projector.enable_input_require_grads() # self.visual_encoder.enable_input_require_grads() # if used # enable gradient (activation) checkpointing for memory efficiency self.gradient_checkpointing_enable() self.use_llm_lora = None self.use_visual_encoder_lora = None if self.use_llm_lora: self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing) if projector_pth is not None: print_log(f"Loading projector from {projector_pth}", "current") proj_sd = load_file(projector_pth, device="cpu") # proj_sd = load_file(projector_pth, device="cuda") self.projector.load_state_dict(proj_sd, strict=False) self.projector.to(self.llm.dtype) if pretrained_pth is not None: pretrained_state_dict = guess_load_checkpoint(pretrained_pth) self.load_state_dict(pretrained_state_dict, strict=False) print_log(f'Load pretrained weight from {pretrained_pth}', 'current') self.visual_select_layer = visual_select_layer self._is_init = True self.is_first_iter = True def _parse_lora_config(self, lora_config): if isinstance(lora_config, dict) or isinstance( lora_config, Config) or isinstance(lora_config, ConfigDict): lora_config = BUILDER.build(lora_config) return lora_config def _prepare_llm_for_lora(self, lora_config, use_activation_checkpointing=True): lora_config = self._parse_lora_config(lora_config) self.llm = prepare_model_for_kbit_training( self.llm, use_activation_checkpointing) if lora_config.target_modules is None: modules = find_all_linear_names(self.llm) lora_config.target_modules = modules self.llm = get_peft_model(self.llm, lora_config) def _prepare_visual_encoder_for_lora(self, lora_config, use_activation_checkpointing=True): lora_config = self._parse_lora_config(lora_config) if lora_config.target_modules is None: modules = find_all_linear_names(self.visual_encoder) lora_config.target_modules = modules self.visual_encoder = get_peft_model(self.visual_encoder, lora_config) def gradient_checkpointing_enable(self): self.activation_checkpointing_enable() def activation_checkpointing_enable(self): self.llm.gradient_checkpointing_enable() # self.visual_encoder.gradient_checkpointing_enable() self.projector.gradient_checkpointing_enable() def gradient_checkpointing_disable(self): self.activation_checkpointing_disable() def activation_checkpointing_disable(self): self.llm.gradient_checkpointing_disable() # self.visual_encoder.gradient_checkpointing_disable() self.projector.gradient_checkpointing_disable() def init_weights(self): pass def state_dict(self, *args, **kwargs): state_dict = super().state_dict(*args, **kwargs) to_return = OrderedDict() # Step 1. visual_encoder if self.use_visual_encoder_lora: to_return.update( get_peft_model_state_dict( self.visual_encoder, state_dict=state_dict)) elif not self.freeze_visual_encoder: to_return.update({ k: v for k, v in state_dict.items() if 'visual_encoder.' in k }) # Step 2. LLM if self.use_llm_lora: to_return.update( get_peft_model_state_dict(self.llm, state_dict=state_dict)) elif not self.freeze_llm: to_return.update( {k: v for k, v in state_dict.items() if 'llm.' in k}) # Step 3. Projector to_return.update( {k: v for k, v in state_dict.items() if 'projector.' in k}) to_return.update( {k: v for k, v in state_dict.items() if 'projector_pool.' in k}) return to_return @staticmethod def _prepare_for_long_context_training(cfg, llm_cfg, max_position_embeddings): orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None) if orig_rope_scaling is None: orig_rope_scaling = {'factor': 1} orig_rope_scaling_factor = orig_rope_scaling[ 'factor'] if 'factor' in orig_rope_scaling.keys() else 1 orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None) if orig_ctx_len: orig_ctx_len *= orig_rope_scaling_factor if max_position_embeddings > orig_ctx_len: scaling_factor = float( math.ceil(max_position_embeddings / orig_ctx_len)) llm_cfg.rope_scaling = { 'type': 'linear', 'factor': scaling_factor } # hardcode for internlm2 llm_cfg.attn_implementation = 'flash_attention_2' cfg.config = llm_cfg return cfg, llm_cfg @staticmethod def _prepare_for_flash_attn(cfg, llm_cfg): cls_name = type(llm_cfg).__name__ SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig', 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig', 'Starcoder2Config', 'Starcoder2Config', 'Phi3Config') SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig', 'MistralConfig', 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig', 'Starcoder2Config', 'Starcoder2Config', 'Phi3Config') torch_dtype = torch.bfloat16 if ( torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ else torch.float16 if getattr(cfg, 'attn_implementation', None) is not None: # Flash Attention 2.0 only supports torch.float16 and # torch.bfloat16 dtypes if cfg.attn_implementation == 'flash_attention_2': cfg.torch_dtype = torch_dtype elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: cfg.torch_dtype = torch_dtype cfg.attn_implementation = 'flash_attention_2' elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN: cfg.attn_implementation = 'sdpa' return cfg, llm_cfg @staticmethod def _prepare_for_qlora_zero3(cfg): if (not is_deepspeed_zero3_enabled()) or (not hasattr( cfg, 'quantization_config')): return cfg torch_dtype = torch.bfloat16 if ( torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ else torch.float16 cfg.torch_dtype = torch_dtype quantization_config = cfg.quantization_config quantization_config.bnb_4bit_compute_dtype = torch_dtype quantization_config.bnb_4bit_quant_storage = torch_dtype return cfg def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None): cfg = self._prepare_for_qlora_zero3(cfg) pretrained_model_name_or_path = cfg.pretrained_model_name_or_path llm_cfg = AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=True) cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg) if max_position_embeddings is not None: cfg, llm_cfg = self._prepare_for_long_context_training( cfg, llm_cfg, max_position_embeddings) return cfg def _build_from_cfg_or_module(self, cfg_or_mod): if isinstance(cfg_or_mod, nn.Module): return cfg_or_mod elif isinstance(cfg_or_mod, dict): traverse_dict(cfg_or_mod) return BUILDER.build(cfg_or_mod) else: raise NotImplementedError def forward(self, data, data_samples=None, mode='loss'): if self.is_first_iter: # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to # device # Only required in `LLaVAModel` . # We do not need this in `SupervisedFinetune` . self.to(data['input_ids'].device) self.is_first_iter = False # data_dict['pixel_values']=[[pixel_values of img1], [pixel_values of img2], ...] if 'pixel_values' in data: feat_to_proj = data['pixel_values'].to(self.llm.dtype) # ======================= FIX ======================= # Explicitly enable gradient tracking for the input features. # This is the crucial step to connect the backpropagation graph # to the projector's weights. feat_to_proj.requires_grad_(True) # =================================================== # The diagnostic code you had was good, but this makes it proactive. # You can now remove the old `if using_proj_ckpt:` block # as this solves the root cause. pixel_values = self.projector(feat_to_proj) # Pass the grad-enabled tensor # === NEW: pool along the sequence length (tokens) to L' if self.use_projector_pool: B, L, H = pixel_values.shape pv = pixel_values.transpose(1, 2) # [B, H, L] pv = self.projector_pool(pv) # [B, H, L'] pixel_values = pv.transpose(1, 2).contiguous() # [B, L', H] data['pixel_values'] = pixel_values data.pop('coords', None) data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data) if mode == 'loss': return self.compute_loss(data, data_samples) elif mode == 'predict': return self.predict(data, data_samples) elif mode == 'tensor': return self._forward(data, data_samples) else: raise NotImplementedError def _forward(self, data, data_samples=None): outputs = self.llm(**data) return outputs def predict(self, data, data_samples=None): outputs = self.llm(**data) logits_dict = [{'logits': logits} for logits in outputs.logits] return logits_dict # 替换 LLaVAModel 中的 compute_loss 函数 def compute_loss(self, data, data_samples=None): """ 计算损失的修改版实现。 该版本通过计算批次中每个样本的平均损失来解决长短文本的梯度失衡问题, 使得每个样本对总损失的贡献相等,无论其token长度如何。 """ # 如果 HF 模型可以自己处理,则直接返回 if "labels" not in data: outputs = self.llm(**data) return {"loss": outputs.loss} # 将 labels 从 data 中分离出来,避免其被直接传递给模型 labels = data.pop("labels") # 模型前向传播,获取 logits outputs = self.llm(**data) logits = outputs.logits # 验证 logits 和 labels 的形状是否匹配 if logits.shape[:-1] != labels.shape: raise ValueError( f"Logits and labels shape mismatch. Logits: {logits.shape}, Labels: {labels.shape}" ) # 将 Logits 和 Labels 的 batch 维度移动到第一维,方便迭代 # logits: [B, L, V] -> [L, B, V] # labels: [B, L] -> [L, B] shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # 使用 cross_entropy 计算每个 token 的损失,但不对其进行任何聚合 (reduction='none') # 这将返回一个与 shift_labels 形状相同的损失张量 loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, reduction='none' ) # 将损失张量 reshape 回 [B, L-1] loss = loss.view(shift_logits.size(0), -1) # 对每个样本(每个序列)分别计算平均损失 # 统计每个样本中有效(非-100)的 token 数量 num_tokens_per_sample = (shift_labels != -100).sum(dim=1) # 计算每个样本的总损失 loss_per_sample = loss.sum(dim=1) # 避免除以零 valid_samples_mask = num_tokens_per_sample > 0 # 初始化每个样本的平均损失 mean_loss_per_sample = torch.zeros_like(loss_per_sample) # 只对有效的样本计算平均损失 if valid_samples_mask.any(): mean_loss_per_sample[valid_samples_mask] = loss_per_sample[valid_samples_mask] / num_tokens_per_sample[valid_samples_mask] # 最终的损失是所有样本平均损失的平均值 final_loss = mean_loss_per_sample.mean() return {"loss": final_loss} # def compute_loss(self, data, data_samples=None): # outputs = self.llm(**data) # loss_dict = {'loss': outputs.loss} # return loss_dict # def compute_loss(self, data, data_samples=None): # """ # 计算 token-level 交叉熵损失(分布式/AMP 兼容)。 # - labels 中 -100 为 ignore_index # - 自动屏蔽负 ID(如 -200 图像占位)与 special_ids 对应位置 # """ # # 1) 若无 labels,退回 HF 默认 # if "labels" not in data: # outputs = self.llm(**data) # return {"loss": outputs.loss} # labels = data["labels"] # [B, T] # input_ids = data.get("input_ids", None) # [B, T] or None # attn = data.get("attention_mask", None) # 可无 # # 2) 标签清洗(不改原 labels) # safe_labels = labels.clone() # # 2.1 屏蔽负 ID(如 -200 图像占位) # if input_ids is not None: # neg_mask = (input_ids < 0) # if neg_mask.any(): # safe_labels = torch.where(neg_mask, torch.full_like(safe_labels, -100), safe_labels) # # 2.2 屏蔽 tokenizer 的特殊 token(模板标记等) # if getattr(self, "tokenizer", None) is not None: # try: # special_ids = set(self.tokenizer.all_special_ids or []) # except Exception: # special_ids = set() # if special_ids: # special_mask = torch.zeros_like(input_ids, dtype=torch.bool) # for sid in special_ids: # special_mask |= (input_ids == sid) # if special_mask.any(): # safe_labels = torch.where(special_mask, torch.full_like(safe_labels, -100), safe_labels) # # 3) 前向,拿 logits(不把 labels 交给 HF,避免其先做 per-device mean) # model_inputs = {k: v for k, v in data.items() if k != "labels"} # outputs = self.llm(**model_inputs, use_cache=False) # logits = outputs.logits # [B, T, V] # # 形状断言 # if logits.dim() != 3 or logits.shape[:2] != safe_labels.shape[:2]: # raise RuntimeError( # f"logits/labels length mismatch: logits {tuple(logits.shape)} vs labels {tuple(safe_labels.shape)}" # ) # # 4) CausalLM 对齐 # shift_logits = logits[:, :-1, :].contiguous() # shift_labels = safe_labels[:, 1:].contiguous() # # 5) 统计有效 token & 分布式聚合 # n_tok_local = (shift_labels != -100).sum().to(device=logits.device, dtype=torch.long) # world_size = 1 # n_tok_global = n_tok_local # if dist.is_available() and dist.is_initialized(): # world_size = dist.get_world_size() # with torch.no_grad(): # n_tok_global = n_tok_local.clone() # dist.all_reduce(n_tok_global, op=dist.ReduceOp.SUM) # # 若全局无监督 token,则返回 0(防 NaN) # if n_tok_global.item() == 0: # zero = shift_logits.sum() * 0.0 # return {"loss": zero, "ntok": n_tok_global.to(zero.dtype)} # # 6) 分子(sum over tokens,FP32 更稳) # loss_sum_local = F.cross_entropy( # shift_logits.float().view(-1, shift_logits.size(-1)), # shift_labels.view(-1), # ignore_index=-100, # reduction="sum", # ) # # 7) 全局 token 平均的 loss(抵消 DDP 的梯度平均) # denom = n_tok_global.clamp_min(1).to(loss_sum_local.dtype) # loss = (loss_sum_local / denom) * float(world_size) # # 8) 返回 # ntok_tensor = denom.detach() # return {"loss": loss, "ntok": ntok_tensor} def __getattr__(self, name: str): try: return super().__getattr__(name) except AttributeError: return getattr(self.llm, name) def to_hf(self, cfg, save_dir, fp32=False, save_pretrained_kwargs={}, save_format='xtuner', **kwargs): if save_format == 'xtuner': self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs) elif save_format == 'huggingface': self.to_huggingface_llava(cfg, save_dir, fp32, save_pretrained_kwargs) elif save_format == 'official': self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs) else: raise NotImplementedError def to_xtuner_llava(self, cfg, save_dir, fp32=False, save_pretrained_kwargs={}): # LLM self.llm.config.use_cache = True if not fp32: print_log('Convert LLM to float16', 'current') self.llm.half() if self.use_llm_lora: llm_path = osp.join(save_dir, 'llm_adapter') print_log(f'Saving LLM adapter to {llm_path}', 'current') self.llm.save_pretrained(llm_path, **save_pretrained_kwargs) elif not self.freeze_llm: llm_path = save_dir print_log(f'Saving LLM tokenizer to {llm_path}', 'current') tokenizer = BUILDER.build(cfg.tokenizer) tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs) print_log(f'Saving LLM to {llm_path}', 'current') self.llm.save_pretrained(llm_path, **save_pretrained_kwargs) self.llm.config.use_cache = False # Visual Encoder if self.use_visual_encoder_lora: visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter') print_log( f'Saving visual_encoder adapter to {visual_encoder_path}', 'current') self.visual_encoder.save_pretrained(visual_encoder_path, **save_pretrained_kwargs) elif not self.freeze_visual_encoder: visual_encoder_path = osp.join(save_dir, 'visual_encoder') print_log( 'Saving visual_encoder image_processor to' f'{visual_encoder_path}', 'current') image_processor = BUILDER.build(cfg.image_processor) image_processor.save_pretrained(visual_encoder_path, **save_pretrained_kwargs) print_log(f'Saving visual_encoder to {visual_encoder_path}', 'current') self.visual_encoder.save_pretrained(visual_encoder_path, **save_pretrained_kwargs) # Projector projector_path = osp.join(save_dir, 'projector') print_log(f'Saving projector to {projector_path}', 'current') # self.projector.save_pretrained(projector_path, # **save_pretrained_kwargs) os.makedirs(projector_path, exist_ok=True) output_path = os.path.join(projector_path, 'projector.safetensors') save_file(self.projector.state_dict(), output_path) if self.use_projector_pool and getattr(self, 'projector_pool', None): projector_pool_path = osp.join(save_dir, 'projector_pool') print_log(f'Saving projector_pool to {projector_pool_path}', 'current') torch.save(self.projector_pool.state_dict(), projector_pool_path) def to_huggingface_llava(self, cfg, save_dir, fp32=False, save_pretrained_kwargs={}): LLM_MAPPING = { 'model': 'language_model.model', 'lm_head': 'language_model.lm_head', } VIT_MAPPING = { 'vision_model': 'vision_tower.vision_model', } PROJECTOR_MAPPING = { 'model.0': 'multi_modal_projector.linear_1', 'model.2': 'multi_modal_projector.linear_2', } LONGNET_MAPPING = { 'layers.0': 'LongNet_encoder.layers.0', 'layers.1': 'LongNet_encoder.layers.1', 'layer_norm': 'LongNet_encoder.layer_norm' } assert getattr(self.llm, 'hf_quantizer', None) is None, \ 'This conversion format does not support quantized LLM.' # get state_dict llm = self.llm if self.use_llm_lora: llm = self.llm.merge_and_unload() llm.config.use_cache = True if not fp32: print_log('Convert LLM to float16', 'current') llm.half() assert isinstance(llm, LlamaForCausalLM), \ 'This conversion format only supports LlamaForCausalLM.' llm_state_dict = llm.state_dict() llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING) need_visual_encoder = (not self.freeze_visual_encoder or self.use_visual_encoder_lora) visual_encoder = self.visual_encoder if self.use_visual_encoder_lora: visual_encoder = self.visual_encoder.merge_and_unload() assert isinstance(visual_encoder, CLIPVisionModel),\ 'This conversion format only supports CLIPVisionModel.' if need_visual_encoder: visual_encoder_state_dict = visual_encoder.state_dict() visual_encoder_state_dict = convert_state_dict_to_hf( visual_encoder_state_dict, VIT_MAPPING) else: visual_encoder_state_dict = {} projector_state_dict = self.projector.state_dict() projector_state_dict = convert_state_dict_to_hf( projector_state_dict, PROJECTOR_MAPPING) LongNet_encoder_state_dict = self.LongNet_encoder.state_dict() LongNet_encoder_state_dict = convert_state_dict_to_hf( LongNet_encoder_state_dict, LONGNET_MAPPING) state_dict = { **projector_state_dict, **llm_state_dict, **visual_encoder_state_dict, **LongNet_encoder_state_dict } # init model text_config = llm.config vision_config = visual_encoder.config config = LlavaConfig( text_config=text_config, vision_config=vision_config, attn_implementation='eager') with init_empty_weights(): with warnings.catch_warnings(): warnings.filterwarnings( 'ignore', message='.*non-meta.*', category=UserWarning) model = LlavaForConditionalGeneration(config) model.load_state_dict(state_dict, strict=True, assign=True) # processor cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained tokenizer = BUILDER.build(cfg.tokenizer) tokenizer.add_tokens( AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False), special_tokens=True) tokenizer.add_special_tokens({'pad_token': ''}) image_processor = BUILDER.build(cfg.image_processor) assert isinstance(image_processor, CLIPImageProcessor),\ 'This conversion format only supports CLIPImageProcessor.' processor = LlavaProcessor( tokenizer=tokenizer, image_processor=image_processor) # Pad to 64 for performance reasons pad_shape = 64 pre_expansion_embeddings = \ model.language_model.model.embed_tokens.weight.data mu = torch.mean(pre_expansion_embeddings, dim=0).float() n = pre_expansion_embeddings.size()[0] sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n dist = torch.distributions.multivariate_normal.MultivariateNormal( mu, covariance_matrix=1e-5 * sigma) # We add an image token so we need to resize the model ori_vocab_size = config.text_config.vocab_size tokenizer_vocab_size = tokenizer.encode('')[-1] added_token = tokenizer_vocab_size - ori_vocab_size if added_token > 0: model.resize_token_embeddings(ori_vocab_size + added_token, pad_shape) model.language_model.model.embed_tokens.weight.data[ ori_vocab_size:] = torch.stack( tuple( dist.sample() for _ in range(model.language_model.model.embed_tokens. weight.data[ori_vocab_size:].shape[0])), dim=0, ) model.language_model.lm_head.weight.data[ ori_vocab_size:] = torch.stack( tuple(dist.sample() for _ in range(model.language_model.lm_head.weight. data[ori_vocab_size:].shape[0])), dim=0, ) model.config.image_token_index = tokenizer.encode( DEFAULT_IMAGE_TOKEN)[-1] model.config.pad_token_id = tokenizer.encode('')[-1] # save print_log(f'Saving to {save_dir}', 'current') model.save_pretrained(save_dir, **save_pretrained_kwargs) processor.save_pretrained(save_dir, **save_pretrained_kwargs) def to_official_llava(self, cfg, save_dir, fp32=False, save_pretrained_kwargs={}): VIT_MAPPING = { 'vision_model': 'model.vision_tower.vision_tower.vision_model', } PROJECTOR_MAPPING = { 'model.0': 'model.mm_projector.0', 'model.2': 'model.mm_projector.2', } LONGNET_MAPPING = { 'layers.0': 'LongNet_encoder.layers.0', 'layers.1': 'LongNet_encoder.layers.1', 'layer_norm': 'LongNet_encoder.layer_norm' } try: from llava.model import LlavaConfig, LlavaLlamaForCausalLM except ImportError: raise ImportError( 'Please install llava with ' '`pip install git+https://github.com/haotian-liu/LLaVA.git ' '--no-deps`.') assert getattr(self.llm, 'hf_quantizer', None) is None, \ 'This conversion format does not support quantized LLM.' # get state_dict llm = self.llm if self.use_llm_lora: llm = self.llm.merge_and_unload() llm.config.use_cache = True if not fp32: print_log('Convert LLM to float16', 'current') llm.half() assert isinstance(llm, LlamaForCausalLM), \ 'This conversion format only supports LlamaForCausalLM.' llm_state_dict = llm.state_dict() need_visual_encoder = (not self.freeze_visual_encoder or self.use_visual_encoder_lora) visual_encoder = self.visual_encoder if self.use_visual_encoder_lora: visual_encoder = self.visual_encoder.merge_and_unload() assert isinstance(visual_encoder, CLIPVisionModel),\ 'This conversion format only supports CLIPVisionModel.' if need_visual_encoder: visual_encoder_state_dict = visual_encoder.state_dict() visual_encoder_state_dict = convert_state_dict_to_hf( visual_encoder_state_dict, VIT_MAPPING) else: visual_encoder_state_dict = {} projector_state_dict = self.projector.state_dict() projector_state_dict = convert_state_dict_to_hf( projector_state_dict, PROJECTOR_MAPPING) LongNet_encoder_state_dict = self.LongNet_encoder.state_dict() LongNet_encoder_state_dict = convert_state_dict_to_hf( LongNet_encoder_state_dict, LONGNET_MAPPING) state_dict = { **projector_state_dict, **llm_state_dict, **visual_encoder_state_dict, **LongNet_encoder_state_dict } # init model tokenizer = BUILDER.build(cfg.tokenizer) image_processor = BUILDER.build(cfg.image_processor) assert isinstance(image_processor, CLIPImageProcessor),\ 'This conversion format only supports CLIPImageProcessor.' llava_config_dict = llm.config.__dict__.copy() llava_config_dict.update( dict( image_aspect_ratio='pad', mm_hidden_size=visual_encoder.config.hidden_size, mm_projector_type=f'mlp{self.projector_depth}x_gelu', mm_use_im_patch_token=False, mm_use_im_start_end=False, mm_vision_select_feature='patch', mm_vision_select_layer=self.visual_select_layer, mm_vision_tower=visual_encoder.config.name_or_path, unfreeze_mm_vision_tower=need_visual_encoder, model_type='llava', use_cache=True, use_mm_proj=True)) llava_config = LlavaConfig(**llava_config_dict) with init_empty_weights(): with warnings.catch_warnings(): warnings.filterwarnings( 'ignore', message='.*non-meta.*', category=UserWarning) model = LlavaLlamaForCausalLM(llava_config) model.load_state_dict(state_dict, strict=True, assign=True) # save print_log(f'Saving to {save_dir}', 'current') model.save_pretrained(save_dir, **save_pretrained_kwargs) image_processor.save_pretrained(save_dir, **save_pretrained_kwargs) tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)