pure_model_weights / code /xtuner /model /llava_no_longnet_simple_sampler.py
WinstonHu's picture
Upload folder xtuner to code/xtuner
e5e24c9 verified
raw
history blame
52.9 kB
# Copyright (c) OpenMMLab. All rights reserved.
import math
import os
import os.path as osp
import warnings
from collections import OrderedDict
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import init_empty_weights
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
from peft.tuners.lora.layer import LoraLayer
from safetensors.torch import load_file, save_file
from torch.nn.init import trunc_normal_
from torch.utils.checkpoint import checkpoint
from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
CLIPVisionModel, LlamaForCausalLM,
LlamaTokenizerFast, LlavaConfig,
LlavaForConditionalGeneration, LlavaProcessor)
from transformers.integrations import is_deepspeed_zero3_enabled
from xtuner.model.torchscale.component.multihead_attention import MultiheadAttention
from xtuner.model.torchscale.architecture.config import EncoderConfig
from xtuner.model.torchscale.model.pos_embed import get_2d_sincos_pos_embed
from xtuner.registry import BUILDER
from xtuner.utils import DEFAULT_IMAGE_TOKEN
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
from .sparse_token_merge import SparsePatchMerging
from .utils import (LoadWoInit, find_all_linear_names,
get_peft_model_state_dict, guess_load_checkpoint,
make_inputs_require_grad,
prepare_inputs_labels_for_multimodal, traverse_dict)
# --- 辅助函数 (来自您的代码,保持不变) ---
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
emb = np.concatenate([emb_h, emb_w], axis=1)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega
pos = pos.reshape(-1)
out = np.einsum('m,d->md', pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb
# --- 修正后的 Resampler 类 ---
class Resampler(nn.Module):
"""
修正后的 Resampler 版本:
1. 区分 query_pos_embed 和 input_pos_embed,解决变量冲突。
2. 解除对外部 llm 模块的依赖,提高封装性。
3. 修正 forward 方法中的位置编码应用逻辑和维度匹配。
4. 集成梯度检查点(gradient_checkpointing)功能以节省显存。
"""
def __init__(
self,
grid_size,
embed_dim,
num_heads,
slide_ngrids=1000, # 从外部传入网格大小
kv_dim=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
gradient_checkpointing=False # 控制是否启用梯度检查点
):
super().__init__()
self.num_queries = grid_size ** 2
self.embed_dim = embed_dim
self.num_heads = num_heads
self.slide_ngrids = slide_ngrids
self.gradient_checkpointing = gradient_checkpointing
# 1. 用于 Query 的位置编码 (固定,不参与训练)
self.query_pos_embed = nn.Parameter(
torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float(),
requires_grad=False
)
# 2. 用于输入视觉特征的位置编码 (大 buffer,在 GPU 上生成)
num_patches = slide_ngrids ** 2
self.register_buffer(
'input_pos_embed',
torch.zeros(1, num_patches, embed_dim),
persistent=False
)
# 可学习的 Query 向量
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=.02)
# KV 投影层
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
else:
self.kv_proj = nn.Identity()
# 核心模块
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
# args = EncoderConfig()
# self.attn = MultiheadAttention(args =args,
# embed_dim= embed_dim,
# num_heads=num_heads,
# self_attention=False,
# encoder_decoder_attention=True,
# )
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
# 初始化权重和输入位置编码
self.apply(self._init_weights)
self.initialize_input_pe_weights()
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.no_grad()
def initialize_input_pe_weights(self, chunk_rows: int = 64, chunk_cols: int = 64):
H = W = self.slide_ngrids
D = self.embed_dim
assert D % 4 == 0, "embed_dim 必须是 4 的倍数,才能和 numpy 实现严格对应。"
device = self.input_pos_embed.device
dtype64 = torch.float64
if self.input_pos_embed.shape != (1, H * W, D):
self.input_pos_embed.resize_(1, H * W, D)
pos4d = self.input_pos_embed.view(1, H, W, D)
k = D // 4
inv = 1.0 / (10000 ** (torch.arange(k, device=device, dtype=dtype64) / k))
y_lin = torch.arange(H, device=device, dtype=dtype64)
x_lin = torch.arange(W, device=device, dtype=dtype64)
y_phase = y_lin.unsqueeze(1) * inv.unsqueeze(0)
x_phase = x_lin.unsqueeze(1) * inv.unsqueeze(0)
y_enc = torch.cat([torch.sin(y_phase), torch.cos(y_phase)], dim=1)
x_enc = torch.cat([torch.sin(x_phase), torch.cos(x_phase)], dim=1)
for r0 in range(0, H, chunk_rows):
r1 = min(r0 + chunk_rows, H)
R = r1 - r0
y_chunk = y_enc[r0:r1].unsqueeze(1)
for c0 in range(0, W, chunk_cols):
c1 = min(c0 + chunk_cols, W)
C = c1 - c0
x_chunk = x_enc[c0:c1].unsqueeze(0)
emb_rc = torch.cat([
x_chunk.expand(R, C, 2 * k),
y_chunk.expand(R, C, 2 * k)
], dim=2)
pos4d[0, r0:r1, c0:c1, :].copy_(emb_rc.to(pos4d.dtype))
def _checkpointed_forward(self, q_embed, kv_embed):
# 封装 attention 和后续层,用于梯度检查点
# q_embed: [num_queries, N, C], kv_embed: [L, N, C]
# print(f"_checkpointed_forward q_embed shape: {q_embed.shape}, kv_embed shape: {kv_embed.shape}")
attn_out = self.attn(q_embed, kv_embed, kv_embed)[0]
permuted_out = attn_out
ln_out = self.ln_post(permuted_out)
proj_out = ln_out @ self.proj
return proj_out
def forward(self, x, coords_rc, attn_mask=None):
# x shape: [N, L, C], coords_rc: [L, 2] (row, col indices)
# 1. 从 buffer 中根据坐标索引,为输入 tokens 获取位置编码
# .squeeze(0) 移除批次维度,然后进行索引
# print(f"Resampler input x shape: {x.shape}, coords_rc shape: {coords_rc.shape}")
pos_indices = (coords_rc[..., 0] * self.slide_ngrids + coords_rc[..., 1]).long()
# print(f"Resampler input pos_indices shape: {pos_indices.shape}, values: {pos_indices}")
input_pos = self.input_pos_embed[:, pos_indices, :].squeeze(0) # Shape: [L, C]
# print(f"Resampler input_pos shape: {input_pos.shape}")
# [MODIFIED] 直接在 (N, L, C) 格式上操作,不再需要 permute
x = self.kv_proj(x)
kv_embed = self.ln_kv(x)
N = x.shape[0]
q = self.ln_q(self.query) # Shape: [num_queries, C]
# [MODIFIED] 调整维度扩展方式以适应 batch-first
# 将 query 从 [num_queries, C] 扩展到 [N, num_queries, C]
q_embed = q.unsqueeze(0).expand(N, -1, -1) + self.query_pos_embed.unsqueeze(0)
# [MODIFIED] 将 input_pos 从 [L, C] 扩展到 [1, L, C] 以便与 kv_embed [N, L, C] 相加
kv_embed = kv_embed + input_pos
if self.training and self.gradient_checkpointing:
q_embed.requires_grad_(True)
kv_embed.requires_grad_(True)
out = checkpoint(self._checkpointed_forward, q_embed, kv_embed, use_reentrant=False)
else:
out = self._checkpointed_forward(q_embed, kv_embed)
return out
def enable_input_require_grads(self):
print_log("enable input required grads for projector", 'current')
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
self.model.register_forward_hook(make_inputs_require_grad)
def gradient_checkpointing_enable(self):
self.gradient_checkpointing = True
def gradient_checkpointing_disable(self):
self.gradient_checkpointing = False
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
# =================================================================================================
# End of Resampler code
# =================================================================================================
def convert_state_dict_to_hf(state_dict, mapping):
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith('.inv_freq'):
continue
for key_to_modify, new_key in mapping.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
new_state_dict[key] = value
return new_state_dict
class AdaptiveAvgPool1dLayer(nn.Module):
def __init__(self, output_size):
super(AdaptiveAvgPool1dLayer, self).__init__()
self.output_size = output_size
def forward(self, x):
return F.adaptive_avg_pool1d(x, self.output_size)
class LLaVAModel(BaseModel):
def __init__(self,
llm,
freeze_llm=True,
visual_select_layer=-2,
pretrained_pth=None,
projector_depth=2,
llm_lora=None,
visual_encoder_lora=None,
use_activation_checkpointing=True,
max_position_embeddings=None,
hidden_size=512,
train_stage='2',
# slide/pos-embed 参数
slide_ngrids=1000,
tile_size=224,
# 各子模块权重路径
projector_pth=None,
resampler_pth=None,
token_merge_pth=None,
# Token Merge
enable_token_merge=True,
# Resampler 配置
use_resampler=True,
resampler_num_latents=256,
resampler_heads = 16,
# === 新增:Stage-2 冻结选项 ===
freeze_mm_in_stage2=False, # 总开关:在 stage-2 冻结 projector / resampler / token_merge
freeze_projector_stage2=None, # 子开关(None 表示跟随总开关)
freeze_resampler_stage2=None, # 子开关(None 表示跟随总开关)
freeze_token_merge_stage2=None # 子开关(None 表示跟随总开关)
):
super().__init__()
self.freeze_llm = freeze_llm
self.freeze_visual_encoder = True
self.tile_size = tile_size
# 训练阶段控制
if train_stage == '0':
print_log('train_stage == 0', 'current')
self.freeze_llm = True
if train_stage == '1':
print_log('train_stage == 1', 'current')
self.freeze_llm = True
elif train_stage == '2':
print_log('train_stage == 2', 'current')
self.freeze_llm = False
# 解析 stage-2 的冻结意图
def _resolve(flag):
return freeze_mm_in_stage2 if flag is None else bool(flag)
self._freeze_projector_in_s2 = _resolve(freeze_projector_stage2)
self._freeze_resampler_in_s2 = _resolve(freeze_resampler_stage2)
self._freeze_token_merge_in_s2 = _resolve(freeze_token_merge_stage2)
# 构建 / 派发 LLM
with LoadWoInit():
if isinstance(llm, dict):
llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
self.llm = self._build_from_cfg_or_module(llm)
self.llm.config.use_cache = False
dispatch_modules(self.llm)
# Token Merge
self.enable_token_merge = enable_token_merge
if self.enable_token_merge:
self.token_merge = SparsePatchMerging(
embed_dim=hidden_size,
layernorm_eps=1e-6,
merge_size=2
)
# Projector
self.projector_depth = projector_depth
projector_config = ProjectorConfig(
visual_hidden_size=hidden_size * 4 if self.enable_token_merge else hidden_size,
llm_hidden_size=self.llm.config.hidden_size,
depth=self.projector_depth
)
self.projector = ProjectorModel(projector_config).to(self.llm.dtype)
self.projector.requires_grad_(True)
# Resampler
self.use_resampler = use_resampler
self.slide_ngrids = slide_ngrids
if self.use_resampler:
self.resampler_num_latents = resampler_num_latents
print_log(f'using simple Resampler with {resampler_num_latents} latents', 'current')
self.resampler = Resampler(
grid_size=int(math.sqrt(self.resampler_num_latents)),
embed_dim=self.llm.config.hidden_size,
num_heads=resampler_heads,
kv_dim=self.llm.config.hidden_size,
).to(self.llm.dtype)
# 冻结 LLM
if self.freeze_llm:
print('freeze_llm')
self.llm.requires_grad_(False)
# 激活检查点(按需对冻结模块跳过 input-grad 使能)
if use_activation_checkpointing:
if hasattr(self.llm, 'enable_input_require_grads'):
self.llm.enable_input_require_grads()
else:
self.llm.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# Resampler is a simple nn.Module and does not have this method.
# If checkpointing is desired for it, its forward pass should be wrapped.
# For this modification, we will omit its specific checkpointing setup.
_projector_frozen = (train_stage == '2' and self._freeze_projector_in_s2)
if not _projector_frozen:
print('enable projector input require grads')
print_log('enable projector input require grads', 'current')
self.projector.enable_input_require_grads()
else:
print_log('[stage-2] Skipping projector.enable_input_require_grads() (frozen)', 'current')
# 启用激活检查点
self.gradient_checkpointing_enable()
# LoRA
self.use_llm_lora = llm_lora is not None
self.use_visual_encoder_lora = None
if self.use_llm_lora:
print_log(f"Building lora {llm_lora.__str__}", "current")
self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
self.verify_lora()
# 加载 token_merge / projector / resampler 的 safetensors
if token_merge_pth is not None and enable_token_merge and hasattr(self, 'token_merge'):
print_log(f'loading token_merge from {token_merge_pth}', 'current')
merger_sd = load_file(token_merge_pth, device='cpu')
self.token_merge.load_state_dict(merger_sd, strict=False)
self.token_merge.to(self.llm.dtype)
if projector_pth is not None:
print_log(f"Loading projector from {projector_pth}", "current")
proj_sd = load_file(projector_pth, device="cpu")
self.projector.load_state_dict(proj_sd, strict=False)
self.projector.to(self.llm.dtype)
if resampler_pth is not None and self.use_resampler and hasattr(self, 'resampler'):
print_log(f'Loading resampler from {resampler_pth}', 'current')
resampler_sd = load_file(resampler_pth, device="cpu")
self.resampler.load_state_dict(resampler_sd, strict=False)
self.resampler.to(self.llm.dtype)
# 额外加载 float 权重(可选)
if pretrained_pth is not None:
sd = guess_load_checkpoint(pretrained_pth)
model_sd = self.state_dict()
filtered = {k: v for k, v in sd.items() if k in model_sd and model_sd[k].shape == v.shape}
missing, unexpected = self.load_state_dict(filtered, strict=False)
print_log(f"Loaded float ckpt from {pretrained_pth}", "current")
print_log(f" missing: {missing}", "current")
print_log(f" unexpected:{unexpected}", "current")
# 记录可视层
self.visual_select_layer = visual_select_layer
# 初始化标志
self._is_init = True
self.is_first_iter = True
# === 关键新增:在 Stage-2 按需冻结三个多模态子模块 ===
if train_stage == '2':
# projector
if hasattr(self, 'projector') and self._freeze_projector_in_s2:
self.projector.requires_grad_(False)
self.projector.eval()
print_log('[stage-2] Freezing projector parameters', 'current')
# resampler
if getattr(self, 'use_resampler', False) and hasattr(self, 'resampler') and self._freeze_resampler_in_s2:
self.resampler.requires_grad_(False)
self.resampler.eval()
print_log('[stage-2] Freezing resampler parameters', 'current')
# token_merge
if getattr(self, 'enable_token_merge', False) and hasattr(self, 'token_merge') and self._freeze_token_merge_in_s2:
self.token_merge.requires_grad_(False)
self.token_merge.eval()
print_log('[stage-2] Freezing token_merge parameters', 'current')
def _parse_lora_config(self, lora_config):
if isinstance(lora_config, dict) or isinstance(
lora_config, Config) or isinstance(lora_config, ConfigDict):
lora_config = BUILDER.build(lora_config)
return lora_config
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _prepare_llm_for_lora(self,
lora_config,
use_activation_checkpointing=True):
lora_config = self._parse_lora_config(lora_config)
self.llm = prepare_model_for_kbit_training(
self.llm, use_activation_checkpointing)
if lora_config.target_modules is None:
modules = find_all_linear_names(self.llm)
lora_config.target_modules = modules
self.llm = get_peft_model(self.llm, lora_config)
def verify_lora(self):
m = self.llm
# 1) Wrapped as a PEFT model
assert isinstance(m, PeftModel), "LoRA not applied: model is not a PeftModel"
# 2) Adapters are registered and active
adapters = m.peft_config # dict: {adapter_name: LoraConfig}
assert len(adapters) > 0, "No adapters registered in peft_config"
active = m.active_adapter if hasattr(m, "active_adapter") else None
assert active in adapters, f"Active adapter {active} not found in peft_config"
# 3) LoRA layers are present on target modules
lora_modules = [mod for mod in m.modules() if isinstance(mod, LoraLayer)]
assert len(lora_modules) > 0, "No LoraLayer modules found (check target_modules)"
# 4) LoRA params are the only trainable ones (typical for QLoRA)
trainable = [(n,p) for n,p in m.named_parameters() if p.requires_grad]
assert len(trainable) > 0, "No trainable parameters (LoRA params are not set to requires_grad=True)"
# Optional: sanity-check that trainable params look like LoRA
suspicious = [n for n,_ in trainable if "lora_" not in n and "modules_to_save" not in n]
# It's okay if you intentionally left some modules_to_save; adjust as needed.
assert len(suspicious) == 0, f"Unexpected trainable params (not LoRA): {suspicious[:5]}"
# 5) Quick count + readable log
total = sum(p.numel() for _,p in m.named_parameters())
trainable_cnt = sum(p.numel() for _,p in trainable)
ratio = trainable_cnt / total
print(f"[LoRA OK] adapters={list(adapters.keys())}, active={active}, "
f"LoraLayers={len(lora_modules)}, trainable={trainable_cnt}/{total} ({ratio:.4%})")
# 6) Forward+backward smoke test to confirm gradients flow to LoRA only
m.train()
dummy_inp = torch.randint(0, m.get_input_embeddings().num_embeddings, (1, 8)).to(next(m.parameters()).device)
out = m(input_ids=dummy_inp, labels=dummy_inp)
out.loss.backward() # should not error
# Ensure some LoRA grads exist
lora_grads = [p.grad for _,p in m.named_parameters() if p.requires_grad and p.grad is not None]
assert len(lora_grads) > 0, "No gradients on LoRA parameters after backward()"
def _prepare_visual_encoder_for_lora(self,
lora_config,
use_activation_checkpointing=True):
lora_config = self._parse_lora_config(lora_config)
if lora_config.target_modules is None:
modules = find_all_linear_names(self.visual_encoder)
lora_config.target_modules = modules
self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
def gradient_checkpointing_enable(self, use_reentrant=False):
self.activation_checkpointing_enable(use_reentrant=use_reentrant)
def activation_checkpointing_enable(self, use_reentrant=False):
# LLM
try:
self.llm.gradient_checkpointing_enable(use_reentrant=use_reentrant)
except TypeError:
# older HF versions
self.llm.gradient_checkpointing_enable()
# projector
try:
self.projector.gradient_checkpointing_enable(use_reentrant=use_reentrant)
except TypeError:
self.projector.gradient_checkpointing_enable()
if getattr(self, 'use_resampler', False) and getattr(self, 'resampler', None) is not None:
try:
self.resampler.gradient_checkpointing_enable(use_reentrant=use_reentrant)
except:
self.resampler.gradient_checkpointing_enable()
def gradient_checkpointing_disable(self):
self.activation_checkpointing_disable()
def activation_checkpointing_disable(self):
self.llm.gradient_checkpointing_disable()
self.projector.gradient_checkpointing_disable()
if getattr(self, 'use_resampler', False) and getattr(self, 'resampler', None) is not None:
self.resampler.gradient_checkpointing_disable()
def init_weights(self):
pass
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
to_return = OrderedDict()
# Step 1. visual_encoder
if self.use_visual_encoder_lora:
to_return.update(
get_peft_model_state_dict(
self.visual_encoder, state_dict=state_dict))
elif not self.freeze_visual_encoder:
to_return.update({
k: v
for k, v in state_dict.items() if 'visual_encoder.' in k
})
# Step 2. LLM
if self.use_llm_lora:
to_return.update(
get_peft_model_state_dict(self.llm, state_dict=state_dict))
elif not self.freeze_llm:
to_return.update(
{k: v
for k, v in state_dict.items() if 'llm.' in k})
# Step 3. Projector
to_return.update(
{k: v
for k, v in state_dict.items() if 'projector.' in k})
# Step 4. Resampler
if getattr(self, 'use_resampler', False) and getattr(self, 'resampler', None) is not None:
to_return.update({k: v for k, v in state_dict.items() if 'resampler.' in k})
# step 5 token merger
if getattr(self, 'token_merge', False):
to_return.update({k: v for k, v in state_dict.items() if 'token_merge.' in k})
return to_return
@staticmethod
def _prepare_for_long_context_training(cfg, llm_cfg,
max_position_embeddings):
orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
if orig_rope_scaling is None:
orig_rope_scaling = {'factor': 1}
orig_rope_scaling_factor = orig_rope_scaling[
'factor'] if 'factor' in orig_rope_scaling.keys() else 1
orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
if orig_ctx_len:
orig_ctx_len *= orig_rope_scaling_factor
if max_position_embeddings > orig_ctx_len:
scaling_factor = float(
math.ceil(max_position_embeddings / orig_ctx_len))
llm_cfg.rope_scaling = {
'type': 'linear',
'factor': scaling_factor
}
# hardcode for internlm2
llm_cfg.attn_implementation = 'flash_attention_2'
cfg.config = llm_cfg
return cfg, llm_cfg
@staticmethod
def _prepare_for_flash_attn(cfg, llm_cfg):
cls_name = type(llm_cfg).__name__
SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
'Starcoder2Config', 'Starcoder2Config',
'Phi3Config')
SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
'MistralConfig', 'MixtralConfig', 'Qwen2Config',
'Qwen2MoeConfig', 'Starcoder2Config',
'Starcoder2Config', 'Phi3Config')
torch_dtype = torch.bfloat16 if (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
else torch.float16
if getattr(cfg, 'attn_implementation', None) is not None:
# Flash Attention 2.0 only supports torch.float16 and
# torch.bfloat16 dtypes
if cfg.attn_implementation == 'flash_attention_2':
cfg.torch_dtype = torch_dtype
elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
cfg.torch_dtype = torch_dtype
cfg.attn_implementation = 'flash_attention_2'
elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
cfg.attn_implementation = 'sdpa'
return cfg, llm_cfg
@staticmethod
def _prepare_for_qlora_zero3(cfg):
if (not is_deepspeed_zero3_enabled()) or (not hasattr(
cfg, 'quantization_config')):
return cfg
torch_dtype = torch.bfloat16 if (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
else torch.float16
cfg.torch_dtype = torch_dtype
quantization_config = cfg.quantization_config
quantization_config.bnb_4bit_compute_dtype = torch_dtype
quantization_config.bnb_4bit_quant_storage = torch_dtype
return cfg
def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
cfg = self._prepare_for_qlora_zero3(cfg)
pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
llm_cfg = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True)
cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
if max_position_embeddings is not None:
cfg, llm_cfg = self._prepare_for_long_context_training(
cfg, llm_cfg, max_position_embeddings)
return cfg
def _build_from_cfg_or_module(self, cfg_or_mod):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
elif isinstance(cfg_or_mod, dict):
traverse_dict(cfg_or_mod)
return BUILDER.build(cfg_or_mod)
else:
raise NotImplementedError
def coords_to_pos(self, coords, tile_size: int = 224):
"""
This function is used to convert the coordinates to the positional indices
Arguments:
----------
coords: torch.Tensor
The coordinates of the patches, of shape [N, L, 2]
output: torch.Tensor
The positional indices of the patches, of shape [N, L]
"""
coords_ = torch.floor(coords / tile_size)
pos = coords_[..., 0] * self.slide_ngrids + coords_[..., 1]
return pos.long() # add 1 for the cls token
@staticmethod
def _coords_rc_to_pos(coords_rc: torch.Tensor, ngrids: int) -> torch.Tensor:
if coords_rc.dtype.is_floating_point:
coords_rc = coords_rc.round().to(torch.long)
# row = coords_rc[:, 0].clamp_(0, ngrids-1)
# col = coords_rc[:, 1].clamp_(0, ngrids-1)
return (coords_rc[..., 0] * ngrids + coords_rc[..., 1]).long() # +1 for cls
def forward(self, data, data_samples=None, mode='loss'):
if self.is_first_iter:
# hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
# device
# Only required in `LLaVAModel` .
# We do not need this in `SupervisedFinetune` .
self.to(data['input_ids'].device)
self.is_first_iter = False
coords = None
if 'pixel_values' in data:
feat_to_proj = data['pixel_values'].to(self.llm.dtype) # torch.Size([1, img_num, 512])
# ensure requires_grad for gradient checkpointing
feat_to_proj.requires_grad_(True)
if 'coords' in data:
coords = data['coords'].to(self.llm.dtype)
# Accept: list[tensor], [L,2] tensor, or [B,L,2] tensor
coords_t = coords[0] if isinstance(coords, list) else coords
Bx = feat_to_proj.size(0) # actual batch size of inputs
if not torch.is_tensor(coords_t):
raise ValueError("coords must be a Tensor or list[Tensor].")
if coords_t.dim() == 2:
# [L, 2]
coords_rc = coords_t
elif coords_t.dim() == 3:
# [B, L, 2] -> ensure B matches and either B==1 or all examples share coords
if coords_t.size(0) != Bx:
raise ValueError(f"coords batch dim mismatch: got {coords_t.size(0)} but inputs have B={Bx}")
if Bx == 1:
coords_rc = coords_t[0]
else:
# require same coords across the batch (cheap equality check)
if not torch.equal(coords_t, coords_t[0].unsqueeze(0).expand_as(coords_t)):
raise NotImplementedError(
"Per-example coords (varying across batch) are not supported by the current "
"patch-merging/layout path. Use batch size 1 or share coords across the batch."
)
coords_rc = coords_t[0]
else:
raise ValueError("coords must have shape [L,2] or [B,L,2].")
if coords_rc.size(-1) != 2:
raise ValueError("coords last dimension must be 2.")
else:
raise RuntimeError
# only works for batch size one
if self.enable_token_merge:
feat_to_proj, coords_rc_merged, _ = self.token_merge(
x=feat_to_proj,
coords_rc=self._coords_to_rowcol(coords_rc),
padmask=torch.zeros([feat_to_proj.size(0), feat_to_proj.size(1)],
device=feat_to_proj.device, dtype=torch.bool)
)
# print(f"After token_merge, feat_to_proj: {feat_to_proj.shape}, coords_rc_merged: {coords_rc_merged.shape}")
else:
coords_rc_merged = self._coords_to_rowcol(coords_rc)
padmask_merged = torch.zeros([feat_to_proj.size(0), feat_to_proj.size(1)],
device=feat_to_proj.device, dtype=torch.bool)
pixel_values = self.projector(feat_to_proj.to(self.llm.dtype)) # output shape [1, patch_num, hidden_size]
# print(f"After projector, pixel_values: {pixel_values.shape}")
if self.use_resampler and getattr(self, 'resampler', None) is not None:
pixel_values = self.resampler(pixel_values, coords_rc_merged,
attn_mask= None) # [1, num_latents, hidden_size]
data['pixel_values'] = pixel_values
# remove coords
data.pop('coords', None)
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
if mode == 'loss':
return self.compute_loss(data, data_samples)
elif mode == 'predict':
return self.predict(data, data_samples)
elif mode == 'tensor':
return self._forward(data, data_samples)
else:
raise NotImplementedError
@staticmethod
def _coords_to_rowcol(coords_xy: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
x = coords_xy[:, 0]
y = coords_xy[:, 1]
x_for_unique = x
y_for_unique = y
if x_for_unique.dtype.is_floating_point:
x_for_unique = x_for_unique.round().to(torch.int)
y_for_unique = y_for_unique.round().to(torch.int)
x_sorted = torch.unique(x_for_unique, sorted=True)
y_sorted = torch.unique(y_for_unique, sorted = True)
col = torch.searchsorted(x_sorted, x)
row = torch.searchsorted(y_sorted, y)
return torch.stack([row, col], dim=-1)
def _forward(self, data, data_samples=None):
outputs = self.llm(**data)
return outputs
def predict(self, data, data_samples=None):
outputs = self.llm(**data)
logits_dict = [{'logits': logits} for logits in outputs.logits]
return logits_dict
def compute_loss(self, data, data_samples=None):
"""
计算损失的修改版实现。
该版本通过计算批次中每个样本的平均损失来解决长短文本的梯度失衡问题,
使得每个样本对总损失的贡献相等,无论其token长度如何。
"""
# 如果 HF 模型可以自己处理,则直接返回
if "labels" not in data:
outputs = self.llm(**data)
return {"loss": outputs.loss}
# 将 labels 从 data 中分离出来,避免其被直接传递给模型
labels = data.pop("labels")
# 模型前向传播,获取 logits
outputs = self.llm(**data)
logits = outputs.logits
# 验证 logits 和 labels 的形状是否匹配
if logits.shape[:-1] != labels.shape:
raise ValueError(
f"Logits and labels shape mismatch. Logits: {logits.shape}, Labels: {labels.shape}"
)
# 将 Logits 和 Labels 的 batch 维度移动到第一维,方便迭代
# logits: [B, L, V] -> [L, B, V]
# labels: [B, L] -> [L, B]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 使用 cross_entropy 计算每个 token 的损失,但不对其进行任何聚合 (reduction='none')
# 这将返回一个与 shift_labels 形状相同的损失张量
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
reduction='none'
)
# 将损失张量 reshape 回 [B, L-1]
loss = loss.view(shift_logits.size(0), -1)
# 对每个样本(每个序列)分别计算平均损失
# 统计每个样本中有效(非-100)的 token 数量
num_tokens_per_sample = (shift_labels != -100).sum(dim=1)
# 计算每个样本的总损失
loss_per_sample = loss.sum(dim=1)
# 避免除以零
valid_samples_mask = num_tokens_per_sample > 0
# 初始化每个样本的平均损失
mean_loss_per_sample = torch.zeros_like(loss_per_sample)
# 只对有效的样本计算平均损失
if valid_samples_mask.any():
mean_loss_per_sample[valid_samples_mask] = loss_per_sample[valid_samples_mask] / num_tokens_per_sample[valid_samples_mask]
# 最终的损失是所有样本平均损失的平均值
final_loss = mean_loss_per_sample.mean()
return {"loss": final_loss}
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.llm, name)
def to_hf(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={},
save_format='xtuner',
**kwargs):
if save_format == 'xtuner':
self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
elif save_format == 'huggingface':
self.to_huggingface_llava(cfg, save_dir, fp32,
save_pretrained_kwargs)
elif save_format == 'official':
self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
else:
raise NotImplementedError
def to_xtuner_llava(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={}):
# LLM
self.llm.config.use_cache = True
if not fp32:
print_log('Convert LLM to float16', 'current')
self.llm.half()
if self.use_llm_lora:
llm_path = osp.join(save_dir, 'llm_adapter')
print_log(f'Saving LLM adapter to {llm_path}', 'current')
self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
elif not self.freeze_llm:
llm_path = save_dir
print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
tokenizer = BUILDER.build(cfg.tokenizer)
tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
print_log(f'Saving LLM to {llm_path}', 'current')
self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
self.llm.config.use_cache = False
# Visual Encoder
if self.use_visual_encoder_lora:
visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
print_log(
f'Saving visual_encoder adapter to {visual_encoder_path}',
'current')
self.visual_encoder.save_pretrained(visual_encoder_path,
**save_pretrained_kwargs)
elif not self.freeze_visual_encoder:
visual_encoder_path = osp.join(save_dir, 'visual_encoder')
print_log(
'Saving visual_encoder image_processor to'
f'{visual_encoder_path}', 'current')
image_processor = BUILDER.build(cfg.image_processor)
image_processor.save_pretrained(visual_encoder_path,
**save_pretrained_kwargs)
print_log(f'Saving visual_encoder to {visual_encoder_path}',
'current')
self.visual_encoder.save_pretrained(visual_encoder_path,
**save_pretrained_kwargs)
# Projector
projector_path = osp.join(save_dir, 'projector')
print_log(f'Saving projector to {projector_path}', 'current')
os.makedirs(projector_path, exist_ok=True)
output_path = os.path.join(projector_path, 'projector.safetensors')
save_file(self.projector.state_dict(), output_path)
if self.use_resampler and hasattr(self, 'resampler'):
resampler_path = osp.join(save_dir, "resampler")
print_log(f'Saving Resampler to {resampler_path}', 'current')
os.makedirs(resampler_path, exist_ok=True)
resampler_output_path = os.path.join(resampler_path, 'resampler.safetensors')
save_file(self.resampler.state_dict(), resampler_output_path)
if self.enable_token_merge and hasattr(self, 'token_merge'):
merger_path = osp.join(save_dir, 'token_merger')
print_log(f'Saving token merger to{merger_path}', 'current')
os.makedirs(merger_path, exist_ok= True)
merger_path = os.path.join(merger_path, 'merger.safetensors')
save_file(self.token_merge.state_dict(), merger_path)
def to_huggingface_llava(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={}):
if self.use_resampler:
warnings.warn("Conversion to HuggingFace LLaVA format with a custom resampler is not supported. "
"The resampler weights will not be saved.")
LLM_MAPPING = {
'model': 'language_model.model',
'lm_head': 'language_model.lm_head',
}
VIT_MAPPING = {
'vision_model': 'vision_tower.vision_model',
}
PROJECTOR_MAPPING = {
'model.0': 'multi_modal_projector.linear_1',
'model.2': 'multi_modal_projector.linear_2',
}
assert getattr(self.llm, 'hf_quantizer', None) is None, \
'This conversion format does not support quantized LLM.'
# get state_dict
llm = self.llm
if self.use_llm_lora:
llm = self.llm.merge_and_unload()
llm.config.use_cache = True
if not fp32:
print_log('Convert LLM to float16', 'current')
llm.half()
assert isinstance(llm, LlamaForCausalLM), \
'This conversion format only supports LlamaForCausalLM.'
llm_state_dict = llm.state_dict()
llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
need_visual_encoder = (not self.freeze_visual_encoder
or self.use_visual_encoder_lora)
visual_encoder = self.visual_encoder
if self.use_visual_encoder_lora:
visual_encoder = self.visual_encoder.merge_and_unload()
assert isinstance(visual_encoder, CLIPVisionModel),\
'This conversion format only supports CLIPVisionModel.'
if need_visual_encoder:
visual_encoder_state_dict = visual_encoder.state_dict()
visual_encoder_state_dict = convert_state_dict_to_hf(
visual_encoder_state_dict, VIT_MAPPING)
else:
visual_encoder_state_dict = {}
projector_state_dict = self.projector.state_dict()
projector_state_dict = convert_state_dict_to_hf(
projector_state_dict, PROJECTOR_MAPPING)
state_dict = {
**projector_state_dict,
**llm_state_dict,
**visual_encoder_state_dict,
}
# init model
text_config = llm.config
vision_config = visual_encoder.config
config = LlavaConfig(
text_config=text_config,
vision_config=vision_config,
attn_implementation='eager')
with init_empty_weights():
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', message='.*non-meta.*', category=UserWarning)
model = LlavaForConditionalGeneration(config)
model.load_state_dict(state_dict, strict=False, assign=True) # strict=False to ignore missing resampler
# processor
cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
tokenizer = BUILDER.build(cfg.tokenizer)
tokenizer.add_tokens(
AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
special_tokens=True)
tokenizer.add_special_tokens({'pad_token': '<pad>'})
image_processor = BUILDER.build(cfg.image_processor)
assert isinstance(image_processor, CLIPImageProcessor),\
'This conversion format only supports CLIPImageProcessor.'
processor = LlavaProcessor(
tokenizer=tokenizer, image_processor=image_processor)
# Pad to 64 for performance reasons
pad_shape = 64
pre_expansion_embeddings = \
model.language_model.model.embed_tokens.weight.data
mu = torch.mean(pre_expansion_embeddings, dim=0).float()
n = pre_expansion_embeddings.size()[0]
sigma = ((pre_expansion_embeddings - mu).T
@ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(
mu, covariance_matrix=1e-5 * sigma)
# We add an image token so we need to resize the model
ori_vocab_size = config.text_config.vocab_size
tokenizer_vocab_size = tokenizer.encode('<pad>')[-1]
added_token = tokenizer_vocab_size - ori_vocab_size
if added_token > 0:
model.resize_token_embeddings(ori_vocab_size + added_token,
pad_shape)
model.language_model.model.embed_tokens.weight.data[
ori_vocab_size:] = torch.stack(
tuple(
dist.sample()
for _ in range(model.language_model.model.embed_tokens.
weight.data[ori_vocab_size:].shape[0])),
dim=0,
)
model.language_model.lm_head.weight.data[
ori_vocab_size:] = torch.stack(
tuple(dist.sample()
for _ in range(model.language_model.lm_head.weight.
data[ori_vocab_size:].shape[0])),
dim=0,
)
model.config.image_token_index = tokenizer.encode(
DEFAULT_IMAGE_TOKEN)[-1]
model.config.pad_token_id = tokenizer.encode('<pad>')[-1]
# save
print_log(f'Saving to {save_dir}', 'current')
model.save_pretrained(save_dir, **save_pretrained_kwargs)
processor.save_pretrained(save_dir, **save_pretrained_kwargs)
def to_official_llava(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={}):
if self.use_resampler:
warnings.warn("Conversion to official LLaVA format with a custom resampler is not supported. "
"The resampler weights will not be saved.")
VIT_MAPPING = {
'vision_model': 'model.vision_tower.vision_tower.vision_model',
}
PROJECTOR_MAPPING = {
'model.0': 'model.mm_projector.0',
'model.2': 'model.mm_projector.2',
}
try:
from llava.model import LlavaConfig, LlavaLlamaForCausalLM
except ImportError:
raise ImportError(
'Please install llava with '
'`pip install git+https://github.com/haotian-liu/LLaVA.git '
'--no-deps`.')
assert getattr(self.llm, 'hf_quantizer', None) is None, \
'This conversion format does not support quantized LLM.'
# get state_dict
llm = self.llm
if self.use_llm_lora:
llm = self.llm.merge_and_unload()
llm.config.use_cache = True
if not fp32:
print_log('Convert LLM to float16', 'current')
llm.half()
assert isinstance(llm, LlamaForCausalLM), \
'This conversion format only supports LlamaForCausalLM.'
llm_state_dict = llm.state_dict()
need_visual_encoder = (not self.freeze_visual_encoder
or self.use_visual_encoder_lora)
visual_encoder = self.visual_encoder
if self.use_visual_encoder_lora:
visual_encoder = self.visual_encoder.merge_and_unload()
assert isinstance(visual_encoder, CLIPVisionModel),\
'This conversion format only supports CLIPVisionModel.'
if need_visual_encoder:
visual_encoder_state_dict = visual_encoder.state_dict()
visual_encoder_state_dict = convert_state_dict_to_hf(
visual_encoder_state_dict, VIT_MAPPING)
else:
visual_encoder_state_dict = {}
projector_state_dict = self.projector.state_dict()
projector_state_dict = convert_state_dict_to_hf(
projector_state_dict, PROJECTOR_MAPPING)
state_dict = {
**projector_state_dict,
**llm_state_dict,
**visual_encoder_state_dict,
}
# init model
tokenizer = BUILDER.build(cfg.tokenizer)
image_processor = BUILDER.build(cfg.image_processor)
assert isinstance(image_processor, CLIPImageProcessor),\
'This conversion format only supports CLIPImageProcessor.'
llava_config_dict = llm.config.__dict__.copy()
llava_config_dict.update(
dict(
image_aspect_ratio='pad',
mm_hidden_size=visual_encoder.config.hidden_size,
mm_projector_type=f'mlp{self.projector_depth}x_gelu',
mm_use_im_patch_token=False,
mm_use_im_start_end=False,
mm_vision_select_feature='patch',
mm_vision_select_layer=self.visual_select_layer,
mm_vision_tower=visual_encoder.config.name_or_path,
unfreeze_mm_vision_tower=need_visual_encoder,
model_type='llava',
use_cache=True,
use_mm_proj=True))
llava_config = LlavaConfig(**llava_config_dict)
with init_empty_weights():
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', message='.*non-meta.*', category=UserWarning)
model = LlavaLlamaForCausalLM(llava_config)
model.load_state_dict(state_dict, strict=False, assign=True) # strict=False to ignore missing resampler
# save
print_log(f'Saving to {save_dir}', 'current')
model.save_pretrained(save_dir, **save_pretrained_kwargs)
image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)