Prompt-Enhancer / fl2basepromptgen.py
John6666's picture
Upload 2 files
0b01e52 verified
import spaces
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import torch
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
fl_model = AutoModelForCausalLM.from_pretrained('MiaoshouAI/Florence-2-large-PromptGen-v1.5', trust_remote_code=True).to("cpu").eval()
fl_processor = AutoProcessor.from_pretrained('MiaoshouAI/Florence-2-large-PromptGen-v1.5', trust_remote_code=True)
except Exception as e:
print(e)
fl_model = fl_processor = None
@spaces.GPU(duration=30)
def fl_run(image):
task_prompt = "<GENERATE_PROMPT>"
prompt = task_prompt + "Describe this image in great detail."
# Ensure the image is in RGB mode
if image.mode != "RGB":
image = image.convert("RGB")
fl_model.to(device)
inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = fl_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3
)
fl_model.to("cpu")
generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = fl_processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
return parsed_answer["<GENERATE_PROMPT>Describe this image in great detail."]
def predict_tags_fl2_base_prompt_gen(image: Image.Image, input_tags: str, algo: list[str]):
def to_list(s):
return [x.strip() for x in s.split(",") if not s == ""]
def list_uniq(l):
return sorted(set(l), key=l.index)
if not "Use Florence-2-large-PromptGen" in algo:
return input_tags
tag_list = list_uniq(to_list(input_tags) + to_list(fl_run(image) + ", "))
tag_list.remove("")
return ", ".join(tag_list)