# Copyright (c) 2023-2024 DeepSeek. # # Permission is hereby granted, free of charge, to any person obtaining a copy of # this software and associated documentation files (the "Software"), to deal in # the Software without restriction, including without limitation the rights to # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of # the Software, and to permit persons to whom the Software is furnished to do so, # subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. from math import e import torch from attrdict import AttrDict from einops import rearrange from transformers import ( AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM, PreTrainedModel, ) from transformers.modeling_outputs import CausalLMOutputWithPast from torch.nn import CrossEntropyLoss from transformers.configuration_utils import PretrainedConfig from janus.models.clip_encoder import CLIPVisionTower from janus.models.projector import MlpProjector class vision_head(torch.nn.Module): def __init__(self, params): super().__init__() self.output_mlp_projector = torch.nn.Linear( params.n_embed, params.image_token_embed ) self.vision_activation = torch.nn.GELU() self.vision_head = torch.nn.Linear( params.image_token_embed, params.image_token_size ) def forward(self, x): x = self.output_mlp_projector(x) x = self.vision_activation(x) x = self.vision_head(x) return x def model_name_to_cls(cls_name): if "MlpProjector" in cls_name: cls = MlpProjector elif "CLIPVisionTower" in cls_name: cls = CLIPVisionTower elif "VQ" in cls_name: from janus.models.vq_model import VQ_models cls = VQ_models[cls_name] elif "vision_head" in cls_name: cls = vision_head else: raise ValueError(f"class_name {cls_name} is invalid.") return cls class VisionConfig(PretrainedConfig): model_type = "vision" cls: str = "" params: AttrDict = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = AttrDict(kwargs.get("params", {})) class AlignerConfig(PretrainedConfig): model_type = "aligner" cls: str = "" params: AttrDict = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = AttrDict(kwargs.get("params", {})) class GenVisionConfig(PretrainedConfig): model_type = "gen_vision" cls: str = "" params: AttrDict = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = AttrDict(kwargs.get("params", {})) class GenAlignerConfig(PretrainedConfig): model_type = "gen_aligner" cls: str = "" params: AttrDict = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = AttrDict(kwargs.get("params", {})) class GenHeadConfig(PretrainedConfig): model_type = "gen_head" cls: str = "" params: AttrDict = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = AttrDict(kwargs.get("params", {})) from dataclasses import dataclass @dataclass class VLChatProcessorOutput(): sft_format: str input_ids: torch.Tensor pixel_values: torch.Tensor num_image_tokens: torch.IntTensor def __len__(self): return len(self.input_ids) class MultiModalityConfig(PretrainedConfig): model_type = "multi_modality" vision_config: VisionConfig aligner_config: AlignerConfig gen_vision_config: GenVisionConfig gen_aligner_config: GenAlignerConfig gen_head_config: GenHeadConfig language_config: LlamaConfig def __init__(self, **kwargs): super().__init__(**kwargs) vision_config = kwargs.get("vision_config", {}) self.vision_config = VisionConfig(**vision_config) aligner_config = kwargs.get("aligner_config", {}) self.aligner_config = AlignerConfig(**aligner_config) gen_vision_config = kwargs.get("gen_vision_config", {}) self.gen_vision_config = GenVisionConfig(**gen_vision_config) gen_aligner_config = kwargs.get("gen_aligner_config", {}) self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) gen_head_config = kwargs.get("gen_head_config", {}) self.gen_head_config = GenHeadConfig(**gen_head_config) language_config = kwargs.get("language_config", {}) if isinstance(language_config, LlamaConfig): self.language_config = language_config else: self.language_config = LlamaConfig(**language_config) class MultiModalityPreTrainedModel(PreTrainedModel): config_class = MultiModalityConfig base_model_prefix = "multi_modality" _no_split_modules = [] _skip_keys_device_placement = "past_key_values" class MultiModalityCausalLM(MultiModalityPreTrainedModel): def __init__(self, config: MultiModalityConfig): super().__init__(config) vision_config = config.vision_config vision_cls = model_name_to_cls(vision_config.cls) self.vision_model = vision_cls(**vision_config.params) aligner_config = config.aligner_config aligner_cls = model_name_to_cls(aligner_config.cls) self.aligner = aligner_cls(aligner_config.params) gen_vision_config = config.gen_vision_config gen_vision_cls = model_name_to_cls(gen_vision_config.cls) self.gen_vision_model = gen_vision_cls() gen_aligner_config = config.gen_aligner_config gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls) self.gen_aligner = gen_aligner_cls(gen_aligner_config.params) gen_head_config = config.gen_head_config gen_head_cls = model_name_to_cls(gen_head_config.cls) self.gen_head = gen_head_cls(gen_head_config.params) self.gen_embed = torch.nn.Embedding( gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed ) language_config = config.language_config self.language_model = LlamaForCausalLM(language_config) def prepare_inputs_embeds( self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, images_seq_mask: torch.LongTensor=None, images_emb_mask: torch.LongTensor=None, **kwargs, ): """ Args: input_ids (torch.LongTensor): [b, T] pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] images_seq_mask (torch.BoolTensor): [b, T] images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) Returns: input_embeds (torch.Tensor): [b, T, D] """ # bs, n = pixel_values.shape[0:2] # images = rearrange(pixel_values, "b n c h w -> (b n) c h w") # # [b x n, T2, D] # images_embeds = self.aligner(self.vision_model(images)) # # # [b x n, T2, D] -> [b, n x T2, D] # images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) # # [b, n, T2] -> [b, n x T2] # # images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") # # # [b, T, D] # # input_ids[input_ids < 0] = 0 # ignore the image embeddings # inputs_embeds = self.language_model.get_input_embeddings()(input_ids) # # # replace with the image embeddings # # inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] # # return inputs_embeds, images_embeds bs, n = pixel_values.shape[0:2] print('px.shape', pixel_values.shape) images = rearrange(pixel_values, "b n c h w -> (b n) c h w") # [b x n, T2, D] images_embeds = self.aligner(self.vision_model(images)) # [b x n, T2, D] -> [b, n x T2, D] images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) # [b, n, T2] -> [b, n x T2] images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") # [b, T, D] input_ids[input_ids < 0] = 0 # ignore the image embeddings inputs_embeds = self.language_model.get_input_embeddings()(input_ids) # replace with the image embeddings print('input_ids' ,input_ids.shape) print('images_seq_mask ',images_seq_mask.shape) print('inputs_embeds ',inputs_embeds.shape) print('images_embeds ',images_embeds.shape) print('images_emb_mask ',images_emb_mask.shape) inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] return inputs_embeds def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): return self.gen_aligner(self.gen_embed(image_ids)) def forward(self,vl_chat_processor, input_ids, labels=None, task="understanding", return_dict=True, pixel_values=None, images_seq_mask=None, images_emb_mask=None, **kwargs): if task == "understanding": inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask) return self.language_model.forward( inputs_embeds=inputs_embeds, labels=labels, **kwargs ) elif task == "generation": print('LLLLLLLLLLL ',pixel_values) print(kwargs) image_token_num_per_image = 576 cfg_weight = 5 temperature = 1 tokens = torch.zeros((2*input_ids.size(0), input_ids.size(1)), dtype=torch.int).cuda() for i in range(2): tokens[i*input_ids.size(0):(i+1)*input_ids.size(0), :] = input_ids if i % 2 != 0: tokens[i*input_ids.size(0):(i+1)*input_ids.size(0), 1:-1] = 100015 # pad_id inputs_embeds = self.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((2*input_ids.size(0), image_token_num_per_image), dtype=torch.int).cuda() outputs = self.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=None, labels=labels) hidden_states = outputs.last_hidden_state logits = self.gen_head(hidden_states) logits_cond = logits[0::2, :] logits_uncond = logits[1::2, :] all_logits = logits_uncond + cfg_weight * (logits_cond - logits_uncond) loss_fct = CrossEntropyLoss() shift_logits = all_logits[..., :-1, :].contiguous() shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size) if labels is not None: shift_labels = labels[..., 1:].contiguous() shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) else: loss = None if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) elif task == "generation_direct": outputs = self.language_model.model(input_ids=input_ids, **kwargs) hidden_states = outputs[0] # possibly outputs[0] logits = self.gen_head(hidden_states) loss = None logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size) if labels is not None: shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) else: loss = None if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) elif task == "image_editing": # image_token_num_per_image = 576 # img_size = 384 # patch_size = 16 # cfg_weight = kwargs.get('cfg_weight', 5) # cfg_weight2 = kwargs.get('cfg_weight2', 5) # temperature = kwargs.get('temperature', 1.0) # parallel_size = kwargs.get('parallel_size', input_ids.size(0)) # # # 构造tokens: 每个输入生成3个版本 (cond_full, cond_part, uncond) # tokens = torch.zeros((3 * input_ids.size(0), input_ids.size(1)), dtype=torch.int).cuda() # pre_data = [] # img_len = len(kwargs['source_image']) # # # 处理输入图像 # import PIL.Image # images = [PIL.Image.open(image_path).convert("RGB") for image_path in kwargs['source_image']] # encoder_pixel_values = vl_chat_processor.image_processor(images, return_tensors="pt")['pixel_values'] # # # 为每个样本构造3种条件的tokens # for i in range(3 * input_ids.size(0)): # tokens[i, :] = input_ids[i // 3, :] # if i % 3 == 2: # uncond版本,用pad_id替换中间tokens # tokens[i, 1:-1] = 100015 # pad_id # # # 添加数据到pre_data # pre_data.append(VLChatProcessorOutput( # sft_format=kwargs['sft_format'][i // 3], # pixel_values=encoder_pixel_values[i // 3, :], # input_ids=tokens[i - 2], # num_image_tokens=[vl_chat_processor.num_image_tokens] * 1 # )) # pre_data.append(VLChatProcessorOutput( # sft_format=kwargs['sft_format'][i // 3], # pixel_values=encoder_pixel_values[i // 3, :], # input_ids=tokens[i - 1], # num_image_tokens=[vl_chat_processor.num_image_tokens] * 1 # )) # pre_data.append(VLChatProcessorOutput( # sft_format=kwargs['sft_format'][i // 3], # pixel_values=None, # input_ids=tokens[i], # num_image_tokens=[] # )) # # # 批处理输入数据 # prepare_inputs = vl_chat_processor.batchify(pre_data) # # # 准备输入embeddings # inputs_embeds = self.prepare_inputs_embeds( # input_ids=tokens.cuda(), # pixel_values=prepare_inputs['pixel_values'].to(torch.bfloat16).cuda(), # images_emb_mask=prepare_inputs['images_emb_mask'].cuda(), # images_seq_mask=prepare_inputs['images_seq_mask'].cuda() # ) # # # 处理输入图像的编码 # input_image_pixel_values = vl_chat_processor.image_processor(images, return_tensors="pt")[ # 'pixel_values'].to(torch.bfloat16).cuda() # quant_input, emb_loss_input, info_input = self.gen_vision_model.encode(input_image_pixel_values) # image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1) # image_embeds_input = self.prepare_gen_img_embeds(image_tokens_input) # # # 将输入图像embeddings插入到正确位置 # ppp = (tokens == 100580).nonzero() # 找到图像token位置 # for ii, ind in enumerate(ppp): # if ii % 4 == 0: # offset = ind[1] + 2 # inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[ # (ii // 2) % img_len] # # # **训练模式:只计算loss,不生成图像** # labels = None # if labels is not None: # outputs = self.language_model.model( # inputs_embeds=inputs_embeds, # use_cache=True, # past_key_values=None # ) # hidden_states = outputs.last_hidden_state # logits = self.gen_head(hidden_states) # # # 分离三种条件的logits # logit_cond_full = logits[0::3, :] # logit_cond_part = logits[1::3, :] # logit_uncond = logits[2::3, :] # # # 计算组合logits # logit_cond = (logit_cond_full + cfg_weight2 * logit_cond_part) / (1 + cfg_weight2) # all_logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) # # # 计算loss # loss_fct = CrossEntropyLoss() # shift_logits = all_logits[..., :-1, :].contiguous() # shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size) # shift_labels = labels[..., 1:].contiguous().view(-1).to(shift_logits.device) # loss = loss_fct(shift_logits, shift_labels) # # if not return_dict: # output = (all_logits,) + outputs[1:] # return ((loss,) + output) if loss is not None else output # # return CausalLMOutputWithPast( # loss=loss, # logits=all_logits, # past_key_values=outputs.past_key_values, # hidden_states=outputs.hidden_states, # attentions=outputs.attentions, # ) # # # **推理模式:自回归生成图像** # else: # import numpy as np # with torch.inference_mode(): # generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda() # outputs = None # # # 自回归生成循环 # for i in range(image_token_num_per_image): # # 前向传播 # outputs = self.language_model.model( # inputs_embeds=inputs_embeds, # use_cache=True, # past_key_values=outputs.past_key_values if i != 0 else None # ) # hidden_states = outputs.last_hidden_state # # # 获取最后一个token的logits # logits = self.gen_head(hidden_states[:, -1, :]) # # # 分离三种条件的logits # logit_cond_full = logits[0::3, :] # logit_cond_part = logits[1::3, :] # logit_uncond = logits[2::3, :] # # # 计算组合logits # logit_cond = (logit_cond_full + cfg_weight2 * logit_cond_part) / (1 + cfg_weight2) # combined_logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) # # # 采样下一个token # probs = torch.softmax(combined_logits / temperature, dim=-1) # next_token = torch.multinomial(probs, num_samples=1) # generated_tokens[:, i] = next_token.squeeze(dim=-1) # # # 为下一步准备输入embeddings # if i < image_token_num_per_image - 1: # # 扩展next_token到3个副本 # next_token_expanded = torch.cat([ # next_token.unsqueeze(dim=1), # next_token.unsqueeze(dim=1), # next_token.unsqueeze(dim=1) # ], dim=1).view(-1) # # # 获取下一个token的embeddings # img_embeds = self.prepare_gen_img_embeds(next_token_expanded) # inputs_embeds = img_embeds.unsqueeze(dim=1) # # # 解码生成的tokens为图像 # dec = self.gen_vision_model.decode_code( # generated_tokens.to(dtype=torch.int), # shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size] # ) # dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) # dec = np.clip((dec + 1) / 2 * 255, 0, 255) # # # 构造输出图像数组,确保形状正确 # visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8) # visual_img[:, :, :] = dec # # # 可选:保存调试图像(仅在推理模式下) # # if kwargs.get('save_debug_image', False): # # # 确保保存的是单张图像,形状为(384, 384, 3) # debug_img = visual_img[0] if visual_img.shape[0] > 0 else visual_img # PIL.Image.fromarray(debug_img).save('/home/ps/Bxh/align-anything/debug_output.png') image_token_num_per_image = 576 img_size = 384 patch_size = 16 cfg_weight = 5 temperature = 1 tokens = torch.zeros((3 * input_ids.size(0), input_ids.size(1)), dtype=torch.int).cuda() pre_data = [] img_len = len(kwargs['source_image']) # print(kwargs['source_image'].size(0)) print(kwargs['source_image']) print(len(kwargs['source_image'][0])) import PIL.Image images = [PIL.Image.open(image_path).convert("RGB") for image_path in kwargs['source_image']] # images = [PIL.Image.open(image_path).convert("RGB") for image_path in kwargs['source_image']] print('len_images : ',len(images)) encoder_pixel_values = vl_chat_processor.image_processor(images, return_tensors="pt")['pixel_values'] print(encoder_pixel_values.shape) print(encoder_pixel_values[0].shape) print(encoder_pixel_values[0][0][0][:2]) # print((encoder_pixel_values[0]!= encoder_pixel_values[1]).sum()) # print((encoder_pixel_values[0] != encoder_pixel_values[2]).sum()) # print((encoder_pixel_values[0] != encoder_pixel_values[3]).sum()) for i in range(3 * input_ids.size(0)): print(input_ids.shape) print(input_ids.size(0)) tokens[i * input_ids.size(0):(i + 1) * input_ids.size(0),:] = input_ids[i // 3,:] if i % 3 == 2: tokens[i * input_ids.size(0):(i + 1) * input_ids.size(0), 1:-1] = 100002 print(encoder_pixel_values[i//3,:].shape) print(len(kwargs['sft_format'][i//3])) print(tokens[i].shape) pre_data.append(VLChatProcessorOutput(sft_format=kwargs['sft_format'][i//3], pixel_values=encoder_pixel_values[i//3,:], input_ids=tokens[i - 2], num_image_tokens=[vl_chat_processor.num_image_tokens] * 1)) pre_data.append(VLChatProcessorOutput(sft_format=kwargs['sft_format'][i//3], pixel_values=encoder_pixel_values[i//3,:], input_ids=tokens[i - 1], num_image_tokens=[vl_chat_processor.num_image_tokens] * 1)) pre_data.append(VLChatProcessorOutput(sft_format=kwargs['sft_format'][i//3], pixel_values=None, input_ids=tokens[i], num_image_tokens=[])) # print(tokens.shape) # _, src_image = self.prepare_inputs_embeds(tokens[0], kwargs['source_image']) ppp = (tokens == 100580).nonzero() # print(tokens[0][583],tokens[0][584],tokens[0][576],tokens[0][577]) # print(input_ids.size(0)) # print(tokens[0][2], tokens[0][3]) # print(tokens[0][1161], tokens[0][1162]) # print(ppp) # print(src_image.shape) # img_len = src_image.shape[0] # # inputs_embeds_2 = self.language_model.get_input_embeddings()(tokens[1]) # # inputs_embeds_3 = self.language_model.get_input_embeddings()(tokens[2]) # inputs_embeds = self.language_model.get_input_embeddings()(tokens) # print(inputs_embeds.shape) prepare_inputs = vl_chat_processor.batchify(pre_data) # print('prepare_inputs pixel_values', prepare_inputs['pixel_values'].shape) # print('prepare_inputs images_emb_mask', prepare_inputs['images_emb_mask'].shape) # print('prepare_inputs images_seq_mask', prepare_inputs['images_seq_mask'].shape) inputs_embeds = self.prepare_inputs_embeds( input_ids=tokens.cuda(), pixel_values=prepare_inputs['pixel_values'].to(torch.bfloat16).cuda(), images_emb_mask=prepare_inputs['images_emb_mask'].cuda(), images_seq_mask=prepare_inputs['images_seq_mask'].cuda() ) input_image_pixel_values = vl_chat_processor.image_processor(images, return_tensors="pt")['pixel_values'].to(torch.bfloat16).cuda() quant_input, emb_loss_input, info_input = self.gen_vision_model.encode(input_image_pixel_values) image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1) image_embeds_input = self.prepare_gen_img_embeds(image_tokens_input) # print('image_embeds_input', image_embeds_input.shape) # print('inputs_embeds', inputs_embeds.shape) for ii, ind in enumerate(ppp): # print('nmsl: ', ii, ind) if ii % 4 == 0: offset = ind[1] + 2 inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[ii // 4] generated_tokens = torch.zeros((3 * input_ids.size(0), image_token_num_per_image), dtype=torch.int).cuda() outputs = self.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=None, labels=labels) hidden_states = outputs.last_hidden_state print('HHHH',hidden_states.shape) # torch.save(inputs_embeds, '/data/bxh_data/unify_model/hidden_states.pt') logits = self.gen_head(hidden_states) print('logits.shape', logits.shape) # [3, 1760, 16384]) print(labels.shape) # [3, 1760] # logits_cond = logits[0::2, :] # logits_uncond = logits[1::2, :] logit_cond_full = logits[0::3, :] logit_cond_part = logits[1::3, :] logit_uncond = logits[2::3, :] cfg_weight2 = 5 logit_cond = (logit_cond_full + cfg_weight2 * (logit_cond_part)) / (1 + cfg_weight2) all_logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) # all_logits = logits_uncond + cfg_weight * (logits_cond - logits_uncond) loss_fct = CrossEntropyLoss() shift_logits = all_logits[..., :-1, :].contiguous() shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size) if labels is not None: shift_labels = labels[..., 1:].contiguous() shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) print(shift_logits.shape, shift_labels.shape) loss = loss_fct(shift_logits, shift_labels) else: loss = None if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) AutoConfig.register("vision", VisionConfig) AutoConfig.register("aligner", AlignerConfig) AutoConfig.register("gen_vision", GenVisionConfig) AutoConfig.register("gen_aligner", GenAlignerConfig) AutoConfig.register("gen_head", GenHeadConfig) AutoConfig.register("multi_modality", MultiModalityConfig) AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)