John6666 commited on
Commit
1ffbc6e
·
verified ·
1 Parent(s): eb0e639

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +21 -18
  2. fl2basepromptgen.py +76 -0
app.py CHANGED
@@ -24,6 +24,9 @@ from tagger import (
24
  from fl2sd3longcap import (
25
  predict_tags_fl2_sd3,
26
  )
 
 
 
27
  from promptenhancer import prompt_enhancer
28
 
29
 
@@ -38,7 +41,8 @@ def description_ui():
38
  - Models: p1atdev's [wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf),\
39
  gokaygokay's [Florence-2-SD3-Captioner](https://huggingface.co/gokaygokay/Florence-2-SD3-Captioner),\
40
  [Lamini-Prompt-Enchance](https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance),\
41
- [Lamini-Prompt-Enchance-Long](https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long).
 
42
  """
43
  )
44
 
@@ -58,7 +62,7 @@ def main():
58
  input_tag_type = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
59
  recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
60
  keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
61
- image_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"], label="Algorithms", value=["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"])
62
  generate_from_image_btn = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
63
 
64
  with gr.Group():
@@ -93,31 +97,30 @@ def main():
93
  output_text_pony = gr.TextArea(label="Output tags (Pony e621 style)", interactive=False, show_copy_button=True)
94
  copy_btn_pony = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
95
 
96
- translate_input_prompt_button.click(translate_prompt, [input_general], [input_general])
97
- translate_input_prompt_button.click(translate_prompt, [input_character], [input_character])
98
- translate_input_prompt_button.click(translate_prompt, [input_copyright], [input_copyright])
99
 
100
  generate_from_image_btn.click(
101
  predict_tags_wd,
102
  [input_image, input_general, image_algorithms, general_threshold, character_threshold],
103
- [
104
- input_copyright,
105
- input_character,
106
- input_general,
107
- copy_input_btn,
108
- ],
109
  ).success(
110
  predict_tags_fl2_sd3,
111
  [input_image, input_general, image_algorithms],
112
  [input_general],
113
  ).success(
114
- remove_specific_prompt, [input_general, keep_tags], [input_general],
 
 
 
 
115
  ).success(
116
- convert_danbooru_to_e621_prompt, [input_general, input_tag_type], [input_general],
117
  ).success(
118
- insert_recom_prompt, [input_general, dummy_np, recom_prompt], [input_general, dummy_np],
119
  )
120
- copy_input_btn.click(compose_prompt_to_copy, [input_character, input_copyright, input_general], [input_tags_to_copy]).success(
121
  gradio_copy_text, [input_tags_to_copy], js=COPY_ACTION_JS,
122
  )
123
 
@@ -126,11 +129,11 @@ def main():
126
  [input_character, input_copyright, input_general, prompt_enhancer_model],
127
  [output_text, copy_btn, copy_btn_pony],
128
  ).success(
129
- convert_danbooru_to_e621_prompt, [output_text, tag_type], [output_text_pony],
130
  ).success(
131
- insert_recom_prompt, [output_text, dummy_np, recom_animagine], [output_text, dummy_np],
132
  ).success(
133
- insert_recom_prompt, [output_text_pony, dummy_np, recom_pony], [output_text_pony, dummy_np],
134
  )
135
  copy_btn.click(gradio_copy_text, [output_text], js=COPY_ACTION_JS)
136
  copy_btn_pony.click(gradio_copy_text, [output_text_pony], js=COPY_ACTION_JS)
 
24
  from fl2sd3longcap import (
25
  predict_tags_fl2_sd3,
26
  )
27
+ from fl2basepromptgen import (
28
+ predict_tags_fl2_base_prompt_gen,
29
+ )
30
  from promptenhancer import prompt_enhancer
31
 
32
 
 
41
  - Models: p1atdev's [wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf),\
42
  gokaygokay's [Florence-2-SD3-Captioner](https://huggingface.co/gokaygokay/Florence-2-SD3-Captioner),\
43
  [Lamini-Prompt-Enchance](https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance),\
44
+ [Lamini-Prompt-Enchance-Long](https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long),\
45
+ MiaoshouAI's [Florence-2-base-PromptGen](https://huggingface.co/MiaoshouAI/Florence-2-base-PromptGen).
46
  """
47
  )
48
 
 
62
  input_tag_type = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
63
  recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
64
  keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
65
+ image_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner", "Use Florence-2-base-PromptGen"], label="Algorithms", value=["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"])
66
  generate_from_image_btn = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
67
 
68
  with gr.Group():
 
97
  output_text_pony = gr.TextArea(label="Output tags (Pony e621 style)", interactive=False, show_copy_button=True)
98
  copy_btn_pony = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
99
 
100
+ translate_input_prompt_button.click(translate_prompt, [input_general], [input_general], queue=False)
101
+ translate_input_prompt_button.click(translate_prompt, [input_character], [input_character], queue=False)
102
+ translate_input_prompt_button.click(translate_prompt, [input_copyright], [input_copyright], queue=False)
103
 
104
  generate_from_image_btn.click(
105
  predict_tags_wd,
106
  [input_image, input_general, image_algorithms, general_threshold, character_threshold],
107
+ [input_copyright, input_character, input_general, copy_input_btn],
 
 
 
 
 
108
  ).success(
109
  predict_tags_fl2_sd3,
110
  [input_image, input_general, image_algorithms],
111
  [input_general],
112
  ).success(
113
+ predict_tags_fl2_base_prompt_gen,
114
+ [input_image, input_general, image_algorithms],
115
+ [input_general],
116
+ ).success(
117
+ remove_specific_prompt, [input_general, keep_tags], [input_general], queue=False,
118
  ).success(
119
+ convert_danbooru_to_e621_prompt, [input_general, input_tag_type], [input_general], queue=False,
120
  ).success(
121
+ insert_recom_prompt, [input_general, dummy_np, recom_prompt], [input_general, dummy_np], queue=False,
122
  )
123
+ copy_input_btn.click(compose_prompt_to_copy, [input_character, input_copyright, input_general], [input_tags_to_copy], queue=False).success(
124
  gradio_copy_text, [input_tags_to_copy], js=COPY_ACTION_JS,
125
  )
126
 
 
129
  [input_character, input_copyright, input_general, prompt_enhancer_model],
130
  [output_text, copy_btn, copy_btn_pony],
131
  ).success(
132
+ convert_danbooru_to_e621_prompt, [output_text, tag_type], [output_text_pony], queue=False,
133
  ).success(
134
+ insert_recom_prompt, [output_text, dummy_np, recom_animagine], [output_text, dummy_np], queue=False,
135
  ).success(
136
+ insert_recom_prompt, [output_text_pony, dummy_np, recom_pony], [output_text_pony, dummy_np], queue=False,
137
  )
138
  copy_btn.click(gradio_copy_text, [output_text], js=COPY_ACTION_JS)
139
  copy_btn_pony.click(gradio_copy_text, [output_text_pony], js=COPY_ACTION_JS)
fl2basepromptgen.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForCausalLM
2
+ import spaces
3
+ import re
4
+ from PIL import Image
5
+
6
+ import subprocess
7
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
8
+
9
+ fl_model = AutoModelForCausalLM.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True).eval()
10
+ fl_processor = AutoProcessor.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True)
11
+
12
+
13
+ def fl_modify_caption(caption: str) -> str:
14
+ """
15
+ Removes specific prefixes from captions if present, otherwise returns the original caption.
16
+ Args:
17
+ caption (str): A string containing a caption.
18
+ Returns:
19
+ str: The caption with the prefix removed if it was present, or the original caption.
20
+ """
21
+ # Define the prefixes to remove
22
+ prefix_substrings = [
23
+ ('captured from ', ''),
24
+ ('captured at ', '')
25
+ ]
26
+
27
+ # Create a regex pattern to match any of the prefixes
28
+ pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
29
+ replacers = {opening.lower(): replacer for opening, replacer in prefix_substrings}
30
+
31
+ # Function to replace matched prefix with its corresponding replacement
32
+ def replace_fn(match):
33
+ return replacers[match.group(0).lower()]
34
+
35
+ # Apply the regex to the caption
36
+ modified_caption = re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
37
+
38
+ # If the caption was modified, return the modified version; otherwise, return the original
39
+ return modified_caption if modified_caption != caption else caption
40
+
41
+
42
+ @spaces.GPU
43
+ def fl_run_example(image):
44
+ task_prompt = "<GENERATE_PROMPT>"
45
+ prompt = task_prompt + "Describe this image in great detail."
46
+
47
+ # Ensure the image is in RGB mode
48
+ if image.mode != "RGB":
49
+ image = image.convert("RGB")
50
+
51
+ inputs = fl_processor(text=prompt, images=image, return_tensors="pt")
52
+ generated_ids = fl_model.generate(
53
+ input_ids=inputs["input_ids"],
54
+ pixel_values=inputs["pixel_values"],
55
+ max_new_tokens=1024,
56
+ do_sample=False,
57
+ num_beams=3
58
+ )
59
+ generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
60
+ parsed_answer = fl_processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
61
+ return parsed_answer["GENERATE_PROMPT>"]
62
+ #return fl_modify_caption(parsed_answer["GENERATE_PROMPT>"])
63
+
64
+
65
+ def predict_tags_fl2_base_prompt_gen(image: Image.Image, input_tags: str, algo: list[str]):
66
+ def to_list(s):
67
+ return [x.strip() for x in s.split(",") if not s == ""]
68
+
69
+ def list_uniq(l):
70
+ return sorted(set(l), key=l.index)
71
+
72
+ if not "Use Florence-2-base-PromptGen" in algo:
73
+ return input_tags
74
+ tag_list = list_uniq(to_list(input_tags) + to_list(fl_run_example(image) + ", "))
75
+ tag_list.remove("")
76
+ return ", ".join(tag_list)