pure_model_weights / code /xtuner /model /llava_dim_reducer.py
WinstonHu's picture
Upload folder xtuner to code/xtuner
e5e24c9 verified
raw
history blame
39.6 kB
import math
import os.path as osp
import warnings
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from accelerate import init_empty_weights
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel
from peft import get_peft_model, prepare_model_for_kbit_training
from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
CLIPVisionModel, LlamaForCausalLM,
LlamaTokenizerFast, LlavaConfig,
LlavaForConditionalGeneration, LlavaProcessor,
PreTrainedModel, PretrainedConfig)
from transformers.integrations import is_deepspeed_zero3_enabled
from xtuner.registry import BUILDER
from xtuner.utils import DEFAULT_IMAGE_TOKEN
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
from .utils import (LoadWoInit, find_all_linear_names,
get_peft_model_state_dict, guess_load_checkpoint,
make_inputs_require_grad,
prepare_inputs_labels_for_multimodal, traverse_dict)
from .torchscale.model.LongNet import make_longnet_from_name
def convert_state_dict_to_hf(state_dict, mapping):
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith('.inv_freq'):
continue
for key_to_modify, new_key in mapping.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
new_state_dict[key] = value
return new_state_dict
class ReducerConfig(PretrainedConfig):
model_type = 'Reducer'
_auto_class = 'AutoConfig'
def __init__(
self,
in_tokens=4096,
out_tokens=2048,
hidden_tokens=1024,
num_queries = 2048,
num_heads = 8,
hidden_size = 3584,
kernel_size=4,
stride=4,
**kwargs,
):
self.in_tokens = in_tokens
self.out_tokens = out_tokens
self.hidden_tokens = hidden_tokens
self.kernel_size = kernel_size,
self.stride = stride,
if self.hidden_tokens is None:
self.hidden_tokens = max(self.in_tokens // 2, self.out_tokens)
self.hidden_size = hidden_size
self.num_queries = num_queries
self.num_heads = num_heads
super().__init__(**kwargs)
class VisualTokenConvReducer(PreTrainedModel):
"""
Wraps a Conv1d reducer with activation-checkpointing and input-grad support.
Handles (B, T, D) inputs by transposing internally for Conv1d.
"""
supports_gradient_checkpointing = True
def __init__(self, config: ReducerConfig):
super().__init__(config)
self.conv = nn.Conv1d(
in_channels=config.hidden_size,
out_channels=config.hidden_size,
kernel_size=config.kernel_size,
stride=config.stride
)
self.gradient_checkpointing = False
def enable_input_require_grads(self):
def make_inputs_require_grad(module, inputs, output):
output.requires_grad_(True)
self.conv.register_forward_hook(make_inputs_require_grad)
def forward(self, x):
"""
x: [batch, tokens_in, features]
returns: [batch, tokens_out, features]
"""
# x shape: [B, T_in, D], Conv1d expects [B, D, T_in]
x = x.transpose(1, 2)
if self.gradient_checkpointing and self.training:
out = cp.checkpoint(self.conv, x)
else:
out = self.conv(x)
# Output from conv is [B, D, T_out], transpose back to [B, T_out, D]
out = out.transpose(1, 2)
return out
class VisualTokenMLPReducer(PreTrainedModel):
"""
MLP-based token reducer. Handles (B, T, D) inputs by transposing internally.
"""
base_model_prefix = "visual_token_mlp_reducer"
supports_gradient_checkpointing = True
def __init__(self, cfg:ReducerConfig):
super().__init__(cfg)
self.in_tokens = cfg.in_tokens
self.out_tokens = cfg.out_tokens
self.hidden_tokens = cfg.hidden_tokens
# Two-layer MLP on the token dimension
self.fc1 = nn.Linear(self.in_tokens, self.hidden_tokens, bias=True)
self.act = nn.GELU()
self.fc2 = nn.Linear(self.hidden_tokens, self.out_tokens, bias=True)
self.gradient_checkpointing = False
def enable_input_require_grads(self):
def make_inputs_require_grad(module, inputs, output):
output.requires_grad_(True)
self.fc1.register_forward_hook(make_inputs_require_grad)
def forward(self, x):
"""
x: [batch, tokens_in, features]
returns: [batch, tokens_out, features]
"""
B, T, D = x.shape
# Permute to [B, D, T] to apply MLP on the token dimension
x = x.transpose(1, 2)
# Fold features into batch dim
x = x.reshape(B * D, T)
if self.gradient_checkpointing and self.training:
out = cp.checkpoint(self._mlp, x)
else:
out = self._mlp(x)
# Unfold back to [B, D, T']
out = out.view(B, D, self.out_tokens)
# Permute back to [B, T', D]
out = out.transpose(1, 2)
return out
def _mlp(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class VisualTokenAttentionReducer(PreTrainedModel):
base_model_prefix = "visual_token_attention_reducer"
supports_gradient_checkpointing = True
def __init__(self, config: ReducerConfig):
super().__init__(config)
# M learnable queries:
self.query_emb = nn.Parameter(torch.randn(config.num_queries, config.hidden_size))
# cross-attention:
self.cross_attn = nn.MultiheadAttention(
embed_dim=config.hidden_size,
num_heads=config.num_heads,
batch_first=True # Process (B, T, D) inputs directly
)
self.gradient_checkpointing = False
def enable_input_require_grads(self):
def make_inputs_require_grad(module, inputs, output):
if isinstance(output, (tuple, list)):
output[0].requires_grad_(True)
else:
output.requires_grad_(True)
self.cross_attn.register_forward_hook(make_inputs_require_grad)
def forward(self, x):
"""
x: (B, T_in, D) → we want (B, M_out, D)
"""
B, T, D = x.shape
# K and V are the input visual tokens, already in (B, T, D)
tokens = x
# expand queries to (B, M, D)
Q = self.query_emb.unsqueeze(0).expand(B, -1, -1)
if self.gradient_checkpointing and self.training:
out = cp.checkpoint(self._attn, Q, tokens, tokens)
else:
out = self._attn(Q, tokens, tokens)
# out: (B, M, D) - no final transpose needed
return out
def _attn(self, Q, K, V):
# returns (B, M, D)
out, _ = self.cross_attn(Q, K, V)
return out
class TextGuidedVisualTokenAttentionReducer(PreTrainedModel):
"""
An enhanced attention-based token reducer that uses text tokens to guide
the compression of visual tokens, operating on batch-first tensors.
"""
base_model_prefix = "text_guided_visual_token_attention_reducer"
supports_gradient_checkpointing = True
def __init__(self, config: ReducerConfig):
super().__init__(config)
self.query_emb = nn.Parameter(
torch.randn(config.num_queries, config.hidden_size))
self.norm_kv = nn.LayerNorm(config.hidden_size)
self.cross_attn = nn.MultiheadAttention(
embed_dim=config.hidden_size,
num_heads=config.num_heads,
batch_first=True # Expects (Batch, Seq_len, Dim)
)
self.norm_ffn = nn.LayerNorm(config.hidden_size)
self.ffn = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size * 4),
nn.GELU(),
nn.Linear(config.hidden_size * 4, config.hidden_size)
)
self.gradient_checkpointing = False
def enable_input_require_grads(self):
def make_inputs_require_grad(module, inputs, output):
if isinstance(output, tuple):
output[0].requires_grad_(True)
else:
output.requires_grad_(True)
self.cross_attn.register_forward_hook(make_inputs_require_grad)
def forward(self, visual_tokens, text_tokens, text_attention_mask=None):
"""
Performs text-guided reduction of visual tokens.
Args:
visual_tokens (torch.Tensor): Visual tokens of shape (B, T_visual, D).
text_tokens (torch.Tensor): Text token embeddings of shape (B, T_text, D).
text_attention_mask (torch.Tensor): Mask for text tokens.
Returns:
torch.Tensor: Compressed visual tokens of shape (B, M_out, D).
"""
B, T_visual, D = visual_tokens.shape
# Concatenate along the sequence dimension to form Key (K) and Value (V)
kv_tokens = torch.cat([visual_tokens, text_tokens], dim=1)
# print_log(f'kv tokens shape: {kv_tokens.shape}', 'current')
key_padding_mask = None
if text_attention_mask is not None:
visual_padding_mask = torch.ones(
B, T_visual, dtype=torch.bool, device=visual_tokens.device)
combined_mask = torch.cat([visual_padding_mask, text_attention_mask], dim=1)
key_padding_mask = (combined_mask == 0)
# Prepare queries (Q) with batch dimension first: (B, M, D)
queries = self.query_emb.unsqueeze(0).expand(B, -1, -1)
# --- Main Forward Pass ---
attn_output, _ = self.cross_attn(
query=queries,
key=self.norm_kv(kv_tokens),
value=self.norm_kv(kv_tokens),
key_padding_mask=key_padding_mask
)
queries = queries + attn_output # Residual connection
ffn_output = self.ffn(self.norm_ffn(queries))
queries = queries + ffn_output # Residual connection
# Final output is already (B, M, D)
return queries
class LLaVAModelWithReducer(BaseModel):
def __init__(self,
llm,
freeze_llm=True,
visual_select_layer=-2,
pretrained_pth=None,
projector_depth=2,
llm_lora=None,
visual_encoder_lora=None,
use_activation_checkpointing=True,
max_position_embeddings=None,
hidden_size=512,
train_stage='2',
enable_long_net=True,
visual_token_reducer_config=None):
super().__init__()
self.enable_long_net = enable_long_net
self.freeze_llm = freeze_llm
self.freeze_visual_encoder = True
if train_stage == '1':
self.freeze_llm = True
self.freeze_long_net = False
elif train_stage == '2':
self.freeze_llm = True
self.freeze_long_net = True
with LoadWoInit():
if isinstance(llm, dict):
llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
self.llm = self._build_from_cfg_or_module(llm)
self.encoder_name = f"LongNet_{2}_layers_{512}_dim"
self.LongNet_encoder = make_longnet_from_name(self.encoder_name)
self.llm.config.use_cache = False
dispatch_modules(self.llm)
# Projector
self.projector_depth = projector_depth
proj_cfg = ProjectorConfig(
visual_hidden_size=hidden_size,
llm_hidden_size=self.llm.config.hidden_size,
depth=self.projector_depth)
self.projector = ProjectorModel(proj_cfg).to(self.llm.dtype)
# Visual Token Reducer
if visual_token_reducer_config:
cfg = visual_token_reducer_config
reducer_type = cfg.get('type', 'attention') # Default to attention
if reducer_type == 'conv':
reducer_cfg = ReducerConfig(
hidden_size= self.llm.config.hidden_size,
kernel_size=cfg['kernel_size'],
stride=cfg['stride'],
)
self.visual_token_reducer = VisualTokenConvReducer(
reducer_cfg
).to(self.llm.dtype)
elif reducer_type == 'attention':
reducer_cfg = ReducerConfig(
in_tokens=cfg['in_tokens'],
out_tokens=cfg['out_tokens'],
hidden_tokens=cfg.get('hidden_tokens', None),
num_heads=cfg.get('num_heads', 8),
num_queries=cfg.get('num_queries', 2048),
hidden_size=self.llm.config.hidden_size
)
self.visual_token_reducer = VisualTokenAttentionReducer(reducer_cfg).to(self.llm.dtype)
elif reducer_type == 'text_guided_attention':
reducer_cfg = ReducerConfig(
in_tokens=cfg['in_tokens'],
out_tokens=cfg['out_tokens'],
hidden_tokens=cfg.get('hidden_tokens', None),
num_heads=cfg.get('num_heads', 8),
num_queries=cfg.get('num_queries', 2048),
hidden_size=self.llm.config.hidden_size
)
self.visual_token_reducer = TextGuidedVisualTokenAttentionReducer(reducer_cfg).to(self.llm.dtype)
elif reducer_type == 'mlp':
reducer_cfg = ReducerConfig(
in_tokens=cfg['in_tokens'],
out_tokens=cfg['out_tokens'],
hidden_tokens=cfg.get('hidden_tokens', None),
hidden_size=self.llm.config.hidden_size
)
self.visual_token_reducer = VisualTokenMLPReducer(reducer_cfg).to(self.llm.dtype)
else:
raise ValueError(f"Unknown reducer type: {reducer_type}. "
"Supported types: 'conv', 'attention', 'mlp', 'text_guided_attention'")
# Freezing
if self.freeze_llm:
self.llm.requires_grad_(False)
if getattr(self, 'freeze_long_net', False):
self.LongNet_encoder.requires_grad_(False)
# Activation / gradient checkpointing
if use_activation_checkpointing:
if hasattr(self.llm, 'enable_input_require_grads'):
self.llm.enable_input_require_grads()
else:
self.llm.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
self.projector.enable_input_require_grads()
if hasattr(self, 'visual_token_reducer'):
self.visual_token_reducer.enable_input_require_grads()
self.gradient_checkpointing_enable()
# LoRA
self.use_llm_lora = llm_lora is not None
self.use_visual_encoder_lora = visual_encoder_lora is not None
if self.use_llm_lora:
print_log('Using LoRA for LLM', 'current')
self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
# Load pretrained
if pretrained_pth is not None:
state = guess_load_checkpoint(pretrained_pth)
self.load_state_dict(state, strict=False)
print_log(f'Loaded pretrained weights from {pretrained_pth}', 'current')
self.visual_select_layer = visual_select_layer
self._is_init = True
self.is_first_iter = True
def _parse_lora_config(self, lora_config):
if isinstance(lora_config, dict) or isinstance(
lora_config, Config) or isinstance(lora_config, ConfigDict):
lora_config = BUILDER.build(lora_config)
return lora_config
def _prepare_llm_for_lora(self,
lora_config,
use_activation_checkpointing=True):
lora_config = self._parse_lora_config(lora_config)
self.llm = prepare_model_for_kbit_training(
self.llm, use_activation_checkpointing)
if lora_config.target_modules is None:
modules = find_all_linear_names(self.llm)
lora_config.target_modules = modules
self.llm = get_peft_model(self.llm, lora_config)
def _prepare_visual_encoder_for_lora(self,
lora_config,
use_activation_checkpointing=True):
lora_config = self._parse_lora_config(lora_config)
if lora_config.target_modules is None:
modules = find_all_linear_names(self.visual_encoder)
lora_config.target_modules = modules
self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
def gradient_checkpointing_enable(self):
self.activation_checkpointing_enable()
def activation_checkpointing_enable(self):
self.llm.gradient_checkpointing_enable()
self.projector.gradient_checkpointing_enable()
if hasattr(self, 'visual_token_reducer'):
self.visual_token_reducer.gradient_checkpointing = True
def gradient_checkpointing_disable(self):
self.activation_checkpointing_disable()
def activation_checkpointing_disable(self):
self.llm.gradient_checkpointing_disable()
self.projector.gradient_checkpointing_disable()
if hasattr(self, 'visual_token_reducer'):
self.visual_token_reducer.gradient_checkpointing = False
def init_weights(self):
pass
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
to_return = OrderedDict()
# Step 1. visual_encoder
if self.use_visual_encoder_lora:
to_return.update(
get_peft_model_state_dict(
self.visual_encoder, state_dict=state_dict))
elif not self.freeze_visual_encoder:
to_return.update({
k: v
for k, v in state_dict.items() if 'visual_encoder.' in k
})
# Step 2. LLM
if self.use_llm_lora:
to_return.update(
get_peft_model_state_dict(self.llm, state_dict=state_dict))
elif not self.freeze_llm:
to_return.update(
{k: v
for k, v in state_dict.items() if 'llm.' in k})
# Step 3. Projector
to_return.update(
{k: v
for k, v in state_dict.items() if 'projector.' in k})
# Step 4. LongNet_encoder
to_return.update(
{k: v
for k, v in state_dict.items() if 'LongNet_encoder.' in k})
# Step 5. Visual Token Reducer
if hasattr(self, 'visual_token_reducer'):
to_return.update(
{k: v
for k, v in state_dict.items() if 'visual_token_reducer.' in k})
return to_return
@staticmethod
def _prepare_for_long_context_training(cfg, llm_cfg,
max_position_embeddings):
orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
if orig_rope_scaling is None:
orig_rope_scaling = {'factor': 1}
orig_rope_scaling_factor = orig_rope_scaling[
'factor'] if 'factor' in orig_rope_scaling.keys() else 1
orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
if orig_ctx_len:
orig_ctx_len *= orig_rope_scaling_factor
if max_position_embeddings > orig_ctx_len:
scaling_factor = float(
math.ceil(max_position_embeddings / orig_ctx_len))
llm_cfg.rope_scaling = {
'type': 'linear',
'factor': scaling_factor
}
llm_cfg.attn_implementation = 'flash_attention_2'
cfg.config = llm_cfg
return cfg, llm_cfg
@staticmethod
def _prepare_for_flash_attn(cfg, llm_cfg):
cls_name = type(llm_cfg).__name__
SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
'Starcoder2Config', 'Starcoder2Config',
'Phi3Config')
SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
'MistralConfig', 'MixtralConfig', 'Qwen2Config',
'Qwen2MoeConfig', 'Starcoder2Config',
'Starcoder2Config', 'Phi3Config')
torch_dtype = torch.bfloat16 if (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
else torch.float16
if getattr(cfg, 'attn_implementation', None) is not None:
if cfg.attn_implementation == 'flash_attention_2':
cfg.torch_dtype = torch_dtype
elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
cfg.torch_dtype = torch_dtype
cfg.attn_implementation = 'flash_attention_2'
elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
cfg.attn_implementation = 'sdpa'
return cfg, llm_cfg
@staticmethod
def _prepare_for_qlora_zero3(cfg):
if (not is_deepspeed_zero3_enabled()) or (not hasattr(
cfg, 'quantization_config')):
return cfg
torch_dtype = torch.bfloat16 if (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
else torch.float16
cfg.torch_dtype = torch_dtype
quantization_config = cfg.quantization_config
quantization_config.bnb_4bit_compute_dtype = torch_dtype
quantization_config.bnb_4bit_quant_storage = torch_dtype
return cfg
def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
cfg = self._prepare_for_qlora_zero3(cfg)
pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
llm_cfg = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True)
cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
if max_position_embeddings is not None:
cfg, llm_cfg = self._prepare_for_long_context_training(
cfg, llm_cfg, max_position_embeddings)
return cfg
def _build_from_cfg_or_module(self, cfg_or_mod):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
elif isinstance(cfg_or_mod, dict):
traverse_dict(cfg_or_mod)
return BUILDER.build(cfg_or_mod)
else:
raise NotImplementedError
def forward(self, data, data_samples=None, mode='loss'):
if self.is_first_iter:
self.to(data['input_ids'].device)
self.is_first_iter = False
is_text_guided_reducer = isinstance(
getattr(self, 'visual_token_reducer', None), TextGuidedVisualTokenAttentionReducer
)
if 'pixel_values' in data:
feat_to_proj = data['pixel_values'].to(self.llm.dtype)
feat_to_proj = self.LongNet_encoder(src_tokens=None,
token_embeddings=feat_to_proj.permute(1, 0, 2))["encoder_out"].permute(1, 0, 2)
pixel_values = self.projector(feat_to_proj.to(self.llm.dtype))
if hasattr(self, 'visual_token_reducer'):
if is_text_guided_reducer:
# Get text embeddings and attention mask for the guided reducer
input_ids = data['input_ids']
text_attention_mask = data.get('attention_mask')
text_embeddings = self.llm.get_input_embeddings()(input_ids.clamp(min=0)).detach()
pixel_values = self.visual_token_reducer(
pixel_values, text_embeddings, text_attention_mask
)
else:
# Input to reducer is now (B, T, D)
pixel_values = self.visual_token_reducer(pixel_values)
# print_log(f'Visual tokens reduced to shape: {pixel_values.shape}', 'current')
data['pixel_values'] = pixel_values
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
if mode == 'loss':
return self.compute_loss(data, data_samples)
elif mode == 'predict':
return self.predict(data, data_samples)
elif mode == 'tensor':
return self._forward(data, data_samples)
else:
raise NotImplementedError
def _forward(self, data, data_samples=None):
outputs = self.llm(**data)
return outputs
def predict(self, data, data_samples=None):
outputs = self.llm(**data)
logits_dict = [{'logits': logits} for logits in outputs.logits]
return logits_dict
def compute_loss(self, data, data_samples=None):
outputs = self.llm(**data)
loss_dict = {'loss': outputs.loss}
return loss_dict
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.llm, name)
def to_hf(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={},
save_format='xtuner',
**kwargs):
if save_format == 'xtuner':
self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
elif save_format == 'huggingface':
self.to_huggingface_llava(cfg, save_dir, fp32,
save_pretrained_kwargs)
elif save_format == 'official':
self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
else:
raise NotImplementedError
def to_xtuner_llava(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={}):
# LLM
self.llm.config.use_cache = True
if not fp32:
self.llm.half()
if self.use_llm_lora:
llm_adapter = osp.join(save_dir, 'llm_adapter')
self.llm.save_pretrained(llm_adapter, **save_pretrained_kwargs)
elif not self.freeze_llm:
tokenizer = BUILDER.build(cfg.tokenizer)
tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
self.llm.save_pretrained(save_dir, **save_pretrained_kwargs)
self.llm.config.use_cache = False
# Visual Encoder
if self.use_visual_encoder_lora:
visual_adapter = osp.join(save_dir, 'visual_encoder_adapter')
self.visual_encoder.save_pretrained(visual_adapter, **save_pretrained_kwargs)
elif not self.freeze_visual_encoder:
vis_dir = osp.join(save_dir, 'visual_encoder')
BUILDER.build(cfg.image_processor).save_pretrained(vis_dir, **save_pretrained_kwargs)
self.visual_encoder.save_pretrained(vis_dir, **save_pretrained_kwargs)
# Projector
proj_dir = osp.join(save_dir, 'projector')
self.projector.save_pretrained(proj_dir, **save_pretrained_kwargs)
# Visual Token Reducer
if hasattr(self, 'visual_token_reducer'):
red_dir = osp.join(save_dir, 'visual_token_reducer')
self.visual_token_reducer.save_pretrained(red_dir, **save_pretrained_kwargs)
# LongNet_encoder
longnet_dir = osp.join(save_dir, 'LongNet_encoder')
self.LongNet_encoder.save_pretrained(longnet_dir, **save_pretrained_kwargs)
def to_huggingface_llava(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={}):
LLM_MAPPING = {
'model': 'language_model.model',
'lm_head': 'language_model.lm_head',
}
VIT_MAPPING = {
'vision_model': 'vision_tower.vision_model',
}
PROJECTOR_MAPPING = {
'model.0': 'multi_modal_projector.linear_1',
'model.2': 'multi_modal_projector.linear_2',
}
REDUCER_MAPPING = {
'query_emb': 'visual_token_reducer.query_emb',
'cross_attn.in_proj_weight': 'visual_token_reducer.cross_attn.in_proj_weight',
'cross_attn.in_proj_bias': 'visual_token_reducer.cross_attn.in_proj_bias',
'cross_attn.out_proj.weight':'visual_token_reducer.cross_attn.out_proj.weight',
'cross_attn.out_proj.bias': 'visual_token_reducer.cross_attn.out_proj.bias'
}
LONGNET_MAPPING = {
'layers.0': 'LongNet_encoder.layers.0',
'layers.1': 'LongNet_encoder.layers.1',
'layer_norm': 'LongNet_encoder.layer_norm'
}
assert getattr(self.llm, 'hf_quantizer', None) is None, \
'This conversion format does not support quantized LLM.'
# get state_dict
llm = self.llm
if self.use_llm_lora:
llm = self.llm.merge_and_unload()
llm.config.use_cache = True
if not fp32:
print_log('Convert LLM to float16', 'current')
llm.half()
assert isinstance(llm, LlamaForCausalLM), \
'This conversion format only supports LlamaForCausalLM.'
llm_state_dict = llm.state_dict()
llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
need_visual_encoder = (not self.freeze_visual_encoder
or self.use_visual_encoder_lora)
visual_encoder = self.visual_encoder
if self.use_visual_encoder_lora:
visual_encoder = self.visual_encoder.merge_and_unload()
assert isinstance(visual_encoder, CLIPVisionModel),\
'This conversion format only supports CLIPVisionModel.'
if need_visual_encoder:
visual_encoder_state_dict = visual_encoder.state_dict()
visual_encoder_state_dict = convert_state_dict_to_hf(
visual_encoder_state_dict, VIT_MAPPING)
else:
visual_encoder_state_dict = {}
projector_state_dict = self.projector.state_dict()
projector_state_dict = convert_state_dict_to_hf(
projector_state_dict, PROJECTOR_MAPPING)
LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
LongNet_encoder_state_dict = convert_state_dict_to_hf(
LongNet_encoder_state_dict, LONGNET_MAPPING)
red_state = convert_state_dict_to_hf(
self.visual_token_reducer.state_dict(), REDUCER_MAPPING
) if hasattr(self, 'visual_token_reducer') else {}
state_dict = {
**projector_state_dict,
**llm_state_dict,
**visual_encoder_state_dict,
**LongNet_encoder_state_dict,
**red_state
}
# init model
text_config = llm.config
vision_config = visual_encoder.config
config = LlavaConfig(
text_config=text_config,
vision_config=vision_config,
attn_implementation='eager')
with init_empty_weights():
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', message='.*non-meta.*', category=UserWarning)
model = LlavaForConditionalGeneration(config)
model.load_state_dict(state_dict, strict=True, assign=True)
# processor
cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
tokenizer = BUILDER.build(cfg.tokenizer)
tokenizer.add_tokens(
AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
special_tokens=True)
tokenizer.add_special_tokens({'pad_token': '<pad>'})
image_processor = BUILDER.build(cfg.image_processor)
assert isinstance(image_processor, CLIPImageProcessor),\
'This conversion format only supports CLIPImageProcessor.'
processor = LlavaProcessor(
tokenizer=tokenizer, image_processor=image_processor)
# Pad to 64 for performance reasons
pad_shape = 64
pre_expansion_embeddings = \
model.language_model.model.embed_tokens.weight.data
mu = torch.mean(pre_expansion_embeddings, dim=0).float()
n = pre_expansion_embeddings.size()[0]
sigma = ((pre_expansion_embeddings - mu).T
@ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(
mu, covariance_matrix=1e-5 * sigma)
# We add an image token so we need to resize the model
ori_vocab_size = config.text_config.vocab_size
tokenizer_vocab_size = tokenizer.encode('<pad>')[-1]
added_token = tokenizer_vocab_size - ori_vocab_size
if added_token > 0:
model.resize_token_embeddings(ori_vocab_size + added_token,
pad_shape)
model.language_model.model.embed_tokens.weight.data[
ori_vocab_size:] = torch.stack(
tuple(
dist.sample()
for _ in range(model.language_model.model.embed_tokens.
weight.data[ori_vocab_size:].shape[0])),
dim=0,
)
model.language_model.lm_head.weight.data[
ori_vocab_size:] = torch.stack(
tuple(dist.sample()
for _ in range(model.language_model.lm_head.weight.
data[ori_vocab_size:].shape[0])),
dim=0,
)
model.config.image_token_index = tokenizer.encode(
DEFAULT_IMAGE_TOKEN)[-1]
model.config.pad_token_id = tokenizer.encode('<pad>')[-1]
# save
print_log(f'Saving to {save_dir}', 'current')
model.save_pretrained(save_dir, **save_pretrained_kwargs)
processor.save_pretrained(save_dir, **save_pretrained_kwargs)
def to_official_llava(self,
cfg,
save_dir,
fp32=False,
save_pretrained_kwargs={}):
VIT_MAPPING = {
'vision_model': 'model.vision_tower.vision_tower.vision_model',
}
PROJECTOR_MAPPING = {
'model.0': 'model.mm_projector.0',
'model.2': 'model.mm_projector.2',
}
LONGNET_MAPPING = {
'layers.0': 'LongNet_encoder.layers.0',
'layers.1': 'LongNet_encoder.layers.1',
'layer_norm': 'LongNet_encoder.layer_norm'
}
REDUCER_MAPPING = {
'query_emb': 'visual_token_reducer.query_emb',
'cross_attn.in_proj_weight': 'visual_token_reducer.cross_attn.in_proj_weight',
'cross_attn.in_proj_bias': 'visual_token_reducer.cross_attn.in_proj_bias',
'cross_attn.out_proj.weight':'visual_token_reducer.cross_attn.out_proj.weight',
'cross_attn.out_proj.bias': 'visual_token_reducer.cross_attn.out_proj.bias'
}
try:
from llava.model import LlavaConfig, LlavaLlamaForCausalLM
except ImportError:
raise ImportError(
'Please install llava with '
'`pip install git+https://github.com/haotian-liu/LLaVA.git '
'--no-deps`.')
assert getattr(self.llm, 'hf_quantizer', None) is None, \
'This conversion format does not support quantized LLM.'
# get state_dict
llm = self.llm
if self.use_llm_lora:
llm = self.llm.merge_and_unload()
llm.config.use_cache = True
if not fp32:
print_log('Convert LLM to float16', 'current')
llm.half()
assert isinstance(llm, LlamaForCausalLM), \
'This conversion format only supports LlamaForCausalLM.'
llm_state_dict = llm.state_dict()
need_visual_encoder = (not self.freeze_visual_encoder
or self.use_visual_encoder_lora)
visual_encoder = self.visual_encoder
if self.use_visual_encoder_lora:
visual_encoder = self.visual_encoder.merge_and_unload()
assert isinstance(visual_encoder, CLIPVisionModel),\
'This conversion format only supports CLIPVisionModel.'
if need_visual_encoder:
visual_encoder_state_dict = visual_encoder.state_dict()
visual_encoder_state_dict = convert_state_dict_to_hf(
visual_encoder_state_dict, VIT_MAPPING)
else:
visual_encoder_state_dict = {}
projector_state_dict = self.projector.state_dict()
projector_state_dict = convert_state_dict_to_hf(
projector_state_dict, PROJECTOR_MAPPING)
LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
LongNet_encoder_state_dict = convert_state_dict_to_hf(
LongNet_encoder_state_dict, LONGNET_MAPPING)
red_state = convert_state_dict_to_hf(
self.visual_token_reducer.state_dict(), REDUCER_MAPPING
) if hasattr(self, 'visual_token_reducer') else {}
state_dict = {
**projector_state_dict,
**llm_state_dict,
**visual_encoder_state_dict,
**LongNet_encoder_state_dict,
**red_state # add visual token reducer state
}
# init model
tokenizer = BUILDER.build(cfg.tokenizer)
image_processor = BUILDER.build(cfg.image_processor)
assert isinstance(image_processor, CLIPImageProcessor),\
'This conversion format only supports CLIPImageProcessor.'
llava_config_dict = llm.config.__dict__.copy()
llava_config_dict.update(
dict(
image_aspect_ratio='pad',
mm_hidden_size=visual_encoder.config.hidden_size,
mm_projector_type=f'mlp{self.projector_depth}x_gelu',
mm_use_im_patch_token=False,
mm_use_im_start_end=False,
mm_vision_select_feature='patch',
mm_vision_select_layer=self.visual_select_layer,
mm_vision_tower=visual_encoder.config.name_or_path,
unfreeze_mm_vision_tower=need_visual_encoder,
model_type='llava',
use_cache=True,
use_mm_proj=True))
llava_config = LlavaConfig(**llava_config_dict)
with init_empty_weights():
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', message='.*non-meta.*', category=UserWarning)
model = LlavaLlamaForCausalLM(llava_config)
model.load_state_dict(state_dict, strict=True, assign=True)
# save
print_log(f'Saving to {save_dir}', 'current')
model.save_pretrained(save_dir, **save_pretrained_kwargs)
image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)