import gradio as gr from transformers import AutoProcessor, AutoModelForCausalLM import os from PIL import Image import zipfile import tempfile import re import torch import spaces # Check for GPU and set device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Define model configurations and task prompts model_configs = { 'gokaygokay/Florence-2-Flux': "", 'gokaygokay/Florence-2-Flux-Large': "", 'yayayaaa/florence-2-large-ft-moredetailed': "", 'MiaoshouAI/Florence-2-large-PromptGen-v2.0': "" } # Define a description for each model to be shown in UI model_descriptions = { 'gokaygokay/Florence-2-Flux': "Faster version with good quality captions", 'gokaygokay/Florence-2-Flux-Large': "Provides detailed captions with better image understanding", 'yayayaaa/florence-2-large-ft-moredetailed': "Fine-tuned specifically for more detailed captions", 'MiaoshouAI/Florence-2-large-PromptGen-v2.0': "Memory efficient model with high quality detailed captions" } # Load a single model to start with print("Loading Florence-2 model...") model_name = 'gokaygokay/Florence-2-Flux' task_prompt = model_configs[model_name] # Load model without device_map model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True ).eval() # Move to GPU if available if device == "cuda": model = model.to("cuda") processor = AutoProcessor.from_pretrained( model_name, trust_remote_code=True ) print(f"Successfully loaded model: {model_name}") title = """

Florence-2 Caption Dataset Creator

[Florence-2 Flux Large] [Florence-2 Flux Base] [Florence-2 More Detailed] [MiaoshouAI PromptGen v2.0]

""" # Function to clean caption text def clean_caption(text): # Remove tokens from the end text = re.sub(r'+$', '', text) # Remove any extra whitespace text = text.strip() return text # Function to load a specific model def load_model(selected_model_name): global model, processor, model_name, task_prompt # Only reload if the model is different if selected_model_name != model_name: print(f"Switching to model: {selected_model_name}") # Release memory from the current model del model torch.cuda.empty_cache() # Load the new model model_name = selected_model_name task_prompt = model_configs[model_name] # Load model without device_map model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True ).eval() # Move to GPU if available if device == "cuda": model = model.to("cuda") processor = AutoProcessor.from_pretrained( model_name, trust_remote_code=True ) print(f"Successfully switched to model: {model_name}") return "Model loaded successfully" # Special function for MiaoshouAI model def generate_miaoshou_caption(image): """Special handling for MiaoshouAI model""" # Create inputs for MiaoshouAI model inputs = processor( text=task_prompt, images=image, return_tensors="pt" ) # Move inputs to device for key in inputs: if isinstance(inputs[key], torch.Tensor): inputs[key] = inputs[key].to(device) # Generate using only input_ids and pixel_values generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=512, do_sample=False, num_beams=3 ) # Decode the generated text generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] # Use the model's post-processing try: parsed_answer = processor.post_process_generation( generated_text, task=task_prompt, image_size=(image.width, image.height) ) # Get the generated text from parsed answer if isinstance(parsed_answer, dict) and task_prompt in parsed_answer: return parsed_answer[task_prompt] else: return str(parsed_answer) except Exception as e: print(f"Post-processing error: {str(e)}") # Fallback to regular decoding if post-processing fails return processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Function to generate a caption for a single image @spaces.GPU def generate_caption(image, selected_model_name): if image is None: return "Please upload an image." # Check if we need to switch models if selected_model_name != model_name: try: load_model(selected_model_name) except Exception as e: return f"Error loading model {selected_model_name}: {str(e)}" if isinstance(image, str): # Handle file path input image = Image.open(image) else: # Handle numpy array input from gradio image = Image.fromarray(image) # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") try: # Special handling for MiaoshouAI model if model_name == 'MiaoshouAI/Florence-2-large-PromptGen-v2.0': caption = generate_miaoshou_caption(image) else: # Regular processing for other models # Create an appropriate prompt based on the model prompt = task_prompt if prompt == "": prompt = prompt + "Describe this image in great detail." # Process the image inputs = processor(text=prompt, images=image, return_tensors="pt") # Move inputs to the same device as the model for key in inputs: if isinstance(inputs[key], torch.Tensor): inputs[key] = inputs[key].to(device) # Generate the caption with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=512, num_beams=3, repetition_penalty=1.10, ) # Decode the generated text generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Handle post-processing for different models if task_prompt == "": # Use the post processing for Florence-2-Flux models try: decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation( decoded_text, task=task_prompt, image_size=(image.width, image.height) ) caption = parsed_answer[task_prompt] except Exception as e: print(f"Error in post processing: {str(e)}") caption = generated_text # Fallback to direct output else: # For other models, use the generated text directly caption = generated_text # Clean the caption to remove padding tokens clean_text = clean_caption(caption) return clean_text except Exception as e: error_msg = f"Error generating caption: {str(e)}" print(error_msg) return error_msg # Function to process multiple images and create a downloadable zip @spaces.GPU def process_images(images, selected_model_name, add_trigger=True, trigger_word="trigger"): """Process multiple images, caption them, and create downloadable zip file""" if not images: return "No images uploaded.", None # Check if we need to switch models if selected_model_name != model_name: try: load_model(selected_model_name) except Exception as e: return f"Error loading model {selected_model_name}: {str(e)}", None # Create a temporary directory to store files temp_dir = tempfile.mkdtemp() # Path for the zip file zip_path = os.path.join(temp_dir, "captions_dataset.zip") results = [] try: # Create a zip file with zipfile.ZipFile(zip_path, 'w') as zipf: for img_file in images: try: # Get file path and extract filename img_path = img_file.name base_name = os.path.basename(img_path) file_name, file_ext = os.path.splitext(base_name) # Skip unsupported formats if file_ext.lower() not in ['.jpg', '.jpeg', '.png']: results.append(f"⚠️ Skipped {base_name}: Unsupported format (only jpg, jpeg, png supported)") continue # Generate caption # Open the image once image = Image.open(img_path) if image.mode != "RGB": image = image.convert("RGB") # Use the same caption generation logic as in generate_caption if model_name == 'MiaoshouAI/Florence-2-large-PromptGen-v2.0': caption = generate_miaoshou_caption(image) else: # Regular processing for other models # Create an appropriate prompt based on the model prompt = task_prompt if prompt == "": prompt = prompt + "Describe this image in great detail." # Process the image inputs = processor(text=prompt, images=image, return_tensors="pt") # Move inputs to the same device as the model for key in inputs: if isinstance(inputs[key], torch.Tensor): inputs[key] = inputs[key].to(device) # Generate the caption with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=512, num_beams=3, repetition_penalty=1.10, ) # Decode the generated text generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Handle post-processing for different models if task_prompt == "": # Use the post processing for Florence-2-Flux models try: decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation( decoded_text, task=task_prompt, image_size=(image.width, image.height) ) caption = parsed_answer[task_prompt] except Exception as e: print(f"Error in post processing: {str(e)}") caption = generated_text # Fallback to direct output else: # For other models, use the generated text directly caption = generated_text # Clean caption and add trigger if needed caption = clean_caption(caption) if add_trigger: caption = f"[{trigger_word}] {caption}" # Create a text file with the caption txt_filename = f"{file_name}.txt" txt_path = os.path.join(temp_dir, txt_filename) with open(txt_path, "w", encoding="utf-8") as f: f.write(caption) # Add the text file to the zip zipf.write(txt_path, txt_filename) # Add the image to the zip zipf.write(img_path, base_name) # Add to results caption_preview = f"{caption[:50]}..." if len(caption) > 50 else caption results.append(f"✓ {base_name} → {file_name}.txt: {caption_preview}") except Exception as e: results.append(f"❌ Error processing {base_name}: {str(e)}") except Exception as e: error_msg = f"Error creating zip file: {str(e)}" print(error_msg) return error_msg, None # Format results summary = f"Processed {len(results)} images. Ready for download.\n\n" result_text = summary + "\n".join(results) return result_text, zip_path # Create the Gradio interface with gr.Blocks() as demo: gr.HTML(title) with gr.Tabs(): # Single image preview tab with gr.TabItem("Preview Caption"): with gr.Row(): with gr.Column(): input_img = gr.Image(label="Input Picture") model_selector = gr.Dropdown( choices=list(model_configs.keys()), label="Model", value=model_name ) preview_btn = gr.Button(value="Generate Caption") with gr.Column(): output_text = gr.Textbox(label="Generated Caption", lines=8) preview_btn.click(generate_caption, [input_img, model_selector], [output_text]) # Dataset creation tab with gr.TabItem("Create Dataset"): with gr.Row(): with gr.Column(scale=1): batch_images = gr.File( file_count="multiple", label="Upload Multiple Images (JPG, JPEG, PNG)" ) batch_model_selector = gr.Dropdown( choices=list(model_configs.keys()), label="Model", value=model_name ) with gr.Row(): add_trigger = gr.Checkbox(label="Add Trigger Word", value=True) trigger_word = gr.Textbox( label="Trigger Word", placeholder="trigger", value="trigger" ) process_btn = gr.Button(value="Process Images") with gr.Column(scale=1): batch_results = gr.Textbox(label="Processing Results", lines=15) download_output = gr.File(label="Download Dataset (Images & Captions)") # Connect process button process_btn.click( fn=process_images, inputs=[batch_images, batch_model_selector, add_trigger, trigger_word], outputs=[batch_results, download_output] ) # Instructions with model information with gr.Accordion("Instructions & Model Information", open=True): gr.Markdown(""" ## Instructions ### Preview Caption - Upload a single image and generate a detailed caption - Try different models to compare results ### Create Dataset - Upload multiple images to process them all at once - All images will be captioned and saved with matching .txt files - By default, captions include `[trigger]` at the beginning (you can modify the trigger word) - Click "Process Images" to generate captions and create a downloadable dataset - Use the download button to get a ZIP file containing all images and caption files ## Models Available """) # Create a markdown description for each model model_md = "" for model_id, description in model_descriptions.items(): model_short_name = model_id.split('/')[-1] model_md += f"- **{model_short_name}**: {description}\n" gr.Markdown(model_md) # Add special note for MiaoshouAI model gr.Markdown(""" ### MiaoshouAI/Florence-2-large-PromptGen-v2.0 Features - Improved caption quality for detailed captions - Memory efficient (requires only ~1GB VRAM) - Fast generation while maintaining high quality - Supports multiple caption formats including detailed captions, tags, and analysis Supported image formats: JPG, JPEG, PNG """) if __name__ == "__main__": demo.launch()