import torch from llava.model.builder import load_pretrained_model from llava.mm_utils import get_model_name_from_path from llava.eval.run_llava import eval_model from ptflops import get_model_complexity_info def get_llava_flops(model_path): """ Calculates the FLOPs and number of parameters for a LLaVA model. """ model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model( model_path=model_path, model_base=None, model_name=model_name, load_4bit=True, # Set to True if you are using a 4-bit quantized model ) # Prepare a dummy input for the model. # The input dimensions should match what the model expects. # For LLaVA, the input is typically a combination of image and text. # The image is processed to a tensor of size (1, 3, 336, 336) for LLaVA-1.5 image_tensor = torch.randn(1, 3, 336, 336).to(model.device) # The text input is a sequence of token IDs. # We will use a dummy sequence of length 512. input_ids = torch.randint(0, tokenizer.vocab_size, (1, 512)).to(model.device) # Use ptflops to get the complexity information # We will use the 'aten' backend which is more suitable for transformer models macs, params = get_model_complexity_info( model, input_res=(3, 336, 336), # A tuple representing the image resolution input_constructor=lambda res: {'images': image_tensor, 'input_ids': input_ids}, as_strings=True, print_per_layer_stat=True, verbose=True, backend='aten' ) print(f"Model: {model_name}") print(f"Computational complexity: {macs}") print(f"Number of parameters: {params}") if __name__ == "__main__": # Add a new argument to the argument parser to trigger FLOPs calculation, # or simply call the function directly with the model path. # For example: model_path = "liuhaotian/llava-v1.5-7b" get_llava_flops(model_path) # You can also integrate this into the existing argument parsing logic # of the run_llava.py script.