|
|
|
|
|
import math |
|
|
import os |
|
|
import os.path as osp |
|
|
import warnings |
|
|
from collections import OrderedDict |
|
|
from functools import partial |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from accelerate import init_empty_weights |
|
|
from mmengine import print_log |
|
|
from mmengine.config import Config, ConfigDict |
|
|
from mmengine.model import BaseModel |
|
|
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training |
|
|
from peft.tuners.lora.layer import LoraLayer |
|
|
from safetensors.torch import load_file, save_file |
|
|
from torch.nn.init import trunc_normal_ |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
from transformers import (AddedToken, AutoConfig, CLIPImageProcessor, |
|
|
CLIPVisionModel, LlamaForCausalLM, |
|
|
LlamaTokenizerFast, LlavaConfig, |
|
|
LlavaForConditionalGeneration, LlavaProcessor) |
|
|
from transformers.integrations import is_deepspeed_zero3_enabled |
|
|
|
|
|
from xtuner.model.torchscale.component.multihead_attention import MultiheadAttention |
|
|
from xtuner.model.torchscale.architecture.config import EncoderConfig |
|
|
|
|
|
from xtuner.model.torchscale.model.pos_embed import get_2d_sincos_pos_embed |
|
|
from xtuner.registry import BUILDER |
|
|
from xtuner.utils import DEFAULT_IMAGE_TOKEN |
|
|
|
|
|
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules |
|
|
from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 |
|
|
from .sparse_token_merge import SparsePatchMerging |
|
|
from .utils import (LoadWoInit, find_all_linear_names, |
|
|
get_peft_model_state_dict, guess_load_checkpoint, |
|
|
make_inputs_require_grad, |
|
|
prepare_inputs_labels_for_multimodal, traverse_dict) |
|
|
|
|
|
|
|
|
|
|
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): |
|
|
grid_h = np.arange(grid_size, dtype=np.float32) |
|
|
grid_w = np.arange(grid_size, dtype=np.float32) |
|
|
grid = np.meshgrid(grid_w, grid_h) |
|
|
grid = np.stack(grid, axis=0) |
|
|
grid = grid.reshape([2, 1, grid_size, grid_size]) |
|
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
|
|
if cls_token: |
|
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
|
|
return pos_embed |
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
|
|
assert embed_dim % 2 == 0 |
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
|
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) |
|
|
return emb |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
|
assert embed_dim % 2 == 0 |
|
|
omega = np.arange(embed_dim // 2, dtype=np.float32) |
|
|
omega /= embed_dim / 2. |
|
|
omega = 1. / 10000**omega |
|
|
pos = pos.reshape(-1) |
|
|
out = np.einsum('m,d->md', pos, omega) |
|
|
emb_sin = np.sin(out) |
|
|
emb_cos = np.cos(out) |
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
|
return emb |
|
|
|
|
|
|
|
|
class Resampler(nn.Module): |
|
|
""" |
|
|
修正后的 Resampler 版本: |
|
|
1. 区分 query_pos_embed 和 input_pos_embed,解决变量冲突。 |
|
|
2. 解除对外部 llm 模块的依赖,提高封装性。 |
|
|
3. 修正 forward 方法中的位置编码应用逻辑和维度匹配。 |
|
|
4. 集成梯度检查点(gradient_checkpointing)功能以节省显存。 |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
grid_size, |
|
|
embed_dim, |
|
|
num_heads, |
|
|
slide_ngrids=1000, |
|
|
kv_dim=None, |
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
|
gradient_checkpointing=False |
|
|
): |
|
|
super().__init__() |
|
|
self.num_queries = grid_size ** 2 |
|
|
self.embed_dim = embed_dim |
|
|
self.num_heads = num_heads |
|
|
self.slide_ngrids = slide_ngrids |
|
|
self.gradient_checkpointing = gradient_checkpointing |
|
|
|
|
|
|
|
|
self.query_pos_embed = nn.Parameter( |
|
|
torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float(), |
|
|
requires_grad=False |
|
|
) |
|
|
|
|
|
|
|
|
num_patches = slide_ngrids ** 2 |
|
|
self.register_buffer( |
|
|
'input_pos_embed', |
|
|
torch.zeros(1, num_patches, embed_dim), |
|
|
persistent=False |
|
|
) |
|
|
|
|
|
|
|
|
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) |
|
|
trunc_normal_(self.query, std=.02) |
|
|
|
|
|
|
|
|
if kv_dim is not None and kv_dim != embed_dim: |
|
|
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) |
|
|
else: |
|
|
self.kv_proj = nn.Identity() |
|
|
|
|
|
|
|
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ln_q = norm_layer(embed_dim) |
|
|
self.ln_kv = norm_layer(embed_dim) |
|
|
self.ln_post = norm_layer(embed_dim) |
|
|
self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim)) |
|
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
self.initialize_input_pe_weights() |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
@torch.no_grad() |
|
|
def initialize_input_pe_weights(self, chunk_rows: int = 64, chunk_cols: int = 64): |
|
|
H = W = self.slide_ngrids |
|
|
D = self.embed_dim |
|
|
assert D % 4 == 0, "embed_dim 必须是 4 的倍数,才能和 numpy 实现严格对应。" |
|
|
|
|
|
device = self.input_pos_embed.device |
|
|
dtype64 = torch.float64 |
|
|
|
|
|
if self.input_pos_embed.shape != (1, H * W, D): |
|
|
self.input_pos_embed.resize_(1, H * W, D) |
|
|
|
|
|
pos4d = self.input_pos_embed.view(1, H, W, D) |
|
|
|
|
|
k = D // 4 |
|
|
inv = 1.0 / (10000 ** (torch.arange(k, device=device, dtype=dtype64) / k)) |
|
|
|
|
|
y_lin = torch.arange(H, device=device, dtype=dtype64) |
|
|
x_lin = torch.arange(W, device=device, dtype=dtype64) |
|
|
|
|
|
y_phase = y_lin.unsqueeze(1) * inv.unsqueeze(0) |
|
|
x_phase = x_lin.unsqueeze(1) * inv.unsqueeze(0) |
|
|
y_enc = torch.cat([torch.sin(y_phase), torch.cos(y_phase)], dim=1) |
|
|
x_enc = torch.cat([torch.sin(x_phase), torch.cos(x_phase)], dim=1) |
|
|
|
|
|
for r0 in range(0, H, chunk_rows): |
|
|
r1 = min(r0 + chunk_rows, H) |
|
|
R = r1 - r0 |
|
|
y_chunk = y_enc[r0:r1].unsqueeze(1) |
|
|
|
|
|
for c0 in range(0, W, chunk_cols): |
|
|
c1 = min(c0 + chunk_cols, W) |
|
|
C = c1 - c0 |
|
|
x_chunk = x_enc[c0:c1].unsqueeze(0) |
|
|
emb_rc = torch.cat([ |
|
|
x_chunk.expand(R, C, 2 * k), |
|
|
y_chunk.expand(R, C, 2 * k) |
|
|
], dim=2) |
|
|
pos4d[0, r0:r1, c0:c1, :].copy_(emb_rc.to(pos4d.dtype)) |
|
|
|
|
|
def _checkpointed_forward(self, q_embed, kv_embed): |
|
|
|
|
|
|
|
|
|
|
|
attn_out = self.attn(q_embed, kv_embed, kv_embed)[0] |
|
|
permuted_out = attn_out |
|
|
ln_out = self.ln_post(permuted_out) |
|
|
proj_out = ln_out @ self.proj |
|
|
return proj_out |
|
|
|
|
|
def forward(self, x, coords_rc, attn_mask=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pos_indices = (coords_rc[..., 0] * self.slide_ngrids + coords_rc[..., 1]).long() |
|
|
|
|
|
input_pos = self.input_pos_embed[:, pos_indices, :].squeeze(0) |
|
|
|
|
|
|
|
|
|
|
|
x = self.kv_proj(x) |
|
|
kv_embed = self.ln_kv(x) |
|
|
|
|
|
N = x.shape[0] |
|
|
q = self.ln_q(self.query) |
|
|
|
|
|
|
|
|
|
|
|
q_embed = q.unsqueeze(0).expand(N, -1, -1) + self.query_pos_embed.unsqueeze(0) |
|
|
|
|
|
|
|
|
kv_embed = kv_embed + input_pos |
|
|
|
|
|
if self.training and self.gradient_checkpointing: |
|
|
q_embed.requires_grad_(True) |
|
|
kv_embed.requires_grad_(True) |
|
|
out = checkpoint(self._checkpointed_forward, q_embed, kv_embed, use_reentrant=False) |
|
|
else: |
|
|
out = self._checkpointed_forward(q_embed, kv_embed) |
|
|
|
|
|
return out |
|
|
|
|
|
def enable_input_require_grads(self): |
|
|
print_log("enable input required grads for projector", 'current') |
|
|
|
|
|
def make_inputs_require_grad(module, input, output): |
|
|
output.requires_grad_(True) |
|
|
|
|
|
self.model.register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
def gradient_checkpointing_enable(self): |
|
|
self.gradient_checkpointing = True |
|
|
|
|
|
def gradient_checkpointing_disable(self): |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _repeat(self, query, N: int): |
|
|
return query.unsqueeze(1).repeat(1, N, 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_state_dict_to_hf(state_dict, mapping): |
|
|
new_state_dict = {} |
|
|
for key, value in state_dict.items(): |
|
|
if key.endswith('.inv_freq'): |
|
|
continue |
|
|
for key_to_modify, new_key in mapping.items(): |
|
|
if key_to_modify in key: |
|
|
key = key.replace(key_to_modify, new_key) |
|
|
new_state_dict[key] = value |
|
|
return new_state_dict |
|
|
|
|
|
class AdaptiveAvgPool1dLayer(nn.Module): |
|
|
def __init__(self, output_size): |
|
|
super(AdaptiveAvgPool1dLayer, self).__init__() |
|
|
self.output_size = output_size |
|
|
|
|
|
def forward(self, x): |
|
|
return F.adaptive_avg_pool1d(x, self.output_size) |
|
|
|
|
|
|
|
|
class LLaVAModel(BaseModel): |
|
|
|
|
|
def __init__(self, |
|
|
llm, |
|
|
freeze_llm=True, |
|
|
visual_select_layer=-2, |
|
|
pretrained_pth=None, |
|
|
projector_depth=2, |
|
|
llm_lora=None, |
|
|
visual_encoder_lora=None, |
|
|
use_activation_checkpointing=True, |
|
|
max_position_embeddings=None, |
|
|
hidden_size=512, |
|
|
train_stage='2', |
|
|
|
|
|
|
|
|
slide_ngrids=1000, |
|
|
tile_size=224, |
|
|
|
|
|
|
|
|
projector_pth=None, |
|
|
resampler_pth=None, |
|
|
token_merge_pth=None, |
|
|
|
|
|
|
|
|
enable_token_merge=True, |
|
|
|
|
|
|
|
|
use_resampler=True, |
|
|
resampler_num_latents=256, |
|
|
resampler_heads = 16, |
|
|
|
|
|
|
|
|
freeze_mm_in_stage2=False, |
|
|
freeze_projector_stage2=None, |
|
|
freeze_resampler_stage2=None, |
|
|
freeze_token_merge_stage2=None |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.freeze_llm = freeze_llm |
|
|
self.freeze_visual_encoder = True |
|
|
self.tile_size = tile_size |
|
|
|
|
|
|
|
|
if train_stage == '0': |
|
|
print_log('train_stage == 0', 'current') |
|
|
self.freeze_llm = True |
|
|
if train_stage == '1': |
|
|
print_log('train_stage == 1', 'current') |
|
|
self.freeze_llm = True |
|
|
elif train_stage == '2': |
|
|
print_log('train_stage == 2', 'current') |
|
|
self.freeze_llm = False |
|
|
|
|
|
|
|
|
def _resolve(flag): |
|
|
return freeze_mm_in_stage2 if flag is None else bool(flag) |
|
|
self._freeze_projector_in_s2 = _resolve(freeze_projector_stage2) |
|
|
self._freeze_resampler_in_s2 = _resolve(freeze_resampler_stage2) |
|
|
self._freeze_token_merge_in_s2 = _resolve(freeze_token_merge_stage2) |
|
|
|
|
|
|
|
|
with LoadWoInit(): |
|
|
if isinstance(llm, dict): |
|
|
llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings) |
|
|
self.llm = self._build_from_cfg_or_module(llm) |
|
|
|
|
|
self.llm.config.use_cache = False |
|
|
dispatch_modules(self.llm) |
|
|
|
|
|
|
|
|
self.enable_token_merge = enable_token_merge |
|
|
if self.enable_token_merge: |
|
|
self.token_merge = SparsePatchMerging( |
|
|
embed_dim=hidden_size, |
|
|
layernorm_eps=1e-6, |
|
|
merge_size=2 |
|
|
) |
|
|
|
|
|
|
|
|
self.projector_depth = projector_depth |
|
|
projector_config = ProjectorConfig( |
|
|
visual_hidden_size=hidden_size * 4 if self.enable_token_merge else hidden_size, |
|
|
llm_hidden_size=self.llm.config.hidden_size, |
|
|
depth=self.projector_depth |
|
|
) |
|
|
self.projector = ProjectorModel(projector_config).to(self.llm.dtype) |
|
|
self.projector.requires_grad_(True) |
|
|
|
|
|
|
|
|
self.use_resampler = use_resampler |
|
|
self.slide_ngrids = slide_ngrids |
|
|
if self.use_resampler: |
|
|
self.resampler_num_latents = resampler_num_latents |
|
|
print_log(f'using simple Resampler with {resampler_num_latents} latents', 'current') |
|
|
self.resampler = Resampler( |
|
|
grid_size=int(math.sqrt(self.resampler_num_latents)), |
|
|
embed_dim=self.llm.config.hidden_size, |
|
|
num_heads=resampler_heads, |
|
|
kv_dim=self.llm.config.hidden_size, |
|
|
).to(self.llm.dtype) |
|
|
|
|
|
|
|
|
|
|
|
if self.freeze_llm: |
|
|
print('freeze_llm') |
|
|
self.llm.requires_grad_(False) |
|
|
|
|
|
|
|
|
if use_activation_checkpointing: |
|
|
if hasattr(self.llm, 'enable_input_require_grads'): |
|
|
self.llm.enable_input_require_grads() |
|
|
else: |
|
|
self.llm.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_projector_frozen = (train_stage == '2' and self._freeze_projector_in_s2) |
|
|
if not _projector_frozen: |
|
|
print('enable projector input require grads') |
|
|
print_log('enable projector input require grads', 'current') |
|
|
self.projector.enable_input_require_grads() |
|
|
else: |
|
|
print_log('[stage-2] Skipping projector.enable_input_require_grads() (frozen)', 'current') |
|
|
|
|
|
|
|
|
self.gradient_checkpointing_enable() |
|
|
|
|
|
|
|
|
self.use_llm_lora = llm_lora is not None |
|
|
self.use_visual_encoder_lora = None |
|
|
if self.use_llm_lora: |
|
|
print_log(f"Building lora {llm_lora.__str__}", "current") |
|
|
self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing) |
|
|
self.verify_lora() |
|
|
|
|
|
|
|
|
if token_merge_pth is not None and enable_token_merge and hasattr(self, 'token_merge'): |
|
|
print_log(f'loading token_merge from {token_merge_pth}', 'current') |
|
|
merger_sd = load_file(token_merge_pth, device='cpu') |
|
|
self.token_merge.load_state_dict(merger_sd, strict=False) |
|
|
self.token_merge.to(self.llm.dtype) |
|
|
|
|
|
if projector_pth is not None: |
|
|
print_log(f"Loading projector from {projector_pth}", "current") |
|
|
proj_sd = load_file(projector_pth, device="cpu") |
|
|
self.projector.load_state_dict(proj_sd, strict=False) |
|
|
self.projector.to(self.llm.dtype) |
|
|
|
|
|
if resampler_pth is not None and self.use_resampler and hasattr(self, 'resampler'): |
|
|
print_log(f'Loading resampler from {resampler_pth}', 'current') |
|
|
resampler_sd = load_file(resampler_pth, device="cpu") |
|
|
self.resampler.load_state_dict(resampler_sd, strict=False) |
|
|
self.resampler.to(self.llm.dtype) |
|
|
|
|
|
|
|
|
if pretrained_pth is not None: |
|
|
sd = guess_load_checkpoint(pretrained_pth) |
|
|
model_sd = self.state_dict() |
|
|
filtered = {k: v for k, v in sd.items() if k in model_sd and model_sd[k].shape == v.shape} |
|
|
missing, unexpected = self.load_state_dict(filtered, strict=False) |
|
|
print_log(f"Loaded float ckpt from {pretrained_pth}", "current") |
|
|
print_log(f" missing: {missing}", "current") |
|
|
print_log(f" unexpected:{unexpected}", "current") |
|
|
|
|
|
|
|
|
self.visual_select_layer = visual_select_layer |
|
|
|
|
|
|
|
|
self._is_init = True |
|
|
self.is_first_iter = True |
|
|
|
|
|
|
|
|
if train_stage == '2': |
|
|
|
|
|
if hasattr(self, 'projector') and self._freeze_projector_in_s2: |
|
|
self.projector.requires_grad_(False) |
|
|
self.projector.eval() |
|
|
print_log('[stage-2] Freezing projector parameters', 'current') |
|
|
|
|
|
|
|
|
if getattr(self, 'use_resampler', False) and hasattr(self, 'resampler') and self._freeze_resampler_in_s2: |
|
|
self.resampler.requires_grad_(False) |
|
|
self.resampler.eval() |
|
|
print_log('[stage-2] Freezing resampler parameters', 'current') |
|
|
|
|
|
|
|
|
if getattr(self, 'enable_token_merge', False) and hasattr(self, 'token_merge') and self._freeze_token_merge_in_s2: |
|
|
self.token_merge.requires_grad_(False) |
|
|
self.token_merge.eval() |
|
|
print_log('[stage-2] Freezing token_merge parameters', 'current') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_lora_config(self, lora_config): |
|
|
if isinstance(lora_config, dict) or isinstance( |
|
|
lora_config, Config) or isinstance(lora_config, ConfigDict): |
|
|
lora_config = BUILDER.build(lora_config) |
|
|
return lora_config |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
def _prepare_llm_for_lora(self, |
|
|
lora_config, |
|
|
use_activation_checkpointing=True): |
|
|
lora_config = self._parse_lora_config(lora_config) |
|
|
self.llm = prepare_model_for_kbit_training( |
|
|
self.llm, use_activation_checkpointing) |
|
|
if lora_config.target_modules is None: |
|
|
modules = find_all_linear_names(self.llm) |
|
|
lora_config.target_modules = modules |
|
|
self.llm = get_peft_model(self.llm, lora_config) |
|
|
|
|
|
def verify_lora(self): |
|
|
m = self.llm |
|
|
|
|
|
|
|
|
assert isinstance(m, PeftModel), "LoRA not applied: model is not a PeftModel" |
|
|
|
|
|
|
|
|
adapters = m.peft_config |
|
|
assert len(adapters) > 0, "No adapters registered in peft_config" |
|
|
active = m.active_adapter if hasattr(m, "active_adapter") else None |
|
|
assert active in adapters, f"Active adapter {active} not found in peft_config" |
|
|
|
|
|
|
|
|
lora_modules = [mod for mod in m.modules() if isinstance(mod, LoraLayer)] |
|
|
assert len(lora_modules) > 0, "No LoraLayer modules found (check target_modules)" |
|
|
|
|
|
|
|
|
trainable = [(n,p) for n,p in m.named_parameters() if p.requires_grad] |
|
|
assert len(trainable) > 0, "No trainable parameters (LoRA params are not set to requires_grad=True)" |
|
|
|
|
|
suspicious = [n for n,_ in trainable if "lora_" not in n and "modules_to_save" not in n] |
|
|
|
|
|
assert len(suspicious) == 0, f"Unexpected trainable params (not LoRA): {suspicious[:5]}" |
|
|
|
|
|
|
|
|
total = sum(p.numel() for _,p in m.named_parameters()) |
|
|
trainable_cnt = sum(p.numel() for _,p in trainable) |
|
|
ratio = trainable_cnt / total |
|
|
print(f"[LoRA OK] adapters={list(adapters.keys())}, active={active}, " |
|
|
f"LoraLayers={len(lora_modules)}, trainable={trainable_cnt}/{total} ({ratio:.4%})") |
|
|
|
|
|
|
|
|
m.train() |
|
|
dummy_inp = torch.randint(0, m.get_input_embeddings().num_embeddings, (1, 8)).to(next(m.parameters()).device) |
|
|
out = m(input_ids=dummy_inp, labels=dummy_inp) |
|
|
out.loss.backward() |
|
|
|
|
|
lora_grads = [p.grad for _,p in m.named_parameters() if p.requires_grad and p.grad is not None] |
|
|
assert len(lora_grads) > 0, "No gradients on LoRA parameters after backward()" |
|
|
|
|
|
def _prepare_visual_encoder_for_lora(self, |
|
|
lora_config, |
|
|
use_activation_checkpointing=True): |
|
|
lora_config = self._parse_lora_config(lora_config) |
|
|
if lora_config.target_modules is None: |
|
|
modules = find_all_linear_names(self.visual_encoder) |
|
|
lora_config.target_modules = modules |
|
|
self.visual_encoder = get_peft_model(self.visual_encoder, lora_config) |
|
|
|
|
|
def gradient_checkpointing_enable(self, use_reentrant=False): |
|
|
self.activation_checkpointing_enable(use_reentrant=use_reentrant) |
|
|
|
|
|
def activation_checkpointing_enable(self, use_reentrant=False): |
|
|
|
|
|
try: |
|
|
self.llm.gradient_checkpointing_enable(use_reentrant=use_reentrant) |
|
|
except TypeError: |
|
|
|
|
|
self.llm.gradient_checkpointing_enable() |
|
|
|
|
|
|
|
|
try: |
|
|
self.projector.gradient_checkpointing_enable(use_reentrant=use_reentrant) |
|
|
except TypeError: |
|
|
self.projector.gradient_checkpointing_enable() |
|
|
|
|
|
if getattr(self, 'use_resampler', False) and getattr(self, 'resampler', None) is not None: |
|
|
try: |
|
|
self.resampler.gradient_checkpointing_enable(use_reentrant=use_reentrant) |
|
|
except: |
|
|
self.resampler.gradient_checkpointing_enable() |
|
|
|
|
|
|
|
|
def gradient_checkpointing_disable(self): |
|
|
self.activation_checkpointing_disable() |
|
|
|
|
|
def activation_checkpointing_disable(self): |
|
|
self.llm.gradient_checkpointing_disable() |
|
|
self.projector.gradient_checkpointing_disable() |
|
|
if getattr(self, 'use_resampler', False) and getattr(self, 'resampler', None) is not None: |
|
|
self.resampler.gradient_checkpointing_disable() |
|
|
|
|
|
|
|
|
def init_weights(self): |
|
|
pass |
|
|
|
|
|
def state_dict(self, *args, **kwargs): |
|
|
state_dict = super().state_dict(*args, **kwargs) |
|
|
to_return = OrderedDict() |
|
|
|
|
|
if self.use_visual_encoder_lora: |
|
|
to_return.update( |
|
|
get_peft_model_state_dict( |
|
|
self.visual_encoder, state_dict=state_dict)) |
|
|
elif not self.freeze_visual_encoder: |
|
|
to_return.update({ |
|
|
k: v |
|
|
for k, v in state_dict.items() if 'visual_encoder.' in k |
|
|
}) |
|
|
|
|
|
if self.use_llm_lora: |
|
|
to_return.update( |
|
|
get_peft_model_state_dict(self.llm, state_dict=state_dict)) |
|
|
|
|
|
elif not self.freeze_llm: |
|
|
to_return.update( |
|
|
{k: v |
|
|
for k, v in state_dict.items() if 'llm.' in k}) |
|
|
|
|
|
to_return.update( |
|
|
{k: v |
|
|
for k, v in state_dict.items() if 'projector.' in k}) |
|
|
|
|
|
|
|
|
if getattr(self, 'use_resampler', False) and getattr(self, 'resampler', None) is not None: |
|
|
to_return.update({k: v for k, v in state_dict.items() if 'resampler.' in k}) |
|
|
|
|
|
|
|
|
if getattr(self, 'token_merge', False): |
|
|
to_return.update({k: v for k, v in state_dict.items() if 'token_merge.' in k}) |
|
|
return to_return |
|
|
|
|
|
@staticmethod |
|
|
def _prepare_for_long_context_training(cfg, llm_cfg, |
|
|
max_position_embeddings): |
|
|
|
|
|
orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None) |
|
|
if orig_rope_scaling is None: |
|
|
orig_rope_scaling = {'factor': 1} |
|
|
|
|
|
orig_rope_scaling_factor = orig_rope_scaling[ |
|
|
'factor'] if 'factor' in orig_rope_scaling.keys() else 1 |
|
|
orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None) |
|
|
if orig_ctx_len: |
|
|
orig_ctx_len *= orig_rope_scaling_factor |
|
|
if max_position_embeddings > orig_ctx_len: |
|
|
scaling_factor = float( |
|
|
math.ceil(max_position_embeddings / orig_ctx_len)) |
|
|
llm_cfg.rope_scaling = { |
|
|
'type': 'linear', |
|
|
'factor': scaling_factor |
|
|
} |
|
|
|
|
|
|
|
|
llm_cfg.attn_implementation = 'flash_attention_2' |
|
|
cfg.config = llm_cfg |
|
|
|
|
|
return cfg, llm_cfg |
|
|
|
|
|
@staticmethod |
|
|
def _prepare_for_flash_attn(cfg, llm_cfg): |
|
|
cls_name = type(llm_cfg).__name__ |
|
|
SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig', |
|
|
'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig', |
|
|
'Starcoder2Config', 'Starcoder2Config', |
|
|
'Phi3Config') |
|
|
SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig', |
|
|
'MistralConfig', 'MixtralConfig', 'Qwen2Config', |
|
|
'Qwen2MoeConfig', 'Starcoder2Config', |
|
|
'Starcoder2Config', 'Phi3Config') |
|
|
|
|
|
torch_dtype = torch.bfloat16 if ( |
|
|
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ |
|
|
else torch.float16 |
|
|
|
|
|
if getattr(cfg, 'attn_implementation', None) is not None: |
|
|
|
|
|
|
|
|
if cfg.attn_implementation == 'flash_attention_2': |
|
|
cfg.torch_dtype = torch_dtype |
|
|
elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: |
|
|
cfg.torch_dtype = torch_dtype |
|
|
cfg.attn_implementation = 'flash_attention_2' |
|
|
elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN: |
|
|
cfg.attn_implementation = 'sdpa' |
|
|
|
|
|
return cfg, llm_cfg |
|
|
|
|
|
@staticmethod |
|
|
def _prepare_for_qlora_zero3(cfg): |
|
|
if (not is_deepspeed_zero3_enabled()) or (not hasattr( |
|
|
cfg, 'quantization_config')): |
|
|
return cfg |
|
|
|
|
|
torch_dtype = torch.bfloat16 if ( |
|
|
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ |
|
|
else torch.float16 |
|
|
|
|
|
cfg.torch_dtype = torch_dtype |
|
|
quantization_config = cfg.quantization_config |
|
|
quantization_config.bnb_4bit_compute_dtype = torch_dtype |
|
|
quantization_config.bnb_4bit_quant_storage = torch_dtype |
|
|
|
|
|
return cfg |
|
|
|
|
|
def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None): |
|
|
cfg = self._prepare_for_qlora_zero3(cfg) |
|
|
pretrained_model_name_or_path = cfg.pretrained_model_name_or_path |
|
|
llm_cfg = AutoConfig.from_pretrained( |
|
|
pretrained_model_name_or_path, trust_remote_code=True) |
|
|
|
|
|
cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg) |
|
|
if max_position_embeddings is not None: |
|
|
cfg, llm_cfg = self._prepare_for_long_context_training( |
|
|
cfg, llm_cfg, max_position_embeddings) |
|
|
return cfg |
|
|
|
|
|
def _build_from_cfg_or_module(self, cfg_or_mod): |
|
|
if isinstance(cfg_or_mod, nn.Module): |
|
|
return cfg_or_mod |
|
|
elif isinstance(cfg_or_mod, dict): |
|
|
traverse_dict(cfg_or_mod) |
|
|
return BUILDER.build(cfg_or_mod) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def coords_to_pos(self, coords, tile_size: int = 224): |
|
|
""" |
|
|
This function is used to convert the coordinates to the positional indices |
|
|
|
|
|
Arguments: |
|
|
---------- |
|
|
coords: torch.Tensor |
|
|
The coordinates of the patches, of shape [N, L, 2] |
|
|
output: torch.Tensor |
|
|
The positional indices of the patches, of shape [N, L] |
|
|
""" |
|
|
coords_ = torch.floor(coords / tile_size) |
|
|
pos = coords_[..., 0] * self.slide_ngrids + coords_[..., 1] |
|
|
return pos.long() |
|
|
|
|
|
@staticmethod |
|
|
def _coords_rc_to_pos(coords_rc: torch.Tensor, ngrids: int) -> torch.Tensor: |
|
|
if coords_rc.dtype.is_floating_point: |
|
|
coords_rc = coords_rc.round().to(torch.long) |
|
|
|
|
|
|
|
|
return (coords_rc[..., 0] * ngrids + coords_rc[..., 1]).long() |
|
|
|
|
|
def forward(self, data, data_samples=None, mode='loss'): |
|
|
if self.is_first_iter: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.to(data['input_ids'].device) |
|
|
self.is_first_iter = False |
|
|
coords = None |
|
|
|
|
|
if 'pixel_values' in data: |
|
|
|
|
|
feat_to_proj = data['pixel_values'].to(self.llm.dtype) |
|
|
|
|
|
feat_to_proj.requires_grad_(True) |
|
|
|
|
|
if 'coords' in data: |
|
|
coords = data['coords'].to(self.llm.dtype) |
|
|
|
|
|
coords_t = coords[0] if isinstance(coords, list) else coords |
|
|
Bx = feat_to_proj.size(0) |
|
|
if not torch.is_tensor(coords_t): |
|
|
raise ValueError("coords must be a Tensor or list[Tensor].") |
|
|
|
|
|
if coords_t.dim() == 2: |
|
|
|
|
|
coords_rc = coords_t |
|
|
elif coords_t.dim() == 3: |
|
|
|
|
|
if coords_t.size(0) != Bx: |
|
|
raise ValueError(f"coords batch dim mismatch: got {coords_t.size(0)} but inputs have B={Bx}") |
|
|
if Bx == 1: |
|
|
coords_rc = coords_t[0] |
|
|
else: |
|
|
|
|
|
if not torch.equal(coords_t, coords_t[0].unsqueeze(0).expand_as(coords_t)): |
|
|
raise NotImplementedError( |
|
|
"Per-example coords (varying across batch) are not supported by the current " |
|
|
"patch-merging/layout path. Use batch size 1 or share coords across the batch." |
|
|
) |
|
|
coords_rc = coords_t[0] |
|
|
else: |
|
|
raise ValueError("coords must have shape [L,2] or [B,L,2].") |
|
|
|
|
|
if coords_rc.size(-1) != 2: |
|
|
raise ValueError("coords last dimension must be 2.") |
|
|
else: |
|
|
raise RuntimeError |
|
|
|
|
|
|
|
|
if self.enable_token_merge: |
|
|
feat_to_proj, coords_rc_merged, _ = self.token_merge( |
|
|
x=feat_to_proj, |
|
|
coords_rc=self._coords_to_rowcol(coords_rc), |
|
|
padmask=torch.zeros([feat_to_proj.size(0), feat_to_proj.size(1)], |
|
|
device=feat_to_proj.device, dtype=torch.bool) |
|
|
) |
|
|
|
|
|
else: |
|
|
coords_rc_merged = self._coords_to_rowcol(coords_rc) |
|
|
padmask_merged = torch.zeros([feat_to_proj.size(0), feat_to_proj.size(1)], |
|
|
device=feat_to_proj.device, dtype=torch.bool) |
|
|
|
|
|
pixel_values = self.projector(feat_to_proj.to(self.llm.dtype)) |
|
|
|
|
|
if self.use_resampler and getattr(self, 'resampler', None) is not None: |
|
|
pixel_values = self.resampler(pixel_values, coords_rc_merged, |
|
|
attn_mask= None) |
|
|
|
|
|
data['pixel_values'] = pixel_values |
|
|
|
|
|
|
|
|
data.pop('coords', None) |
|
|
|
|
|
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data) |
|
|
|
|
|
if mode == 'loss': |
|
|
return self.compute_loss(data, data_samples) |
|
|
elif mode == 'predict': |
|
|
return self.predict(data, data_samples) |
|
|
elif mode == 'tensor': |
|
|
return self._forward(data, data_samples) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
@staticmethod |
|
|
def _coords_to_rowcol(coords_xy: torch.Tensor) -> torch.Tensor: |
|
|
with torch.no_grad(): |
|
|
x = coords_xy[:, 0] |
|
|
y = coords_xy[:, 1] |
|
|
x_for_unique = x |
|
|
y_for_unique = y |
|
|
if x_for_unique.dtype.is_floating_point: |
|
|
x_for_unique = x_for_unique.round().to(torch.int) |
|
|
y_for_unique = y_for_unique.round().to(torch.int) |
|
|
x_sorted = torch.unique(x_for_unique, sorted=True) |
|
|
y_sorted = torch.unique(y_for_unique, sorted = True) |
|
|
|
|
|
col = torch.searchsorted(x_sorted, x) |
|
|
row = torch.searchsorted(y_sorted, y) |
|
|
return torch.stack([row, col], dim=-1) |
|
|
|
|
|
def _forward(self, data, data_samples=None): |
|
|
|
|
|
outputs = self.llm(**data) |
|
|
|
|
|
return outputs |
|
|
|
|
|
def predict(self, data, data_samples=None): |
|
|
outputs = self.llm(**data) |
|
|
logits_dict = [{'logits': logits} for logits in outputs.logits] |
|
|
return logits_dict |
|
|
|
|
|
def compute_loss(self, data, data_samples=None): |
|
|
""" |
|
|
计算损失的修改版实现。 |
|
|
该版本通过计算批次中每个样本的平均损失来解决长短文本的梯度失衡问题, |
|
|
使得每个样本对总损失的贡献相等,无论其token长度如何。 |
|
|
""" |
|
|
|
|
|
if "labels" not in data: |
|
|
outputs = self.llm(**data) |
|
|
return {"loss": outputs.loss} |
|
|
|
|
|
|
|
|
labels = data.pop("labels") |
|
|
|
|
|
|
|
|
outputs = self.llm(**data) |
|
|
logits = outputs.logits |
|
|
|
|
|
|
|
|
if logits.shape[:-1] != labels.shape: |
|
|
raise ValueError( |
|
|
f"Logits and labels shape mismatch. Logits: {logits.shape}, Labels: {labels.shape}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
|
|
|
|
|
|
loss = F.cross_entropy( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1), |
|
|
ignore_index=-100, |
|
|
reduction='none' |
|
|
) |
|
|
|
|
|
|
|
|
loss = loss.view(shift_logits.size(0), -1) |
|
|
|
|
|
|
|
|
|
|
|
num_tokens_per_sample = (shift_labels != -100).sum(dim=1) |
|
|
|
|
|
|
|
|
loss_per_sample = loss.sum(dim=1) |
|
|
|
|
|
|
|
|
valid_samples_mask = num_tokens_per_sample > 0 |
|
|
|
|
|
|
|
|
mean_loss_per_sample = torch.zeros_like(loss_per_sample) |
|
|
|
|
|
|
|
|
if valid_samples_mask.any(): |
|
|
mean_loss_per_sample[valid_samples_mask] = loss_per_sample[valid_samples_mask] / num_tokens_per_sample[valid_samples_mask] |
|
|
|
|
|
|
|
|
final_loss = mean_loss_per_sample.mean() |
|
|
|
|
|
return {"loss": final_loss} |
|
|
|
|
|
|
|
|
|
|
|
def __getattr__(self, name: str): |
|
|
try: |
|
|
return super().__getattr__(name) |
|
|
except AttributeError: |
|
|
return getattr(self.llm, name) |
|
|
|
|
|
def to_hf(self, |
|
|
cfg, |
|
|
save_dir, |
|
|
fp32=False, |
|
|
save_pretrained_kwargs={}, |
|
|
save_format='xtuner', |
|
|
**kwargs): |
|
|
if save_format == 'xtuner': |
|
|
self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs) |
|
|
elif save_format == 'huggingface': |
|
|
self.to_huggingface_llava(cfg, save_dir, fp32, |
|
|
save_pretrained_kwargs) |
|
|
elif save_format == 'official': |
|
|
self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def to_xtuner_llava(self, |
|
|
cfg, |
|
|
save_dir, |
|
|
fp32=False, |
|
|
save_pretrained_kwargs={}): |
|
|
|
|
|
self.llm.config.use_cache = True |
|
|
if not fp32: |
|
|
print_log('Convert LLM to float16', 'current') |
|
|
self.llm.half() |
|
|
if self.use_llm_lora: |
|
|
llm_path = osp.join(save_dir, 'llm_adapter') |
|
|
print_log(f'Saving LLM adapter to {llm_path}', 'current') |
|
|
self.llm.save_pretrained(llm_path, **save_pretrained_kwargs) |
|
|
elif not self.freeze_llm: |
|
|
llm_path = save_dir |
|
|
print_log(f'Saving LLM tokenizer to {llm_path}', 'current') |
|
|
tokenizer = BUILDER.build(cfg.tokenizer) |
|
|
tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs) |
|
|
print_log(f'Saving LLM to {llm_path}', 'current') |
|
|
self.llm.save_pretrained(llm_path, **save_pretrained_kwargs) |
|
|
self.llm.config.use_cache = False |
|
|
|
|
|
|
|
|
if self.use_visual_encoder_lora: |
|
|
visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter') |
|
|
print_log( |
|
|
f'Saving visual_encoder adapter to {visual_encoder_path}', |
|
|
'current') |
|
|
self.visual_encoder.save_pretrained(visual_encoder_path, |
|
|
**save_pretrained_kwargs) |
|
|
elif not self.freeze_visual_encoder: |
|
|
visual_encoder_path = osp.join(save_dir, 'visual_encoder') |
|
|
print_log( |
|
|
'Saving visual_encoder image_processor to' |
|
|
f'{visual_encoder_path}', 'current') |
|
|
image_processor = BUILDER.build(cfg.image_processor) |
|
|
image_processor.save_pretrained(visual_encoder_path, |
|
|
**save_pretrained_kwargs) |
|
|
print_log(f'Saving visual_encoder to {visual_encoder_path}', |
|
|
'current') |
|
|
self.visual_encoder.save_pretrained(visual_encoder_path, |
|
|
**save_pretrained_kwargs) |
|
|
|
|
|
|
|
|
projector_path = osp.join(save_dir, 'projector') |
|
|
print_log(f'Saving projector to {projector_path}', 'current') |
|
|
os.makedirs(projector_path, exist_ok=True) |
|
|
output_path = os.path.join(projector_path, 'projector.safetensors') |
|
|
save_file(self.projector.state_dict(), output_path) |
|
|
|
|
|
if self.use_resampler and hasattr(self, 'resampler'): |
|
|
|
|
|
resampler_path = osp.join(save_dir, "resampler") |
|
|
print_log(f'Saving Resampler to {resampler_path}', 'current') |
|
|
os.makedirs(resampler_path, exist_ok=True) |
|
|
resampler_output_path = os.path.join(resampler_path, 'resampler.safetensors') |
|
|
save_file(self.resampler.state_dict(), resampler_output_path) |
|
|
|
|
|
if self.enable_token_merge and hasattr(self, 'token_merge'): |
|
|
merger_path = osp.join(save_dir, 'token_merger') |
|
|
print_log(f'Saving token merger to{merger_path}', 'current') |
|
|
os.makedirs(merger_path, exist_ok= True) |
|
|
merger_path = os.path.join(merger_path, 'merger.safetensors') |
|
|
save_file(self.token_merge.state_dict(), merger_path) |
|
|
|
|
|
def to_huggingface_llava(self, |
|
|
cfg, |
|
|
save_dir, |
|
|
fp32=False, |
|
|
save_pretrained_kwargs={}): |
|
|
|
|
|
if self.use_resampler: |
|
|
warnings.warn("Conversion to HuggingFace LLaVA format with a custom resampler is not supported. " |
|
|
"The resampler weights will not be saved.") |
|
|
|
|
|
LLM_MAPPING = { |
|
|
'model': 'language_model.model', |
|
|
'lm_head': 'language_model.lm_head', |
|
|
} |
|
|
VIT_MAPPING = { |
|
|
'vision_model': 'vision_tower.vision_model', |
|
|
} |
|
|
PROJECTOR_MAPPING = { |
|
|
'model.0': 'multi_modal_projector.linear_1', |
|
|
'model.2': 'multi_modal_projector.linear_2', |
|
|
} |
|
|
|
|
|
assert getattr(self.llm, 'hf_quantizer', None) is None, \ |
|
|
'This conversion format does not support quantized LLM.' |
|
|
|
|
|
|
|
|
llm = self.llm |
|
|
if self.use_llm_lora: |
|
|
llm = self.llm.merge_and_unload() |
|
|
llm.config.use_cache = True |
|
|
if not fp32: |
|
|
print_log('Convert LLM to float16', 'current') |
|
|
llm.half() |
|
|
|
|
|
assert isinstance(llm, LlamaForCausalLM), \ |
|
|
'This conversion format only supports LlamaForCausalLM.' |
|
|
llm_state_dict = llm.state_dict() |
|
|
llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING) |
|
|
|
|
|
need_visual_encoder = (not self.freeze_visual_encoder |
|
|
or self.use_visual_encoder_lora) |
|
|
visual_encoder = self.visual_encoder |
|
|
if self.use_visual_encoder_lora: |
|
|
visual_encoder = self.visual_encoder.merge_and_unload() |
|
|
assert isinstance(visual_encoder, CLIPVisionModel),\ |
|
|
'This conversion format only supports CLIPVisionModel.' |
|
|
if need_visual_encoder: |
|
|
visual_encoder_state_dict = visual_encoder.state_dict() |
|
|
visual_encoder_state_dict = convert_state_dict_to_hf( |
|
|
visual_encoder_state_dict, VIT_MAPPING) |
|
|
else: |
|
|
visual_encoder_state_dict = {} |
|
|
|
|
|
projector_state_dict = self.projector.state_dict() |
|
|
projector_state_dict = convert_state_dict_to_hf( |
|
|
projector_state_dict, PROJECTOR_MAPPING) |
|
|
|
|
|
state_dict = { |
|
|
**projector_state_dict, |
|
|
**llm_state_dict, |
|
|
**visual_encoder_state_dict, |
|
|
} |
|
|
|
|
|
|
|
|
text_config = llm.config |
|
|
vision_config = visual_encoder.config |
|
|
config = LlavaConfig( |
|
|
text_config=text_config, |
|
|
vision_config=vision_config, |
|
|
attn_implementation='eager') |
|
|
|
|
|
with init_empty_weights(): |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings( |
|
|
'ignore', message='.*non-meta.*', category=UserWarning) |
|
|
model = LlavaForConditionalGeneration(config) |
|
|
model.load_state_dict(state_dict, strict=False, assign=True) |
|
|
|
|
|
|
|
|
cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained |
|
|
tokenizer = BUILDER.build(cfg.tokenizer) |
|
|
|
|
|
tokenizer.add_tokens( |
|
|
AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False), |
|
|
special_tokens=True) |
|
|
tokenizer.add_special_tokens({'pad_token': '<pad>'}) |
|
|
|
|
|
image_processor = BUILDER.build(cfg.image_processor) |
|
|
assert isinstance(image_processor, CLIPImageProcessor),\ |
|
|
'This conversion format only supports CLIPImageProcessor.' |
|
|
|
|
|
processor = LlavaProcessor( |
|
|
tokenizer=tokenizer, image_processor=image_processor) |
|
|
|
|
|
|
|
|
pad_shape = 64 |
|
|
|
|
|
pre_expansion_embeddings = \ |
|
|
model.language_model.model.embed_tokens.weight.data |
|
|
mu = torch.mean(pre_expansion_embeddings, dim=0).float() |
|
|
n = pre_expansion_embeddings.size()[0] |
|
|
sigma = ((pre_expansion_embeddings - mu).T |
|
|
@ (pre_expansion_embeddings - mu)) / n |
|
|
dist = torch.distributions.multivariate_normal.MultivariateNormal( |
|
|
mu, covariance_matrix=1e-5 * sigma) |
|
|
|
|
|
|
|
|
ori_vocab_size = config.text_config.vocab_size |
|
|
tokenizer_vocab_size = tokenizer.encode('<pad>')[-1] |
|
|
added_token = tokenizer_vocab_size - ori_vocab_size |
|
|
|
|
|
if added_token > 0: |
|
|
model.resize_token_embeddings(ori_vocab_size + added_token, |
|
|
pad_shape) |
|
|
model.language_model.model.embed_tokens.weight.data[ |
|
|
ori_vocab_size:] = torch.stack( |
|
|
tuple( |
|
|
dist.sample() |
|
|
for _ in range(model.language_model.model.embed_tokens. |
|
|
weight.data[ori_vocab_size:].shape[0])), |
|
|
dim=0, |
|
|
) |
|
|
model.language_model.lm_head.weight.data[ |
|
|
ori_vocab_size:] = torch.stack( |
|
|
tuple(dist.sample() |
|
|
for _ in range(model.language_model.lm_head.weight. |
|
|
data[ori_vocab_size:].shape[0])), |
|
|
dim=0, |
|
|
) |
|
|
model.config.image_token_index = tokenizer.encode( |
|
|
DEFAULT_IMAGE_TOKEN)[-1] |
|
|
model.config.pad_token_id = tokenizer.encode('<pad>')[-1] |
|
|
|
|
|
|
|
|
print_log(f'Saving to {save_dir}', 'current') |
|
|
model.save_pretrained(save_dir, **save_pretrained_kwargs) |
|
|
processor.save_pretrained(save_dir, **save_pretrained_kwargs) |
|
|
|
|
|
def to_official_llava(self, |
|
|
cfg, |
|
|
save_dir, |
|
|
fp32=False, |
|
|
save_pretrained_kwargs={}): |
|
|
if self.use_resampler: |
|
|
warnings.warn("Conversion to official LLaVA format with a custom resampler is not supported. " |
|
|
"The resampler weights will not be saved.") |
|
|
VIT_MAPPING = { |
|
|
'vision_model': 'model.vision_tower.vision_tower.vision_model', |
|
|
} |
|
|
PROJECTOR_MAPPING = { |
|
|
'model.0': 'model.mm_projector.0', |
|
|
'model.2': 'model.mm_projector.2', |
|
|
} |
|
|
|
|
|
try: |
|
|
from llava.model import LlavaConfig, LlavaLlamaForCausalLM |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
'Please install llava with ' |
|
|
'`pip install git+https://github.com/haotian-liu/LLaVA.git ' |
|
|
'--no-deps`.') |
|
|
|
|
|
assert getattr(self.llm, 'hf_quantizer', None) is None, \ |
|
|
'This conversion format does not support quantized LLM.' |
|
|
|
|
|
|
|
|
llm = self.llm |
|
|
if self.use_llm_lora: |
|
|
llm = self.llm.merge_and_unload() |
|
|
llm.config.use_cache = True |
|
|
if not fp32: |
|
|
print_log('Convert LLM to float16', 'current') |
|
|
llm.half() |
|
|
|
|
|
assert isinstance(llm, LlamaForCausalLM), \ |
|
|
'This conversion format only supports LlamaForCausalLM.' |
|
|
llm_state_dict = llm.state_dict() |
|
|
|
|
|
need_visual_encoder = (not self.freeze_visual_encoder |
|
|
or self.use_visual_encoder_lora) |
|
|
visual_encoder = self.visual_encoder |
|
|
if self.use_visual_encoder_lora: |
|
|
visual_encoder = self.visual_encoder.merge_and_unload() |
|
|
assert isinstance(visual_encoder, CLIPVisionModel),\ |
|
|
'This conversion format only supports CLIPVisionModel.' |
|
|
if need_visual_encoder: |
|
|
visual_encoder_state_dict = visual_encoder.state_dict() |
|
|
visual_encoder_state_dict = convert_state_dict_to_hf( |
|
|
visual_encoder_state_dict, VIT_MAPPING) |
|
|
else: |
|
|
visual_encoder_state_dict = {} |
|
|
|
|
|
projector_state_dict = self.projector.state_dict() |
|
|
projector_state_dict = convert_state_dict_to_hf( |
|
|
projector_state_dict, PROJECTOR_MAPPING) |
|
|
|
|
|
state_dict = { |
|
|
**projector_state_dict, |
|
|
**llm_state_dict, |
|
|
**visual_encoder_state_dict, |
|
|
} |
|
|
|
|
|
|
|
|
tokenizer = BUILDER.build(cfg.tokenizer) |
|
|
image_processor = BUILDER.build(cfg.image_processor) |
|
|
assert isinstance(image_processor, CLIPImageProcessor),\ |
|
|
'This conversion format only supports CLIPImageProcessor.' |
|
|
|
|
|
llava_config_dict = llm.config.__dict__.copy() |
|
|
llava_config_dict.update( |
|
|
dict( |
|
|
image_aspect_ratio='pad', |
|
|
mm_hidden_size=visual_encoder.config.hidden_size, |
|
|
mm_projector_type=f'mlp{self.projector_depth}x_gelu', |
|
|
mm_use_im_patch_token=False, |
|
|
mm_use_im_start_end=False, |
|
|
mm_vision_select_feature='patch', |
|
|
mm_vision_select_layer=self.visual_select_layer, |
|
|
mm_vision_tower=visual_encoder.config.name_or_path, |
|
|
unfreeze_mm_vision_tower=need_visual_encoder, |
|
|
model_type='llava', |
|
|
use_cache=True, |
|
|
use_mm_proj=True)) |
|
|
|
|
|
llava_config = LlavaConfig(**llava_config_dict) |
|
|
|
|
|
with init_empty_weights(): |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings( |
|
|
'ignore', message='.*non-meta.*', category=UserWarning) |
|
|
model = LlavaLlamaForCausalLM(llava_config) |
|
|
|
|
|
model.load_state_dict(state_dict, strict=False, assign=True) |
|
|
|
|
|
|
|
|
print_log(f'Saving to {save_dir}', 'current') |
|
|
|
|
|
model.save_pretrained(save_dir, **save_pretrained_kwargs) |
|
|
image_processor.save_pretrained(save_dir, **save_pretrained_kwargs) |
|
|
tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs) |