|
|
import math |
|
|
import os.path as osp |
|
|
import warnings |
|
|
from collections import OrderedDict |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.utils.checkpoint as cp |
|
|
from accelerate import init_empty_weights |
|
|
from mmengine import print_log |
|
|
from mmengine.config import Config, ConfigDict |
|
|
from mmengine.model import BaseModel |
|
|
from peft import get_peft_model, prepare_model_for_kbit_training |
|
|
from transformers import (AddedToken, AutoConfig, CLIPImageProcessor, |
|
|
CLIPVisionModel, LlamaForCausalLM, |
|
|
LlamaTokenizerFast, LlavaConfig, |
|
|
LlavaForConditionalGeneration, LlavaProcessor, |
|
|
PreTrainedModel, PretrainedConfig) |
|
|
from transformers.integrations import is_deepspeed_zero3_enabled |
|
|
|
|
|
from xtuner.registry import BUILDER |
|
|
from xtuner.utils import DEFAULT_IMAGE_TOKEN |
|
|
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules |
|
|
from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 |
|
|
from .utils import (LoadWoInit, find_all_linear_names, |
|
|
get_peft_model_state_dict, guess_load_checkpoint, |
|
|
make_inputs_require_grad, |
|
|
prepare_inputs_labels_for_multimodal, traverse_dict) |
|
|
from .torchscale.model.LongNet import make_longnet_from_name |
|
|
|
|
|
|
|
|
def convert_state_dict_to_hf(state_dict, mapping): |
|
|
new_state_dict = {} |
|
|
for key, value in state_dict.items(): |
|
|
if key.endswith('.inv_freq'): |
|
|
continue |
|
|
for key_to_modify, new_key in mapping.items(): |
|
|
if key_to_modify in key: |
|
|
key = key.replace(key_to_modify, new_key) |
|
|
new_state_dict[key] = value |
|
|
return new_state_dict |
|
|
|
|
|
|
|
|
class ReducerConfig(PretrainedConfig): |
|
|
model_type = 'Reducer' |
|
|
_auto_class = 'AutoConfig' |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_tokens=4096, |
|
|
out_tokens=2048, |
|
|
hidden_tokens=1024, |
|
|
num_queries = 2048, |
|
|
num_heads = 8, |
|
|
hidden_size = 3584, |
|
|
kernel_size=4, |
|
|
stride=4, |
|
|
**kwargs, |
|
|
): |
|
|
self.in_tokens = in_tokens |
|
|
self.out_tokens = out_tokens |
|
|
self.hidden_tokens = hidden_tokens |
|
|
self.kernel_size = kernel_size, |
|
|
self.stride = stride, |
|
|
if self.hidden_tokens is None: |
|
|
self.hidden_tokens = max(self.in_tokens // 2, self.out_tokens) |
|
|
|
|
|
self.hidden_size = hidden_size |
|
|
self.num_queries = num_queries |
|
|
self.num_heads = num_heads |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
class VisualTokenConvReducer(PreTrainedModel): |
|
|
""" |
|
|
Wraps a Conv1d reducer with activation-checkpointing and input-grad support. |
|
|
Handles (B, T, D) inputs by transposing internally for Conv1d. |
|
|
""" |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
def __init__(self, config: ReducerConfig): |
|
|
super().__init__(config) |
|
|
self.conv = nn.Conv1d( |
|
|
in_channels=config.hidden_size, |
|
|
out_channels=config.hidden_size, |
|
|
kernel_size=config.kernel_size, |
|
|
stride=config.stride |
|
|
) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def enable_input_require_grads(self): |
|
|
def make_inputs_require_grad(module, inputs, output): |
|
|
output.requires_grad_(True) |
|
|
self.conv.register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
x: [batch, tokens_in, features] |
|
|
returns: [batch, tokens_out, features] |
|
|
""" |
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
out = cp.checkpoint(self.conv, x) |
|
|
else: |
|
|
out = self.conv(x) |
|
|
|
|
|
|
|
|
out = out.transpose(1, 2) |
|
|
return out |
|
|
|
|
|
|
|
|
class VisualTokenMLPReducer(PreTrainedModel): |
|
|
""" |
|
|
MLP-based token reducer. Handles (B, T, D) inputs by transposing internally. |
|
|
""" |
|
|
base_model_prefix = "visual_token_mlp_reducer" |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
def __init__(self, cfg:ReducerConfig): |
|
|
super().__init__(cfg) |
|
|
self.in_tokens = cfg.in_tokens |
|
|
self.out_tokens = cfg.out_tokens |
|
|
self.hidden_tokens = cfg.hidden_tokens |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(self.in_tokens, self.hidden_tokens, bias=True) |
|
|
self.act = nn.GELU() |
|
|
self.fc2 = nn.Linear(self.hidden_tokens, self.out_tokens, bias=True) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def enable_input_require_grads(self): |
|
|
def make_inputs_require_grad(module, inputs, output): |
|
|
output.requires_grad_(True) |
|
|
self.fc1.register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
x: [batch, tokens_in, features] |
|
|
returns: [batch, tokens_out, features] |
|
|
""" |
|
|
B, T, D = x.shape |
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
|
|
|
x = x.reshape(B * D, T) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
out = cp.checkpoint(self._mlp, x) |
|
|
else: |
|
|
out = self._mlp(x) |
|
|
|
|
|
|
|
|
out = out.view(B, D, self.out_tokens) |
|
|
|
|
|
|
|
|
out = out.transpose(1, 2) |
|
|
return out |
|
|
|
|
|
def _mlp(self, x): |
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class VisualTokenAttentionReducer(PreTrainedModel): |
|
|
base_model_prefix = "visual_token_attention_reducer" |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
def __init__(self, config: ReducerConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.query_emb = nn.Parameter(torch.randn(config.num_queries, config.hidden_size)) |
|
|
|
|
|
self.cross_attn = nn.MultiheadAttention( |
|
|
embed_dim=config.hidden_size, |
|
|
num_heads=config.num_heads, |
|
|
batch_first=True |
|
|
) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def enable_input_require_grads(self): |
|
|
def make_inputs_require_grad(module, inputs, output): |
|
|
if isinstance(output, (tuple, list)): |
|
|
output[0].requires_grad_(True) |
|
|
else: |
|
|
output.requires_grad_(True) |
|
|
self.cross_attn.register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
x: (B, T_in, D) → we want (B, M_out, D) |
|
|
""" |
|
|
B, T, D = x.shape |
|
|
|
|
|
tokens = x |
|
|
|
|
|
|
|
|
Q = self.query_emb.unsqueeze(0).expand(B, -1, -1) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
out = cp.checkpoint(self._attn, Q, tokens, tokens) |
|
|
else: |
|
|
out = self._attn(Q, tokens, tokens) |
|
|
|
|
|
|
|
|
return out |
|
|
|
|
|
def _attn(self, Q, K, V): |
|
|
|
|
|
out, _ = self.cross_attn(Q, K, V) |
|
|
return out |
|
|
|
|
|
|
|
|
class TextGuidedVisualTokenAttentionReducer(PreTrainedModel): |
|
|
""" |
|
|
An enhanced attention-based token reducer that uses text tokens to guide |
|
|
the compression of visual tokens, operating on batch-first tensors. |
|
|
""" |
|
|
base_model_prefix = "text_guided_visual_token_attention_reducer" |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
def __init__(self, config: ReducerConfig): |
|
|
super().__init__(config) |
|
|
self.query_emb = nn.Parameter( |
|
|
torch.randn(config.num_queries, config.hidden_size)) |
|
|
|
|
|
self.norm_kv = nn.LayerNorm(config.hidden_size) |
|
|
self.cross_attn = nn.MultiheadAttention( |
|
|
embed_dim=config.hidden_size, |
|
|
num_heads=config.num_heads, |
|
|
batch_first=True |
|
|
) |
|
|
self.norm_ffn = nn.LayerNorm(config.hidden_size) |
|
|
self.ffn = nn.Sequential( |
|
|
nn.Linear(config.hidden_size, config.hidden_size * 4), |
|
|
nn.GELU(), |
|
|
nn.Linear(config.hidden_size * 4, config.hidden_size) |
|
|
) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def enable_input_require_grads(self): |
|
|
def make_inputs_require_grad(module, inputs, output): |
|
|
if isinstance(output, tuple): |
|
|
output[0].requires_grad_(True) |
|
|
else: |
|
|
output.requires_grad_(True) |
|
|
self.cross_attn.register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
def forward(self, visual_tokens, text_tokens, text_attention_mask=None): |
|
|
""" |
|
|
Performs text-guided reduction of visual tokens. |
|
|
Args: |
|
|
visual_tokens (torch.Tensor): Visual tokens of shape (B, T_visual, D). |
|
|
text_tokens (torch.Tensor): Text token embeddings of shape (B, T_text, D). |
|
|
text_attention_mask (torch.Tensor): Mask for text tokens. |
|
|
Returns: |
|
|
torch.Tensor: Compressed visual tokens of shape (B, M_out, D). |
|
|
""" |
|
|
B, T_visual, D = visual_tokens.shape |
|
|
|
|
|
|
|
|
kv_tokens = torch.cat([visual_tokens, text_tokens], dim=1) |
|
|
|
|
|
|
|
|
key_padding_mask = None |
|
|
if text_attention_mask is not None: |
|
|
visual_padding_mask = torch.ones( |
|
|
B, T_visual, dtype=torch.bool, device=visual_tokens.device) |
|
|
combined_mask = torch.cat([visual_padding_mask, text_attention_mask], dim=1) |
|
|
key_padding_mask = (combined_mask == 0) |
|
|
|
|
|
|
|
|
queries = self.query_emb.unsqueeze(0).expand(B, -1, -1) |
|
|
|
|
|
|
|
|
attn_output, _ = self.cross_attn( |
|
|
query=queries, |
|
|
key=self.norm_kv(kv_tokens), |
|
|
value=self.norm_kv(kv_tokens), |
|
|
key_padding_mask=key_padding_mask |
|
|
) |
|
|
queries = queries + attn_output |
|
|
|
|
|
ffn_output = self.ffn(self.norm_ffn(queries)) |
|
|
queries = queries + ffn_output |
|
|
|
|
|
|
|
|
return queries |
|
|
|
|
|
|
|
|
class LLaVAModelWithReducer(BaseModel): |
|
|
|
|
|
def __init__(self, |
|
|
llm, |
|
|
freeze_llm=True, |
|
|
visual_select_layer=-2, |
|
|
pretrained_pth=None, |
|
|
projector_depth=2, |
|
|
llm_lora=None, |
|
|
visual_encoder_lora=None, |
|
|
use_activation_checkpointing=True, |
|
|
max_position_embeddings=None, |
|
|
hidden_size=512, |
|
|
train_stage='2', |
|
|
enable_long_net=True, |
|
|
visual_token_reducer_config=None): |
|
|
super().__init__() |
|
|
self.enable_long_net = enable_long_net |
|
|
self.freeze_llm = freeze_llm |
|
|
self.freeze_visual_encoder = True |
|
|
if train_stage == '1': |
|
|
self.freeze_llm = True |
|
|
self.freeze_long_net = False |
|
|
elif train_stage == '2': |
|
|
self.freeze_llm = True |
|
|
self.freeze_long_net = True |
|
|
|
|
|
with LoadWoInit(): |
|
|
if isinstance(llm, dict): |
|
|
llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings) |
|
|
self.llm = self._build_from_cfg_or_module(llm) |
|
|
|
|
|
self.encoder_name = f"LongNet_{2}_layers_{512}_dim" |
|
|
self.LongNet_encoder = make_longnet_from_name(self.encoder_name) |
|
|
self.llm.config.use_cache = False |
|
|
dispatch_modules(self.llm) |
|
|
|
|
|
|
|
|
self.projector_depth = projector_depth |
|
|
proj_cfg = ProjectorConfig( |
|
|
visual_hidden_size=hidden_size, |
|
|
llm_hidden_size=self.llm.config.hidden_size, |
|
|
depth=self.projector_depth) |
|
|
self.projector = ProjectorModel(proj_cfg).to(self.llm.dtype) |
|
|
|
|
|
|
|
|
if visual_token_reducer_config: |
|
|
cfg = visual_token_reducer_config |
|
|
reducer_type = cfg.get('type', 'attention') |
|
|
|
|
|
if reducer_type == 'conv': |
|
|
reducer_cfg = ReducerConfig( |
|
|
hidden_size= self.llm.config.hidden_size, |
|
|
kernel_size=cfg['kernel_size'], |
|
|
stride=cfg['stride'], |
|
|
) |
|
|
self.visual_token_reducer = VisualTokenConvReducer( |
|
|
reducer_cfg |
|
|
).to(self.llm.dtype) |
|
|
|
|
|
elif reducer_type == 'attention': |
|
|
reducer_cfg = ReducerConfig( |
|
|
in_tokens=cfg['in_tokens'], |
|
|
out_tokens=cfg['out_tokens'], |
|
|
hidden_tokens=cfg.get('hidden_tokens', None), |
|
|
num_heads=cfg.get('num_heads', 8), |
|
|
num_queries=cfg.get('num_queries', 2048), |
|
|
hidden_size=self.llm.config.hidden_size |
|
|
) |
|
|
self.visual_token_reducer = VisualTokenAttentionReducer(reducer_cfg).to(self.llm.dtype) |
|
|
elif reducer_type == 'text_guided_attention': |
|
|
reducer_cfg = ReducerConfig( |
|
|
in_tokens=cfg['in_tokens'], |
|
|
out_tokens=cfg['out_tokens'], |
|
|
hidden_tokens=cfg.get('hidden_tokens', None), |
|
|
num_heads=cfg.get('num_heads', 8), |
|
|
num_queries=cfg.get('num_queries', 2048), |
|
|
hidden_size=self.llm.config.hidden_size |
|
|
) |
|
|
self.visual_token_reducer = TextGuidedVisualTokenAttentionReducer(reducer_cfg).to(self.llm.dtype) |
|
|
|
|
|
elif reducer_type == 'mlp': |
|
|
reducer_cfg = ReducerConfig( |
|
|
in_tokens=cfg['in_tokens'], |
|
|
out_tokens=cfg['out_tokens'], |
|
|
hidden_tokens=cfg.get('hidden_tokens', None), |
|
|
hidden_size=self.llm.config.hidden_size |
|
|
) |
|
|
self.visual_token_reducer = VisualTokenMLPReducer(reducer_cfg).to(self.llm.dtype) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown reducer type: {reducer_type}. " |
|
|
"Supported types: 'conv', 'attention', 'mlp', 'text_guided_attention'") |
|
|
|
|
|
|
|
|
if self.freeze_llm: |
|
|
self.llm.requires_grad_(False) |
|
|
if getattr(self, 'freeze_long_net', False): |
|
|
self.LongNet_encoder.requires_grad_(False) |
|
|
|
|
|
|
|
|
if use_activation_checkpointing: |
|
|
if hasattr(self.llm, 'enable_input_require_grads'): |
|
|
self.llm.enable_input_require_grads() |
|
|
else: |
|
|
self.llm.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
self.projector.enable_input_require_grads() |
|
|
if hasattr(self, 'visual_token_reducer'): |
|
|
self.visual_token_reducer.enable_input_require_grads() |
|
|
self.gradient_checkpointing_enable() |
|
|
|
|
|
|
|
|
self.use_llm_lora = llm_lora is not None |
|
|
self.use_visual_encoder_lora = visual_encoder_lora is not None |
|
|
if self.use_llm_lora: |
|
|
print_log('Using LoRA for LLM', 'current') |
|
|
self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing) |
|
|
|
|
|
|
|
|
if pretrained_pth is not None: |
|
|
state = guess_load_checkpoint(pretrained_pth) |
|
|
self.load_state_dict(state, strict=False) |
|
|
print_log(f'Loaded pretrained weights from {pretrained_pth}', 'current') |
|
|
|
|
|
self.visual_select_layer = visual_select_layer |
|
|
self._is_init = True |
|
|
self.is_first_iter = True |
|
|
|
|
|
|
|
|
def _parse_lora_config(self, lora_config): |
|
|
if isinstance(lora_config, dict) or isinstance( |
|
|
lora_config, Config) or isinstance(lora_config, ConfigDict): |
|
|
lora_config = BUILDER.build(lora_config) |
|
|
return lora_config |
|
|
|
|
|
def _prepare_llm_for_lora(self, |
|
|
lora_config, |
|
|
use_activation_checkpointing=True): |
|
|
lora_config = self._parse_lora_config(lora_config) |
|
|
self.llm = prepare_model_for_kbit_training( |
|
|
self.llm, use_activation_checkpointing) |
|
|
if lora_config.target_modules is None: |
|
|
modules = find_all_linear_names(self.llm) |
|
|
lora_config.target_modules = modules |
|
|
self.llm = get_peft_model(self.llm, lora_config) |
|
|
|
|
|
def _prepare_visual_encoder_for_lora(self, |
|
|
lora_config, |
|
|
use_activation_checkpointing=True): |
|
|
lora_config = self._parse_lora_config(lora_config) |
|
|
if lora_config.target_modules is None: |
|
|
modules = find_all_linear_names(self.visual_encoder) |
|
|
lora_config.target_modules = modules |
|
|
self.visual_encoder = get_peft_model(self.visual_encoder, lora_config) |
|
|
|
|
|
def gradient_checkpointing_enable(self): |
|
|
self.activation_checkpointing_enable() |
|
|
|
|
|
def activation_checkpointing_enable(self): |
|
|
self.llm.gradient_checkpointing_enable() |
|
|
self.projector.gradient_checkpointing_enable() |
|
|
if hasattr(self, 'visual_token_reducer'): |
|
|
self.visual_token_reducer.gradient_checkpointing = True |
|
|
|
|
|
def gradient_checkpointing_disable(self): |
|
|
self.activation_checkpointing_disable() |
|
|
|
|
|
def activation_checkpointing_disable(self): |
|
|
self.llm.gradient_checkpointing_disable() |
|
|
self.projector.gradient_checkpointing_disable() |
|
|
if hasattr(self, 'visual_token_reducer'): |
|
|
self.visual_token_reducer.gradient_checkpointing = False |
|
|
|
|
|
def init_weights(self): |
|
|
pass |
|
|
|
|
|
def state_dict(self, *args, **kwargs): |
|
|
state_dict = super().state_dict(*args, **kwargs) |
|
|
to_return = OrderedDict() |
|
|
|
|
|
if self.use_visual_encoder_lora: |
|
|
to_return.update( |
|
|
get_peft_model_state_dict( |
|
|
self.visual_encoder, state_dict=state_dict)) |
|
|
elif not self.freeze_visual_encoder: |
|
|
to_return.update({ |
|
|
k: v |
|
|
for k, v in state_dict.items() if 'visual_encoder.' in k |
|
|
}) |
|
|
|
|
|
if self.use_llm_lora: |
|
|
to_return.update( |
|
|
get_peft_model_state_dict(self.llm, state_dict=state_dict)) |
|
|
|
|
|
elif not self.freeze_llm: |
|
|
to_return.update( |
|
|
{k: v |
|
|
for k, v in state_dict.items() if 'llm.' in k}) |
|
|
|
|
|
to_return.update( |
|
|
{k: v |
|
|
for k, v in state_dict.items() if 'projector.' in k}) |
|
|
|
|
|
|
|
|
to_return.update( |
|
|
{k: v |
|
|
for k, v in state_dict.items() if 'LongNet_encoder.' in k}) |
|
|
|
|
|
|
|
|
if hasattr(self, 'visual_token_reducer'): |
|
|
to_return.update( |
|
|
{k: v |
|
|
for k, v in state_dict.items() if 'visual_token_reducer.' in k}) |
|
|
|
|
|
return to_return |
|
|
|
|
|
@staticmethod |
|
|
def _prepare_for_long_context_training(cfg, llm_cfg, |
|
|
max_position_embeddings): |
|
|
|
|
|
orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None) |
|
|
if orig_rope_scaling is None: |
|
|
orig_rope_scaling = {'factor': 1} |
|
|
|
|
|
orig_rope_scaling_factor = orig_rope_scaling[ |
|
|
'factor'] if 'factor' in orig_rope_scaling.keys() else 1 |
|
|
orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None) |
|
|
if orig_ctx_len: |
|
|
orig_ctx_len *= orig_rope_scaling_factor |
|
|
if max_position_embeddings > orig_ctx_len: |
|
|
scaling_factor = float( |
|
|
math.ceil(max_position_embeddings / orig_ctx_len)) |
|
|
llm_cfg.rope_scaling = { |
|
|
'type': 'linear', |
|
|
'factor': scaling_factor |
|
|
} |
|
|
|
|
|
llm_cfg.attn_implementation = 'flash_attention_2' |
|
|
cfg.config = llm_cfg |
|
|
|
|
|
return cfg, llm_cfg |
|
|
|
|
|
@staticmethod |
|
|
def _prepare_for_flash_attn(cfg, llm_cfg): |
|
|
cls_name = type(llm_cfg).__name__ |
|
|
SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig', |
|
|
'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig', |
|
|
'Starcoder2Config', 'Starcoder2Config', |
|
|
'Phi3Config') |
|
|
SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig', |
|
|
'MistralConfig', 'MixtralConfig', 'Qwen2Config', |
|
|
'Qwen2MoeConfig', 'Starcoder2Config', |
|
|
'Starcoder2Config', 'Phi3Config') |
|
|
|
|
|
torch_dtype = torch.bfloat16 if ( |
|
|
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ |
|
|
else torch.float16 |
|
|
|
|
|
if getattr(cfg, 'attn_implementation', None) is not None: |
|
|
if cfg.attn_implementation == 'flash_attention_2': |
|
|
cfg.torch_dtype = torch_dtype |
|
|
elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: |
|
|
cfg.torch_dtype = torch_dtype |
|
|
cfg.attn_implementation = 'flash_attention_2' |
|
|
elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN: |
|
|
cfg.attn_implementation = 'sdpa' |
|
|
|
|
|
return cfg, llm_cfg |
|
|
|
|
|
@staticmethod |
|
|
def _prepare_for_qlora_zero3(cfg): |
|
|
if (not is_deepspeed_zero3_enabled()) or (not hasattr( |
|
|
cfg, 'quantization_config')): |
|
|
return cfg |
|
|
|
|
|
torch_dtype = torch.bfloat16 if ( |
|
|
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ |
|
|
else torch.float16 |
|
|
|
|
|
cfg.torch_dtype = torch_dtype |
|
|
quantization_config = cfg.quantization_config |
|
|
quantization_config.bnb_4bit_compute_dtype = torch_dtype |
|
|
quantization_config.bnb_4bit_quant_storage = torch_dtype |
|
|
|
|
|
return cfg |
|
|
|
|
|
def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None): |
|
|
cfg = self._prepare_for_qlora_zero3(cfg) |
|
|
pretrained_model_name_or_path = cfg.pretrained_model_name_or_path |
|
|
llm_cfg = AutoConfig.from_pretrained( |
|
|
pretrained_model_name_or_path, trust_remote_code=True) |
|
|
cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg) |
|
|
if max_position_embeddings is not None: |
|
|
cfg, llm_cfg = self._prepare_for_long_context_training( |
|
|
cfg, llm_cfg, max_position_embeddings) |
|
|
return cfg |
|
|
|
|
|
def _build_from_cfg_or_module(self, cfg_or_mod): |
|
|
if isinstance(cfg_or_mod, nn.Module): |
|
|
return cfg_or_mod |
|
|
elif isinstance(cfg_or_mod, dict): |
|
|
traverse_dict(cfg_or_mod) |
|
|
return BUILDER.build(cfg_or_mod) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def forward(self, data, data_samples=None, mode='loss'): |
|
|
if self.is_first_iter: |
|
|
self.to(data['input_ids'].device) |
|
|
self.is_first_iter = False |
|
|
|
|
|
is_text_guided_reducer = isinstance( |
|
|
getattr(self, 'visual_token_reducer', None), TextGuidedVisualTokenAttentionReducer |
|
|
) |
|
|
|
|
|
if 'pixel_values' in data: |
|
|
feat_to_proj = data['pixel_values'].to(self.llm.dtype) |
|
|
feat_to_proj = self.LongNet_encoder(src_tokens=None, |
|
|
token_embeddings=feat_to_proj.permute(1, 0, 2))["encoder_out"].permute(1, 0, 2) |
|
|
pixel_values = self.projector(feat_to_proj.to(self.llm.dtype)) |
|
|
|
|
|
if hasattr(self, 'visual_token_reducer'): |
|
|
if is_text_guided_reducer: |
|
|
|
|
|
input_ids = data['input_ids'] |
|
|
text_attention_mask = data.get('attention_mask') |
|
|
text_embeddings = self.llm.get_input_embeddings()(input_ids.clamp(min=0)).detach() |
|
|
|
|
|
pixel_values = self.visual_token_reducer( |
|
|
pixel_values, text_embeddings, text_attention_mask |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
pixel_values = self.visual_token_reducer(pixel_values) |
|
|
|
|
|
|
|
|
data['pixel_values'] = pixel_values |
|
|
|
|
|
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data) |
|
|
|
|
|
if mode == 'loss': |
|
|
return self.compute_loss(data, data_samples) |
|
|
elif mode == 'predict': |
|
|
return self.predict(data, data_samples) |
|
|
elif mode == 'tensor': |
|
|
return self._forward(data, data_samples) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def _forward(self, data, data_samples=None): |
|
|
|
|
|
outputs = self.llm(**data) |
|
|
|
|
|
return outputs |
|
|
|
|
|
def predict(self, data, data_samples=None): |
|
|
outputs = self.llm(**data) |
|
|
logits_dict = [{'logits': logits} for logits in outputs.logits] |
|
|
return logits_dict |
|
|
|
|
|
def compute_loss(self, data, data_samples=None): |
|
|
outputs = self.llm(**data) |
|
|
loss_dict = {'loss': outputs.loss} |
|
|
return loss_dict |
|
|
|
|
|
def __getattr__(self, name: str): |
|
|
try: |
|
|
return super().__getattr__(name) |
|
|
except AttributeError: |
|
|
return getattr(self.llm, name) |
|
|
|
|
|
def to_hf(self, |
|
|
cfg, |
|
|
save_dir, |
|
|
fp32=False, |
|
|
save_pretrained_kwargs={}, |
|
|
save_format='xtuner', |
|
|
**kwargs): |
|
|
if save_format == 'xtuner': |
|
|
self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs) |
|
|
elif save_format == 'huggingface': |
|
|
self.to_huggingface_llava(cfg, save_dir, fp32, |
|
|
save_pretrained_kwargs) |
|
|
elif save_format == 'official': |
|
|
self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def to_xtuner_llava(self, |
|
|
cfg, |
|
|
save_dir, |
|
|
fp32=False, |
|
|
save_pretrained_kwargs={}): |
|
|
|
|
|
self.llm.config.use_cache = True |
|
|
if not fp32: |
|
|
self.llm.half() |
|
|
if self.use_llm_lora: |
|
|
llm_adapter = osp.join(save_dir, 'llm_adapter') |
|
|
self.llm.save_pretrained(llm_adapter, **save_pretrained_kwargs) |
|
|
elif not self.freeze_llm: |
|
|
tokenizer = BUILDER.build(cfg.tokenizer) |
|
|
tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs) |
|
|
self.llm.save_pretrained(save_dir, **save_pretrained_kwargs) |
|
|
self.llm.config.use_cache = False |
|
|
|
|
|
|
|
|
if self.use_visual_encoder_lora: |
|
|
visual_adapter = osp.join(save_dir, 'visual_encoder_adapter') |
|
|
self.visual_encoder.save_pretrained(visual_adapter, **save_pretrained_kwargs) |
|
|
elif not self.freeze_visual_encoder: |
|
|
vis_dir = osp.join(save_dir, 'visual_encoder') |
|
|
BUILDER.build(cfg.image_processor).save_pretrained(vis_dir, **save_pretrained_kwargs) |
|
|
self.visual_encoder.save_pretrained(vis_dir, **save_pretrained_kwargs) |
|
|
|
|
|
|
|
|
proj_dir = osp.join(save_dir, 'projector') |
|
|
self.projector.save_pretrained(proj_dir, **save_pretrained_kwargs) |
|
|
|
|
|
|
|
|
if hasattr(self, 'visual_token_reducer'): |
|
|
red_dir = osp.join(save_dir, 'visual_token_reducer') |
|
|
self.visual_token_reducer.save_pretrained(red_dir, **save_pretrained_kwargs) |
|
|
|
|
|
|
|
|
longnet_dir = osp.join(save_dir, 'LongNet_encoder') |
|
|
self.LongNet_encoder.save_pretrained(longnet_dir, **save_pretrained_kwargs) |
|
|
|
|
|
def to_huggingface_llava(self, |
|
|
cfg, |
|
|
save_dir, |
|
|
fp32=False, |
|
|
save_pretrained_kwargs={}): |
|
|
|
|
|
LLM_MAPPING = { |
|
|
'model': 'language_model.model', |
|
|
'lm_head': 'language_model.lm_head', |
|
|
} |
|
|
VIT_MAPPING = { |
|
|
'vision_model': 'vision_tower.vision_model', |
|
|
} |
|
|
PROJECTOR_MAPPING = { |
|
|
'model.0': 'multi_modal_projector.linear_1', |
|
|
'model.2': 'multi_modal_projector.linear_2', |
|
|
} |
|
|
REDUCER_MAPPING = { |
|
|
'query_emb': 'visual_token_reducer.query_emb', |
|
|
'cross_attn.in_proj_weight': 'visual_token_reducer.cross_attn.in_proj_weight', |
|
|
'cross_attn.in_proj_bias': 'visual_token_reducer.cross_attn.in_proj_bias', |
|
|
'cross_attn.out_proj.weight':'visual_token_reducer.cross_attn.out_proj.weight', |
|
|
'cross_attn.out_proj.bias': 'visual_token_reducer.cross_attn.out_proj.bias' |
|
|
} |
|
|
LONGNET_MAPPING = { |
|
|
'layers.0': 'LongNet_encoder.layers.0', |
|
|
'layers.1': 'LongNet_encoder.layers.1', |
|
|
'layer_norm': 'LongNet_encoder.layer_norm' |
|
|
} |
|
|
|
|
|
assert getattr(self.llm, 'hf_quantizer', None) is None, \ |
|
|
'This conversion format does not support quantized LLM.' |
|
|
|
|
|
|
|
|
llm = self.llm |
|
|
if self.use_llm_lora: |
|
|
llm = self.llm.merge_and_unload() |
|
|
llm.config.use_cache = True |
|
|
if not fp32: |
|
|
print_log('Convert LLM to float16', 'current') |
|
|
llm.half() |
|
|
|
|
|
assert isinstance(llm, LlamaForCausalLM), \ |
|
|
'This conversion format only supports LlamaForCausalLM.' |
|
|
llm_state_dict = llm.state_dict() |
|
|
llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING) |
|
|
|
|
|
need_visual_encoder = (not self.freeze_visual_encoder |
|
|
or self.use_visual_encoder_lora) |
|
|
visual_encoder = self.visual_encoder |
|
|
if self.use_visual_encoder_lora: |
|
|
visual_encoder = self.visual_encoder.merge_and_unload() |
|
|
assert isinstance(visual_encoder, CLIPVisionModel),\ |
|
|
'This conversion format only supports CLIPVisionModel.' |
|
|
if need_visual_encoder: |
|
|
visual_encoder_state_dict = visual_encoder.state_dict() |
|
|
visual_encoder_state_dict = convert_state_dict_to_hf( |
|
|
visual_encoder_state_dict, VIT_MAPPING) |
|
|
else: |
|
|
visual_encoder_state_dict = {} |
|
|
|
|
|
projector_state_dict = self.projector.state_dict() |
|
|
projector_state_dict = convert_state_dict_to_hf( |
|
|
projector_state_dict, PROJECTOR_MAPPING) |
|
|
|
|
|
LongNet_encoder_state_dict = self.LongNet_encoder.state_dict() |
|
|
LongNet_encoder_state_dict = convert_state_dict_to_hf( |
|
|
LongNet_encoder_state_dict, LONGNET_MAPPING) |
|
|
|
|
|
red_state = convert_state_dict_to_hf( |
|
|
self.visual_token_reducer.state_dict(), REDUCER_MAPPING |
|
|
) if hasattr(self, 'visual_token_reducer') else {} |
|
|
|
|
|
state_dict = { |
|
|
**projector_state_dict, |
|
|
**llm_state_dict, |
|
|
**visual_encoder_state_dict, |
|
|
**LongNet_encoder_state_dict, |
|
|
**red_state |
|
|
} |
|
|
|
|
|
|
|
|
text_config = llm.config |
|
|
vision_config = visual_encoder.config |
|
|
config = LlavaConfig( |
|
|
text_config=text_config, |
|
|
vision_config=vision_config, |
|
|
attn_implementation='eager') |
|
|
|
|
|
with init_empty_weights(): |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings( |
|
|
'ignore', message='.*non-meta.*', category=UserWarning) |
|
|
model = LlavaForConditionalGeneration(config) |
|
|
model.load_state_dict(state_dict, strict=True, assign=True) |
|
|
|
|
|
|
|
|
cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained |
|
|
tokenizer = BUILDER.build(cfg.tokenizer) |
|
|
|
|
|
tokenizer.add_tokens( |
|
|
AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False), |
|
|
special_tokens=True) |
|
|
tokenizer.add_special_tokens({'pad_token': '<pad>'}) |
|
|
|
|
|
image_processor = BUILDER.build(cfg.image_processor) |
|
|
assert isinstance(image_processor, CLIPImageProcessor),\ |
|
|
'This conversion format only supports CLIPImageProcessor.' |
|
|
|
|
|
processor = LlavaProcessor( |
|
|
tokenizer=tokenizer, image_processor=image_processor) |
|
|
|
|
|
|
|
|
pad_shape = 64 |
|
|
|
|
|
pre_expansion_embeddings = \ |
|
|
model.language_model.model.embed_tokens.weight.data |
|
|
mu = torch.mean(pre_expansion_embeddings, dim=0).float() |
|
|
n = pre_expansion_embeddings.size()[0] |
|
|
sigma = ((pre_expansion_embeddings - mu).T |
|
|
@ (pre_expansion_embeddings - mu)) / n |
|
|
dist = torch.distributions.multivariate_normal.MultivariateNormal( |
|
|
mu, covariance_matrix=1e-5 * sigma) |
|
|
|
|
|
|
|
|
ori_vocab_size = config.text_config.vocab_size |
|
|
tokenizer_vocab_size = tokenizer.encode('<pad>')[-1] |
|
|
added_token = tokenizer_vocab_size - ori_vocab_size |
|
|
|
|
|
if added_token > 0: |
|
|
model.resize_token_embeddings(ori_vocab_size + added_token, |
|
|
pad_shape) |
|
|
model.language_model.model.embed_tokens.weight.data[ |
|
|
ori_vocab_size:] = torch.stack( |
|
|
tuple( |
|
|
dist.sample() |
|
|
for _ in range(model.language_model.model.embed_tokens. |
|
|
weight.data[ori_vocab_size:].shape[0])), |
|
|
dim=0, |
|
|
) |
|
|
model.language_model.lm_head.weight.data[ |
|
|
ori_vocab_size:] = torch.stack( |
|
|
tuple(dist.sample() |
|
|
for _ in range(model.language_model.lm_head.weight. |
|
|
data[ori_vocab_size:].shape[0])), |
|
|
dim=0, |
|
|
) |
|
|
model.config.image_token_index = tokenizer.encode( |
|
|
DEFAULT_IMAGE_TOKEN)[-1] |
|
|
model.config.pad_token_id = tokenizer.encode('<pad>')[-1] |
|
|
|
|
|
|
|
|
print_log(f'Saving to {save_dir}', 'current') |
|
|
model.save_pretrained(save_dir, **save_pretrained_kwargs) |
|
|
processor.save_pretrained(save_dir, **save_pretrained_kwargs) |
|
|
|
|
|
def to_official_llava(self, |
|
|
cfg, |
|
|
save_dir, |
|
|
fp32=False, |
|
|
save_pretrained_kwargs={}): |
|
|
|
|
|
VIT_MAPPING = { |
|
|
'vision_model': 'model.vision_tower.vision_tower.vision_model', |
|
|
} |
|
|
PROJECTOR_MAPPING = { |
|
|
'model.0': 'model.mm_projector.0', |
|
|
'model.2': 'model.mm_projector.2', |
|
|
} |
|
|
LONGNET_MAPPING = { |
|
|
'layers.0': 'LongNet_encoder.layers.0', |
|
|
'layers.1': 'LongNet_encoder.layers.1', |
|
|
'layer_norm': 'LongNet_encoder.layer_norm' |
|
|
} |
|
|
REDUCER_MAPPING = { |
|
|
'query_emb': 'visual_token_reducer.query_emb', |
|
|
'cross_attn.in_proj_weight': 'visual_token_reducer.cross_attn.in_proj_weight', |
|
|
'cross_attn.in_proj_bias': 'visual_token_reducer.cross_attn.in_proj_bias', |
|
|
'cross_attn.out_proj.weight':'visual_token_reducer.cross_attn.out_proj.weight', |
|
|
'cross_attn.out_proj.bias': 'visual_token_reducer.cross_attn.out_proj.bias' |
|
|
} |
|
|
|
|
|
try: |
|
|
from llava.model import LlavaConfig, LlavaLlamaForCausalLM |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
'Please install llava with ' |
|
|
'`pip install git+https://github.com/haotian-liu/LLaVA.git ' |
|
|
'--no-deps`.') |
|
|
|
|
|
assert getattr(self.llm, 'hf_quantizer', None) is None, \ |
|
|
'This conversion format does not support quantized LLM.' |
|
|
|
|
|
|
|
|
llm = self.llm |
|
|
if self.use_llm_lora: |
|
|
llm = self.llm.merge_and_unload() |
|
|
llm.config.use_cache = True |
|
|
if not fp32: |
|
|
print_log('Convert LLM to float16', 'current') |
|
|
llm.half() |
|
|
|
|
|
assert isinstance(llm, LlamaForCausalLM), \ |
|
|
'This conversion format only supports LlamaForCausalLM.' |
|
|
llm_state_dict = llm.state_dict() |
|
|
|
|
|
need_visual_encoder = (not self.freeze_visual_encoder |
|
|
or self.use_visual_encoder_lora) |
|
|
visual_encoder = self.visual_encoder |
|
|
if self.use_visual_encoder_lora: |
|
|
visual_encoder = self.visual_encoder.merge_and_unload() |
|
|
assert isinstance(visual_encoder, CLIPVisionModel),\ |
|
|
'This conversion format only supports CLIPVisionModel.' |
|
|
if need_visual_encoder: |
|
|
visual_encoder_state_dict = visual_encoder.state_dict() |
|
|
visual_encoder_state_dict = convert_state_dict_to_hf( |
|
|
visual_encoder_state_dict, VIT_MAPPING) |
|
|
else: |
|
|
visual_encoder_state_dict = {} |
|
|
|
|
|
projector_state_dict = self.projector.state_dict() |
|
|
projector_state_dict = convert_state_dict_to_hf( |
|
|
projector_state_dict, PROJECTOR_MAPPING) |
|
|
|
|
|
LongNet_encoder_state_dict = self.LongNet_encoder.state_dict() |
|
|
LongNet_encoder_state_dict = convert_state_dict_to_hf( |
|
|
LongNet_encoder_state_dict, LONGNET_MAPPING) |
|
|
|
|
|
red_state = convert_state_dict_to_hf( |
|
|
self.visual_token_reducer.state_dict(), REDUCER_MAPPING |
|
|
) if hasattr(self, 'visual_token_reducer') else {} |
|
|
|
|
|
state_dict = { |
|
|
**projector_state_dict, |
|
|
**llm_state_dict, |
|
|
**visual_encoder_state_dict, |
|
|
**LongNet_encoder_state_dict, |
|
|
**red_state |
|
|
} |
|
|
|
|
|
|
|
|
tokenizer = BUILDER.build(cfg.tokenizer) |
|
|
image_processor = BUILDER.build(cfg.image_processor) |
|
|
assert isinstance(image_processor, CLIPImageProcessor),\ |
|
|
'This conversion format only supports CLIPImageProcessor.' |
|
|
|
|
|
llava_config_dict = llm.config.__dict__.copy() |
|
|
llava_config_dict.update( |
|
|
dict( |
|
|
image_aspect_ratio='pad', |
|
|
mm_hidden_size=visual_encoder.config.hidden_size, |
|
|
mm_projector_type=f'mlp{self.projector_depth}x_gelu', |
|
|
mm_use_im_patch_token=False, |
|
|
mm_use_im_start_end=False, |
|
|
mm_vision_select_feature='patch', |
|
|
mm_vision_select_layer=self.visual_select_layer, |
|
|
mm_vision_tower=visual_encoder.config.name_or_path, |
|
|
unfreeze_mm_vision_tower=need_visual_encoder, |
|
|
model_type='llava', |
|
|
use_cache=True, |
|
|
use_mm_proj=True)) |
|
|
|
|
|
llava_config = LlavaConfig(**llava_config_dict) |
|
|
|
|
|
with init_empty_weights(): |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings( |
|
|
'ignore', message='.*non-meta.*', category=UserWarning) |
|
|
model = LlavaLlamaForCausalLM(llava_config) |
|
|
|
|
|
model.load_state_dict(state_dict, strict=True, assign=True) |
|
|
|
|
|
|
|
|
print_log(f'Saving to {save_dir}', 'current') |
|
|
|
|
|
model.save_pretrained(save_dir, **save_pretrained_kwargs) |
|
|
image_processor.save_pretrained(save_dir, **save_pretrained_kwargs) |
|
|
tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs) |