pure_model_weights / code /xtuner /model /llava_only_projector.py
WinstonHu's picture
Upload folder xtuner to code/xtuner
e5e24c9 verified
raw
history blame
39.4 kB
# Copyright (c) OpenMMLab. All rights reserved.
import math
import os.path as osp
import warnings
from collections import OrderedDict
import os
from safetensors.torch import load_file, save_file
import torch
import torch.distributed as dist # === MOD ===
import torch.nn as nn
from accelerate import init_empty_weights
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel
from peft import get_peft_model, prepare_model_for_kbit_training
from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
CLIPVisionModel, LlamaForCausalLM,
LlamaTokenizerFast, LlavaConfig,
LlavaForConditionalGeneration, LlavaProcessor)
from transformers.integrations import is_deepspeed_zero3_enabled
from xtuner.registry import BUILDER
from xtuner.utils import DEFAULT_IMAGE_TOKEN
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
from .utils import (LoadWoInit, find_all_linear_names,
get_peft_model_state_dict, guess_load_checkpoint,
make_inputs_require_grad,
prepare_inputs_labels_for_multimodal, traverse_dict)
import torch.nn.functional as F
def convert_state_dict_to_hf(state_dict, mapping):
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith('.inv_freq'):
continue
for key_to_modify, new_key in mapping.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
new_state_dict[key] = value
return new_state_dict
class AdaptiveAvgPool1dLayer(nn.Module):
"""
自适应平均池化层(沿序列维 L),带输入/输出 LayerNorm,并在大 L 时切换为线性插值,
避免 CUDA AdaptiveAvgPool 的 sharedMem 限制导致的报错。
期望输入:x ∈ [B, H, L]
- 先在 [B, L, H] 上做输入层归一化(LayerNorm(H))。
- 对序列维 L 做池化/插值到 output_size。
- 再在 [B, L_out, H] 上做输出层归一化。
参数:
output_size (int): 池化后的 token 数 L_out。
hidden_size (int): 通道维 H 的大小(用于 LayerNorm 维度)。
eps (float): LayerNorm eps。
affine (bool): LayerNorm 是否带缩放平移参数。
impl (str): 'auto' | 'pool' | 'interp'。auto 根据长度阈值自动切换。
switch_threshold (int): 当 L >= 该阈值且 impl='auto' 时使用插值。
pool_in_fp32 (bool): 池化/插值内部提升到 FP32 计算以增强数稳。
"""
def __init__(self, output_size: int, hidden_size: int, eps: float = 1e-5, affine: bool = True,
impl: str = 'auto', switch_threshold: int = 8192, pool_in_fp32: bool = True):
super().__init__()
if output_size <= 0:
raise ValueError("output_size must be positive")
if hidden_size <= 0:
raise ValueError("hidden_size must be positive")
if impl not in ('auto', 'pool', 'interp'):
raise ValueError("impl must be one of {'auto','pool','interp'}")
self.output_size = int(output_size)
self.hidden_size = int(hidden_size)
self.impl = impl
self.switch_threshold = int(switch_threshold)
self.pool_in_fp32 = bool(pool_in_fp32)
self.in_norm = nn.LayerNorm(hidden_size, eps=eps, elementwise_affine=affine)
self.out_norm = nn.LayerNorm(hidden_size, eps=eps, elementwise_affine=affine)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 期待 x 形状为 [B, H, L]
if x.dim() != 3:
raise ValueError(f"AdaptiveAvgPool1dLayer expects 3D tensor [B,H,L], got {tuple(x.shape)}")
B, H, L = x.shape
if H != self.hidden_size:
raise ValueError(f"Channel size mismatch: got H={H}, expected {self.hidden_size}")
# 输入归一化:在 [B, L, H] 上做 LayerNorm(H)
x = x.transpose(1, 2).contiguous() # [B, L, H]
x = self.in_norm(x)
x = x.transpose(1, 2).contiguous() # [B, H, L]
# 选择实现:大 L 时使用插值以避免 CUDA sharedMem 报错
use_interp = (self.impl == 'interp') or (self.impl == 'auto' and L >= self.switch_threshold)
orig_dtype = x.dtype
if self.pool_in_fp32 and x.dtype in (torch.float16, torch.bfloat16):
x = x.float()
if use_interp:
# 线性插值在 [B, H, L] 上稳定可导
x = F.interpolate(x, size=self.output_size, mode='linear', align_corners=False)
else:
x = F.adaptive_avg_pool1d(x.contiguous(), self.output_size)
x = x.to(orig_dtype)
# 输出归一化:在 [B, L_out, H] 上做 LayerNorm(H)
x = x.transpose(1, 2).contiguous() # [B, L_out, H]
x = self.out_norm(x)
x = x.transpose(1, 2).contiguous() # [B, H, L_out]
return x
class LLaVAModel(BaseModel):
def __init__(self,
llm,
freeze_llm=True,
visual_select_layer=-2,
pretrained_pth=None,
projector_depth=2,
llm_lora=None,
visual_encoder_lora=None,
use_activation_checkpointing=True,
max_position_embeddings=None,
hidden_size=512,
train_stage='2',
projector_pth=None,
use_projector_pool = False,
projector_pool_out_tokens = 1024,
projector_pool_pth = None,
projector_pool_ln_eps = 1e-6,
projector_pool_ln_affine = True,
):
super().__init__()
self.freeze_llm = freeze_llm
self.freeze_visual_encoder = True
if train_stage == '1':
print('train_stage == 1')
self.freeze_llm = True
elif train_stage == '2':
print('train_stage == 2')
self.freeze_llm = False
with LoadWoInit():
if isinstance(llm, dict):
llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
self.llm = self._build_from_cfg_or_module(llm)
self.llm.config.use_cache = False
dispatch_modules(self.llm)
self.projector_depth = projector_depth
projector_config = ProjectorConfig(
visual_hidden_size=hidden_size,
llm_hidden_size=self.llm.config.hidden_size,
depth=self.projector_depth)
self.projector = ProjectorModel(projector_config).to(
self.llm.dtype)
self.use_projector_pool = use_projector_pool
if self.use_projector_pool:
hs = int(self.llm.config.hidden_size)
self.projector_pool = AdaptiveAvgPool1dLayer(
output_size=int(projector_pool_out_tokens),
hidden_size=hs,
eps=float(projector_pool_ln_eps),
affine=bool(projector_pool_ln_affine),
impl= 'auto',
switch_threshold= 10240,
pool_in_fp32= True,
)
if self.freeze_llm:
print('freeze_llm')
self.llm.requires_grad_(False)
if use_activation_checkpointing:
# For backward compatibility
if hasattr(self.llm, 'enable_input_require_grads'):
self.llm.enable_input_require_grads()
else:
self.llm.get_input_embeddings().register_forward_hook(
make_inputs_require_grad)
self.projector.enable_input_require_grads()
# self.visual_encoder.enable_input_require_grads() # if used
# enable gradient (activation) checkpointing for memory efficiency
self.gradient_checkpointing_enable()
self.use_llm_lora = None
self.use_visual_encoder_lora = None
if self.use_llm_lora:
self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
if projector_pth is not None:
print_log(f"Loading projector from {projector_pth}", "current")
proj_sd = load_file(projector_pth, device="cpu")
# proj_sd = load_file(projector_pth, device="cuda")
self.projector.load_state_dict(proj_sd, strict=False)
self.projector.to(self.llm.dtype)
if pretrained_pth is not None:
pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
self.load_state_dict(pretrained_state_dict, strict=False)
print_log(f'Load pretrained weight from {pretrained_pth}',
'current')
self.visual_select_layer = visual_select_layer
self._is_init = True
self.is_first_iter = True
def _parse_lora_config(self, lora_config):
if isinstance(lora_config, dict) or isinstance(
lora_config, Config) or isinstance(lora_config, ConfigDict):
lora_config = BUILDER.build(lora_config)
return lora_config
def _prepare_llm_for_lora(self,
lora_config,
use_activation_checkpointing=True):
lora_config = self._parse_lora_config(lora_config)
self.llm = prepare_model_for_kbit_training(
self.llm, use_activation_checkpointing)
if lora_config.target_modules is None:
modules = find_all_linear_names(self.llm)
lora_config.target_modules = modules
self.llm = get_peft_model(self.llm, lora_config)
def _prepare_visual_encoder_for_lora(self,
lora_config,
use_activation_checkpointing=True):
lora_config = self._parse_lora_config(lora_config)
if lora_config.target_modules is None:
modules = find_all_linear_names(self.visual_encoder)
lora_config.target_modules = modules
self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
def gradient_checkpointing_enable(self):
self.activation_checkpointing_enable()
def activation_checkpointing_enable(self):
self.llm.gradient_checkpointing_enable()
# self.visual_encoder.gradient_checkpointing_enable()
self.projector.gradient_checkpointing_enable()
def gradient_checkpointing_disable(self):
self.activation_checkpointing_disable()
def activation_checkpointing_disable(self):
self.llm.gradient_checkpointing_disable()
# self.visual_encoder.gradient_checkpointing_disable()
self.projector.gradient_checkpointing_disable()
def init_weights(self):
pass
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
to_return = OrderedDict()
# Step 1. visual_encoder
if self.use_visual_encoder_lora:
to_return.update(
get_peft_model_state_dict(
self.visual_encoder, state_dict=state_dict))
elif not self.freeze_visual_encoder:
to_return.update({
k: v
for k, v in state_dict.items() if 'visual_encoder.' in k
})
# Step 2. LLM
if self.use_llm_lora:
to_return.update(
get_peft_model_state_dict(self.llm, state_dict=state_dict))
elif not self.freeze_llm:
to_return.update(
{k: v
for k, v in state_dict.items() if 'llm.' in k})
# Step 3. Projector
to_return.update(
{k: v
for k, v in state_dict.items() if 'projector.' in k})
to_return.update(
{k: v
for k, v in state_dict.items() if 'projector_pool.' in k})
return to_return
@staticmethod
def _prepare_for_long_context_training(cfg, llm_cfg,
max_position_embeddings):
orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
if orig_rope_scaling is None:
orig_rope_scaling = {'factor': 1}
orig_rope_scaling_factor = orig_rope_scaling[
'factor'] if 'factor' in orig_rope_scaling.keys() else 1
orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
if orig_ctx_len:
orig_ctx_len *= orig_rope_scaling_factor
if max_position_embeddings > orig_ctx_len:
scaling_factor = float(
math.ceil(max_position_embeddings / orig_ctx_len))
llm_cfg.rope_scaling = {
'type': 'linear',
'factor': scaling_factor
}
# hardcode for internlm2
llm_cfg.attn_implementation = 'flash_attention_2'
cfg.config = llm_cfg
return cfg, llm_cfg
@staticmethod
def _prepare_for_flash_attn(cfg, llm_cfg):
cls_name = type(llm_cfg).__name__
SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
'Starcoder2Config', 'Starcoder2Config',
'Phi3Config')
SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
'MistralConfig', 'MixtralConfig', 'Qwen2Config',
'Qwen2MoeConfig', 'Starcoder2Config',
'Starcoder2Config', 'Phi3Config')
torch_dtype = torch.bfloat16 if (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
else torch.float16
if getattr(cfg, 'attn_implementation', None) is not None:
# Flash Attention 2.0 only supports torch.float16 and
# torch.bfloat16 dtypes
if cfg.attn_implementation == 'flash_attention_2':
cfg.torch_dtype = torch_dtype
elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
cfg.torch_dtype = torch_dtype
cfg.attn_implementation = 'flash_attention_2'
elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
cfg.attn_implementation = 'sdpa'
return cfg, llm_cfg
@staticmethod
def _prepare_for_qlora_zero3(cfg):
if (not is_deepspeed_zero3_enabled()) or (not hasattr(
cfg, 'quantization_config')):
return cfg
torch_dtype = torch.bfloat16 if (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
else torch.float16
cfg.torch_dtype = torch_dtype
quantization_config = cfg.quantization_config
quantization_config.bnb_4bit_compute_dtype = torch_dtype
quantization_config.bnb_4bit_quant_storage = torch_dtype
return cfg
def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
cfg = self._prepare_for_qlora_zero3(cfg)
pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
llm_cfg = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True)
cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
if max_position_embeddings is not None:
cfg, llm_cfg = self._prepare_for_long_context_training(
cfg, llm_cfg, max_position_embeddings)
return cfg
def _build_from_cfg_or_module(self, cfg_or_mod):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
elif isinstance(cfg_or_mod, dict):
traverse_dict(cfg_or_mod)
return BUILDER.build(cfg_or_mod)
else:
raise NotImplementedError
def forward(self, data, data_samples=None, mode='loss'):
if self.is_first_iter:
# hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
# device
# Only required in `LLaVAModel` .
# We do not need this in `SupervisedFinetune` .
self.to(data['input_ids'].device)
self.is_first_iter = False
# data_dict['pixel_values']=[[pixel_values of img1], [pixel_values of img2], ...]
if 'pixel_values' in data:
feat_to_proj = data['pixel_values'].to(self.llm.dtype)
# ======================= FIX =======================
# Explicitly enable gradient tracking for the input features.
# This is the crucial step to connect the backpropagation graph
# to the projector's weights.
feat_to_proj.requires_grad_(True)
# ===================================================
# The diagnostic code you had was good, but this makes it proactive.
# You can now remove the old `if using_proj_ckpt:` block
# as this solves the root cause.
pixel_values = self.projector(feat_to_proj) # Pass the grad-enabled tensor
# === NEW: pool along the sequence length (tokens) to L'
if self.use_projector_pool:
B, L, H = pixel_values.shape
pv = pixel_values.transpose(1, 2) # [B, H, L]
pv = self.projector_pool(pv) # [B, H, L']
pixel_values = pv.transpose(1, 2).contiguous() # [B, L', H]
data['pixel_values'] = pixel_values
data.pop('coords', None)
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
if mode == 'loss':
return self.compute_loss(data, data_samples)
elif mode == 'predict':
return self.predict(data, data_samples)
elif mode == 'tensor':
return self._forward(data, data_samples)
else:
raise NotImplementedError
def _forward(self, data, data_samples=None):
outputs = self.llm(**data)
return outputs
def predict(self, data, data_samples=None):
outputs = self.llm(**data)
logits_dict = [{'logits': logits} for logits in outputs.logits]
return logits_dict
# 替换 LLaVAModel 中的 compute_loss 函数
def compute_loss(self, data, data_samples=None):
"""
计算损失的修改版实现。
该版本通过计算批次中每个样本的平均损失来解决长短文本的梯度失衡问题,
使得每个样本对总损失的贡献相等,无论其token长度如何。
"""
# 如果 HF 模型可以自己处理,则直接返回
if "labels" not in data:
outputs = self.llm(**data)
return {"loss": outputs.loss}
# 将 labels 从 data 中分离出来,避免其被直接传递给模型
labels = data.pop("labels")
# 模型前向传播,获取 logits
outputs = self.llm(**data)
logits = outputs.logits
# 验证 logits 和 labels 的形状是否匹配
if logits.shape[:-1] != labels.shape:
raise ValueError(
f"Logits and labels shape mismatch. Logits: {logits.shape}, Labels: {labels.shape}"
)
# 将 Logits 和 Labels 的 batch 维度移动到第一维,方便迭代
# logits: [B, L, V] -> [L, B, V]
# labels: [B, L] -> [L, B]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# 使用 cross_entropy 计算每个 token 的损失,但不对其进行任何聚合 (reduction='none')
# 这将返回一个与 shift_labels 形状相同的损失张量
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
reduction='none'
)
# 将损失张量 reshape 回 [B, L-1]
loss = loss.view(shift_logits.size(0), -1)
# 对每个样本(每个序列)分别计算平均损失
# 统计每个样本中有效(非-100)的 token 数量
num_tokens_per_sample = (shift_labels != -100).sum(dim=1)
# 计算每个样本的总损失
loss_per_sample = loss.sum(dim=1)
# 避免除以零
valid_samples_mask = num_tokens_per_sample > 0
# 初始化每个样本的平均损失
mean_loss_per_sample = torch.zeros_like(loss_per_sample)
# 只对有效的样本计算平均损失
if valid_samples_mask.any():
mean_loss_per_sample[valid_samples_mask] = loss_per_sample[valid_samples_mask] / num_tokens_per_sample[valid_samples_mask]
# 最终的损失是所有样本平均损失的平均值
final_loss = mean_loss_per_sample.mean()
return {"loss": final_loss}
# def compute_loss(self, data, data_samples=None):
# outputs = self.llm(**data)
# loss_dict = {'loss': outputs.loss}
# return loss_dict
# def compute_loss(self, data, data_samples=None):
# """
# 计算 token-level 交叉熵损失(分布式/AMP 兼容)。
# - labels 中 -100 为 ignore_index
# - 自动屏蔽负 ID(如 -200 图像占位)与 special_ids 对应位置
# """
# # 1) 若无 labels,退回 HF 默认
# if "labels" not in data:
# outputs = self.llm(**data)
# return {"loss": outputs.loss}
# labels = data["labels"] # [B, T]
# input_ids = data.get("input_ids", None) # [B, T] or None
# attn = data.get("attention_mask", None) # 可无
# # 2) 标签清洗(不改原 labels)
# safe_labels = labels.clone()
# # 2.1 屏蔽负 ID(如 -200 图像占位)
# if input_ids is not None:
# neg_mask = (input_ids < 0)
# if neg_mask.any():
# safe_labels = torch.where(neg_mask, torch.full_like(safe_labels, -100), safe_labels)
# # 2.2 屏蔽 tokenizer 的特殊 token(模板标记等)
# if getattr(self, "tokenizer", None) is not None:
# try:
# special_ids = set(self.tokenizer.all_special_ids or [])
# except Exception:
# special_ids = set()
# if special_ids:
# special_mask = torch.zeros_like(input_ids, dtype=torch.bool)
# for sid in special_ids:
# special_mask |= (input_ids == sid)
# if special_mask.any():
# safe_labels = torch.where(special_mask, torch.full_like(safe_labels, -100), safe_labels)
# # 3) 前向,拿 logits(不把 labels 交给 HF,避免其先做 per-device mean)
# model_inputs = {k: v for k, v in data.items() if k != "labels"}
# outputs = self.llm(**model_inputs, use_cache=False)
# logits = outputs.logits # [B, T, V]
# # 形状断言
# if logits.dim() != 3 or logits.shape[:2] != safe_labels.shape[:2]:
# raise RuntimeError(
# f"logits/labels length mismatch: logits {tuple(logits.shape)} vs labels {tuple(safe_labels.shape)}"
# )
# # 4) CausalLM 对齐
# shift_logits = logits[:, :-1, :].contiguous()
# shift_labels = safe_labels[:, 1:].contiguous()
# # 5) 统计有效 token & 分布式聚合
# n_tok_local = (shift_labels != -100).sum().to(device=logits.device, dtype=torch.long)
# world_size = 1
# n_tok_global = n_tok_local
# if dist.is_available() and dist.is_initialized():
# world_size = dist.get_world_size()
# with torch.no_grad():
# n_tok_global = n_tok_local.clone()
# dist.all_reduce(n_tok_global, op=dist.ReduceOp.SUM)
# # 若全局无监督 token,则返回 0(防 NaN)
# if n_tok_global.item() == 0:
# zero = shift_logits.sum() * 0.0
# return {"loss": zero, "ntok": n_tok_global.to(zero.dtype)}
# # 6) 分子(sum over tokens,FP32 更稳)
# loss_sum_local = F.cross_entropy(
# shift_logits.float().view(-1, shift_logits.size(-1)),
# shift_labels.view(-1),
# ignore_index=-100,
# reduction="sum",
# )
# # 7) 全局 token 平均的 loss(抵消 DDP 的梯度平均)
# denom = n_tok_global.clamp_min(1).to(loss_sum_local.dtype)
# loss = (loss_sum_local / denom) * float(world_size)
# # 8) 返回
# ntok_tensor = denom.detach()
# return {"loss": loss, "ntok": ntok_tensor}
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.llm, name)
def to_hf(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={},
save_format='xtuner',
**kwargs):
if save_format == 'xtuner':
self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
elif save_format == 'huggingface':
self.to_huggingface_llava(cfg, save_dir, fp32,
save_pretrained_kwargs)
elif save_format == 'official':
self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
else:
raise NotImplementedError
def to_xtuner_llava(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={}):
# LLM
self.llm.config.use_cache = True
if not fp32:
print_log('Convert LLM to float16', 'current')
self.llm.half()
if self.use_llm_lora:
llm_path = osp.join(save_dir, 'llm_adapter')
print_log(f'Saving LLM adapter to {llm_path}', 'current')
self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
elif not self.freeze_llm:
llm_path = save_dir
print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
tokenizer = BUILDER.build(cfg.tokenizer)
tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
print_log(f'Saving LLM to {llm_path}', 'current')
self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
self.llm.config.use_cache = False
# Visual Encoder
if self.use_visual_encoder_lora:
visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
print_log(
f'Saving visual_encoder adapter to {visual_encoder_path}',
'current')
self.visual_encoder.save_pretrained(visual_encoder_path,
**save_pretrained_kwargs)
elif not self.freeze_visual_encoder:
visual_encoder_path = osp.join(save_dir, 'visual_encoder')
print_log(
'Saving visual_encoder image_processor to'
f'{visual_encoder_path}', 'current')
image_processor = BUILDER.build(cfg.image_processor)
image_processor.save_pretrained(visual_encoder_path,
**save_pretrained_kwargs)
print_log(f'Saving visual_encoder to {visual_encoder_path}',
'current')
self.visual_encoder.save_pretrained(visual_encoder_path,
**save_pretrained_kwargs)
# Projector
projector_path = osp.join(save_dir, 'projector')
print_log(f'Saving projector to {projector_path}', 'current')
# self.projector.save_pretrained(projector_path,
# **save_pretrained_kwargs)
os.makedirs(projector_path, exist_ok=True)
output_path = os.path.join(projector_path, 'projector.safetensors')
save_file(self.projector.state_dict(), output_path)
if self.use_projector_pool and getattr(self, 'projector_pool', None):
projector_pool_path = osp.join(save_dir, 'projector_pool')
print_log(f'Saving projector_pool to {projector_pool_path}', 'current')
torch.save(self.projector_pool.state_dict(), projector_pool_path)
def to_huggingface_llava(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={}):
LLM_MAPPING = {
'model': 'language_model.model',
'lm_head': 'language_model.lm_head',
}
VIT_MAPPING = {
'vision_model': 'vision_tower.vision_model',
}
PROJECTOR_MAPPING = {
'model.0': 'multi_modal_projector.linear_1',
'model.2': 'multi_modal_projector.linear_2',
}
LONGNET_MAPPING = {
'layers.0': 'LongNet_encoder.layers.0',
'layers.1': 'LongNet_encoder.layers.1',
'layer_norm': 'LongNet_encoder.layer_norm'
}
assert getattr(self.llm, 'hf_quantizer', None) is None, \
'This conversion format does not support quantized LLM.'
# get state_dict
llm = self.llm
if self.use_llm_lora:
llm = self.llm.merge_and_unload()
llm.config.use_cache = True
if not fp32:
print_log('Convert LLM to float16', 'current')
llm.half()
assert isinstance(llm, LlamaForCausalLM), \
'This conversion format only supports LlamaForCausalLM.'
llm_state_dict = llm.state_dict()
llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
need_visual_encoder = (not self.freeze_visual_encoder
or self.use_visual_encoder_lora)
visual_encoder = self.visual_encoder
if self.use_visual_encoder_lora:
visual_encoder = self.visual_encoder.merge_and_unload()
assert isinstance(visual_encoder, CLIPVisionModel),\
'This conversion format only supports CLIPVisionModel.'
if need_visual_encoder:
visual_encoder_state_dict = visual_encoder.state_dict()
visual_encoder_state_dict = convert_state_dict_to_hf(
visual_encoder_state_dict, VIT_MAPPING)
else:
visual_encoder_state_dict = {}
projector_state_dict = self.projector.state_dict()
projector_state_dict = convert_state_dict_to_hf(
projector_state_dict, PROJECTOR_MAPPING)
LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
LongNet_encoder_state_dict = convert_state_dict_to_hf(
LongNet_encoder_state_dict, LONGNET_MAPPING)
state_dict = {
**projector_state_dict,
**llm_state_dict,
**visual_encoder_state_dict,
**LongNet_encoder_state_dict
}
# init model
text_config = llm.config
vision_config = visual_encoder.config
config = LlavaConfig(
text_config=text_config,
vision_config=vision_config,
attn_implementation='eager')
with init_empty_weights():
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', message='.*non-meta.*', category=UserWarning)
model = LlavaForConditionalGeneration(config)
model.load_state_dict(state_dict, strict=True, assign=True)
# processor
cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
tokenizer = BUILDER.build(cfg.tokenizer)
tokenizer.add_tokens(
AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
special_tokens=True)
tokenizer.add_special_tokens({'pad_token': '<pad>'})
image_processor = BUILDER.build(cfg.image_processor)
assert isinstance(image_processor, CLIPImageProcessor),\
'This conversion format only supports CLIPImageProcessor.'
processor = LlavaProcessor(
tokenizer=tokenizer, image_processor=image_processor)
# Pad to 64 for performance reasons
pad_shape = 64
pre_expansion_embeddings = \
model.language_model.model.embed_tokens.weight.data
mu = torch.mean(pre_expansion_embeddings, dim=0).float()
n = pre_expansion_embeddings.size()[0]
sigma = ((pre_expansion_embeddings - mu).T
@ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(
mu, covariance_matrix=1e-5 * sigma)
# We add an image token so we need to resize the model
ori_vocab_size = config.text_config.vocab_size
tokenizer_vocab_size = tokenizer.encode('<pad>')[-1]
added_token = tokenizer_vocab_size - ori_vocab_size
if added_token > 0:
model.resize_token_embeddings(ori_vocab_size + added_token,
pad_shape)
model.language_model.model.embed_tokens.weight.data[
ori_vocab_size:] = torch.stack(
tuple(
dist.sample()
for _ in range(model.language_model.model.embed_tokens.
weight.data[ori_vocab_size:].shape[0])),
dim=0,
)
model.language_model.lm_head.weight.data[
ori_vocab_size:] = torch.stack(
tuple(dist.sample()
for _ in range(model.language_model.lm_head.weight.
data[ori_vocab_size:].shape[0])),
dim=0,
)
model.config.image_token_index = tokenizer.encode(
DEFAULT_IMAGE_TOKEN)[-1]
model.config.pad_token_id = tokenizer.encode('<pad>')[-1]
# save
print_log(f'Saving to {save_dir}', 'current')
model.save_pretrained(save_dir, **save_pretrained_kwargs)
processor.save_pretrained(save_dir, **save_pretrained_kwargs)
def to_official_llava(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={}):
VIT_MAPPING = {
'vision_model': 'model.vision_tower.vision_tower.vision_model',
}
PROJECTOR_MAPPING = {
'model.0': 'model.mm_projector.0',
'model.2': 'model.mm_projector.2',
}
LONGNET_MAPPING = {
'layers.0': 'LongNet_encoder.layers.0',
'layers.1': 'LongNet_encoder.layers.1',
'layer_norm': 'LongNet_encoder.layer_norm'
}
try:
from llava.model import LlavaConfig, LlavaLlamaForCausalLM
except ImportError:
raise ImportError(
'Please install llava with '
'`pip install git+https://github.com/haotian-liu/LLaVA.git '
'--no-deps`.')
assert getattr(self.llm, 'hf_quantizer', None) is None, \
'This conversion format does not support quantized LLM.'
# get state_dict
llm = self.llm
if self.use_llm_lora:
llm = self.llm.merge_and_unload()
llm.config.use_cache = True
if not fp32:
print_log('Convert LLM to float16', 'current')
llm.half()
assert isinstance(llm, LlamaForCausalLM), \
'This conversion format only supports LlamaForCausalLM.'
llm_state_dict = llm.state_dict()
need_visual_encoder = (not self.freeze_visual_encoder
or self.use_visual_encoder_lora)
visual_encoder = self.visual_encoder
if self.use_visual_encoder_lora:
visual_encoder = self.visual_encoder.merge_and_unload()
assert isinstance(visual_encoder, CLIPVisionModel),\
'This conversion format only supports CLIPVisionModel.'
if need_visual_encoder:
visual_encoder_state_dict = visual_encoder.state_dict()
visual_encoder_state_dict = convert_state_dict_to_hf(
visual_encoder_state_dict, VIT_MAPPING)
else:
visual_encoder_state_dict = {}
projector_state_dict = self.projector.state_dict()
projector_state_dict = convert_state_dict_to_hf(
projector_state_dict, PROJECTOR_MAPPING)
LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
LongNet_encoder_state_dict = convert_state_dict_to_hf(
LongNet_encoder_state_dict, LONGNET_MAPPING)
state_dict = {
**projector_state_dict,
**llm_state_dict,
**visual_encoder_state_dict,
**LongNet_encoder_state_dict
}
# init model
tokenizer = BUILDER.build(cfg.tokenizer)
image_processor = BUILDER.build(cfg.image_processor)
assert isinstance(image_processor, CLIPImageProcessor),\
'This conversion format only supports CLIPImageProcessor.'
llava_config_dict = llm.config.__dict__.copy()
llava_config_dict.update(
dict(
image_aspect_ratio='pad',
mm_hidden_size=visual_encoder.config.hidden_size,
mm_projector_type=f'mlp{self.projector_depth}x_gelu',
mm_use_im_patch_token=False,
mm_use_im_start_end=False,
mm_vision_select_feature='patch',
mm_vision_select_layer=self.visual_select_layer,
mm_vision_tower=visual_encoder.config.name_or_path,
unfreeze_mm_vision_tower=need_visual_encoder,
model_type='llava',
use_cache=True,
use_mm_proj=True))
llava_config = LlavaConfig(**llava_config_dict)
with init_empty_weights():
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', message='.*non-meta.*', category=UserWarning)
model = LlavaLlamaForCausalLM(llava_config)
model.load_state_dict(state_dict, strict=True, assign=True)
# save
print_log(f'Saving to {save_dir}', 'current')
model.save_pretrained(save_dir, **save_pretrained_kwargs)
image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)