import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import os
from PIL import Image
import zipfile
import tempfile
import re
import torch
import spaces
# Check for GPU and set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Define model configurations and task prompts
model_configs = {
'gokaygokay/Florence-2-Flux': "",
'gokaygokay/Florence-2-Flux-Large': "",
'yayayaaa/florence-2-large-ft-moredetailed': "",
'MiaoshouAI/Florence-2-large-PromptGen-v2.0': ""
}
# Define a description for each model to be shown in UI
model_descriptions = {
'gokaygokay/Florence-2-Flux': "Faster version with good quality captions",
'gokaygokay/Florence-2-Flux-Large': "Provides detailed captions with better image understanding",
'yayayaaa/florence-2-large-ft-moredetailed': "Fine-tuned specifically for more detailed captions",
'MiaoshouAI/Florence-2-large-PromptGen-v2.0': "Memory efficient model with high quality detailed captions"
}
# Load a single model to start with
print("Loading Florence-2 model...")
model_name = 'gokaygokay/Florence-2-Flux'
task_prompt = model_configs[model_name]
# Load model without device_map
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True
).eval()
# Move to GPU if available
if device == "cuda":
model = model.to("cuda")
processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True
)
print(f"Successfully loaded model: {model_name}")
title = """Florence-2 Caption Dataset Creator
[Florence-2 Flux Large]
[Florence-2 Flux Base]
[Florence-2 More Detailed]
[MiaoshouAI PromptGen v2.0]
"""
# Function to clean caption text
def clean_caption(text):
# Remove tokens from the end
text = re.sub(r'+$', '', text)
# Remove any extra whitespace
text = text.strip()
return text
# Function to load a specific model
def load_model(selected_model_name):
global model, processor, model_name, task_prompt
# Only reload if the model is different
if selected_model_name != model_name:
print(f"Switching to model: {selected_model_name}")
# Release memory from the current model
del model
torch.cuda.empty_cache()
# Load the new model
model_name = selected_model_name
task_prompt = model_configs[model_name]
# Load model without device_map
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True
).eval()
# Move to GPU if available
if device == "cuda":
model = model.to("cuda")
processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True
)
print(f"Successfully switched to model: {model_name}")
return "Model loaded successfully"
# Special function for MiaoshouAI model
def generate_miaoshou_caption(image):
"""Special handling for MiaoshouAI model"""
# Create inputs for MiaoshouAI model
inputs = processor(
text=task_prompt,
images=image,
return_tensors="pt"
)
# Move inputs to device
for key in inputs:
if isinstance(inputs[key], torch.Tensor):
inputs[key] = inputs[key].to(device)
# Generate using only input_ids and pixel_values
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=512,
do_sample=False,
num_beams=3
)
# Decode the generated text
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
# Use the model's post-processing
try:
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
# Get the generated text from parsed answer
if isinstance(parsed_answer, dict) and task_prompt in parsed_answer:
return parsed_answer[task_prompt]
else:
return str(parsed_answer)
except Exception as e:
print(f"Post-processing error: {str(e)}")
# Fallback to regular decoding if post-processing fails
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Function to generate a caption for a single image
@spaces.GPU
def generate_caption(image, selected_model_name):
if image is None:
return "Please upload an image."
# Check if we need to switch models
if selected_model_name != model_name:
try:
load_model(selected_model_name)
except Exception as e:
return f"Error loading model {selected_model_name}: {str(e)}"
if isinstance(image, str):
# Handle file path input
image = Image.open(image)
else:
# Handle numpy array input from gradio
image = Image.fromarray(image)
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
try:
# Special handling for MiaoshouAI model
if model_name == 'MiaoshouAI/Florence-2-large-PromptGen-v2.0':
caption = generate_miaoshou_caption(image)
else:
# Regular processing for other models
# Create an appropriate prompt based on the model
prompt = task_prompt
if prompt == "":
prompt = prompt + "Describe this image in great detail."
# Process the image
inputs = processor(text=prompt, images=image, return_tensors="pt")
# Move inputs to the same device as the model
for key in inputs:
if isinstance(inputs[key], torch.Tensor):
inputs[key] = inputs[key].to(device)
# Generate the caption
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
num_beams=3,
repetition_penalty=1.10,
)
# Decode the generated text
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Handle post-processing for different models
if task_prompt == "":
# Use the post processing for Florence-2-Flux models
try:
decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
decoded_text,
task=task_prompt,
image_size=(image.width, image.height)
)
caption = parsed_answer[task_prompt]
except Exception as e:
print(f"Error in post processing: {str(e)}")
caption = generated_text # Fallback to direct output
else:
# For other models, use the generated text directly
caption = generated_text
# Clean the caption to remove padding tokens
clean_text = clean_caption(caption)
return clean_text
except Exception as e:
error_msg = f"Error generating caption: {str(e)}"
print(error_msg)
return error_msg
# Function to process multiple images and create a downloadable zip
@spaces.GPU
def process_images(images, selected_model_name, add_trigger=True, trigger_word="trigger"):
"""Process multiple images, caption them, and create downloadable zip file"""
if not images:
return "No images uploaded.", None
# Check if we need to switch models
if selected_model_name != model_name:
try:
load_model(selected_model_name)
except Exception as e:
return f"Error loading model {selected_model_name}: {str(e)}", None
# Create a temporary directory to store files
temp_dir = tempfile.mkdtemp()
# Path for the zip file
zip_path = os.path.join(temp_dir, "captions_dataset.zip")
results = []
try:
# Create a zip file
with zipfile.ZipFile(zip_path, 'w') as zipf:
for img_file in images:
try:
# Get file path and extract filename
img_path = img_file.name
base_name = os.path.basename(img_path)
file_name, file_ext = os.path.splitext(base_name)
# Skip unsupported formats
if file_ext.lower() not in ['.jpg', '.jpeg', '.png']:
results.append(f"⚠️ Skipped {base_name}: Unsupported format (only jpg, jpeg, png supported)")
continue
# Generate caption
# Open the image once
image = Image.open(img_path)
if image.mode != "RGB":
image = image.convert("RGB")
# Use the same caption generation logic as in generate_caption
if model_name == 'MiaoshouAI/Florence-2-large-PromptGen-v2.0':
caption = generate_miaoshou_caption(image)
else:
# Regular processing for other models
# Create an appropriate prompt based on the model
prompt = task_prompt
if prompt == "":
prompt = prompt + "Describe this image in great detail."
# Process the image
inputs = processor(text=prompt, images=image, return_tensors="pt")
# Move inputs to the same device as the model
for key in inputs:
if isinstance(inputs[key], torch.Tensor):
inputs[key] = inputs[key].to(device)
# Generate the caption
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
num_beams=3,
repetition_penalty=1.10,
)
# Decode the generated text
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Handle post-processing for different models
if task_prompt == "":
# Use the post processing for Florence-2-Flux models
try:
decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
decoded_text,
task=task_prompt,
image_size=(image.width, image.height)
)
caption = parsed_answer[task_prompt]
except Exception as e:
print(f"Error in post processing: {str(e)}")
caption = generated_text # Fallback to direct output
else:
# For other models, use the generated text directly
caption = generated_text
# Clean caption and add trigger if needed
caption = clean_caption(caption)
if add_trigger:
caption = f"[{trigger_word}] {caption}"
# Create a text file with the caption
txt_filename = f"{file_name}.txt"
txt_path = os.path.join(temp_dir, txt_filename)
with open(txt_path, "w", encoding="utf-8") as f:
f.write(caption)
# Add the text file to the zip
zipf.write(txt_path, txt_filename)
# Add the image to the zip
zipf.write(img_path, base_name)
# Add to results
caption_preview = f"{caption[:50]}..." if len(caption) > 50 else caption
results.append(f"✓ {base_name} → {file_name}.txt: {caption_preview}")
except Exception as e:
results.append(f"❌ Error processing {base_name}: {str(e)}")
except Exception as e:
error_msg = f"Error creating zip file: {str(e)}"
print(error_msg)
return error_msg, None
# Format results
summary = f"Processed {len(results)} images. Ready for download.\n\n"
result_text = summary + "\n".join(results)
return result_text, zip_path
# Create the Gradio interface
with gr.Blocks() as demo:
gr.HTML(title)
with gr.Tabs():
# Single image preview tab
with gr.TabItem("Preview Caption"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture")
model_selector = gr.Dropdown(
choices=list(model_configs.keys()),
label="Model",
value=model_name
)
preview_btn = gr.Button(value="Generate Caption")
with gr.Column():
output_text = gr.Textbox(label="Generated Caption", lines=8)
preview_btn.click(generate_caption, [input_img, model_selector], [output_text])
# Dataset creation tab
with gr.TabItem("Create Dataset"):
with gr.Row():
with gr.Column(scale=1):
batch_images = gr.File(
file_count="multiple",
label="Upload Multiple Images (JPG, JPEG, PNG)"
)
batch_model_selector = gr.Dropdown(
choices=list(model_configs.keys()),
label="Model",
value=model_name
)
with gr.Row():
add_trigger = gr.Checkbox(label="Add Trigger Word", value=True)
trigger_word = gr.Textbox(
label="Trigger Word",
placeholder="trigger",
value="trigger"
)
process_btn = gr.Button(value="Process Images")
with gr.Column(scale=1):
batch_results = gr.Textbox(label="Processing Results", lines=15)
download_output = gr.File(label="Download Dataset (Images & Captions)")
# Connect process button
process_btn.click(
fn=process_images,
inputs=[batch_images, batch_model_selector, add_trigger, trigger_word],
outputs=[batch_results, download_output]
)
# Instructions with model information
with gr.Accordion("Instructions & Model Information", open=True):
gr.Markdown("""
## Instructions
### Preview Caption
- Upload a single image and generate a detailed caption
- Try different models to compare results
### Create Dataset
- Upload multiple images to process them all at once
- All images will be captioned and saved with matching .txt files
- By default, captions include `[trigger]` at the beginning (you can modify the trigger word)
- Click "Process Images" to generate captions and create a downloadable dataset
- Use the download button to get a ZIP file containing all images and caption files
## Models Available
""")
# Create a markdown description for each model
model_md = ""
for model_id, description in model_descriptions.items():
model_short_name = model_id.split('/')[-1]
model_md += f"- **{model_short_name}**: {description}\n"
gr.Markdown(model_md)
# Add special note for MiaoshouAI model
gr.Markdown("""
### MiaoshouAI/Florence-2-large-PromptGen-v2.0 Features
- Improved caption quality for detailed captions
- Memory efficient (requires only ~1GB VRAM)
- Fast generation while maintaining high quality
- Supports multiple caption formats including detailed captions, tags, and analysis
Supported image formats: JPG, JPEG, PNG
""")
if __name__ == "__main__":
demo.launch()