K1Z3M1112 commited on
Commit
7e0f360
·
verified ·
1 Parent(s): 47dcfc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +524 -283
app.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  import torch
7
  from gradio.themes import Soft
8
  from gradio.themes.utils import colors, fonts, sizes
 
9
 
10
  colors.steel_blue = colors.Color(
11
  name="steel_blue",
@@ -36,186 +37,208 @@ class SteelBlueTheme(Soft):
36
  steel_blue_theme = SteelBlueTheme()
37
 
38
  print("=" * 50)
39
- print("🎨 Style2Paints - Enhanced Multi-Model Edition")
40
  print("=" * 50)
41
 
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
44
 
45
- from diffusers import StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline, ControlNetModel, DDIMScheduler, StableDiffusionPipeline, StableDiffusionXLPipeline
46
  from controlnet_aux import LineartDetector, LineartAnimeDetector
47
 
48
- # Model configurations
49
  AVAILABLE_MODELS = {
50
- "Anything V3 (Anime)": {
51
- "repo": "Linaqruf/anything-v3.0",
 
 
 
 
 
 
 
 
 
 
 
52
  "type": "sd15",
53
- "description": "Original anime model"
 
54
  },
55
- "ChikMix V3": {
56
- "repo": "digiplay/ChikMix_V3",
57
  "type": "sd15",
58
- "description": "High quality anime/realistic mix"
 
59
  },
60
- "ChilloutMix": {
61
- "repo": "digiplay/chilloutmix_NiPrunedFp16Fix",
62
  "type": "sd15",
63
- "description": "Popular realistic/anime hybrid"
 
64
  },
65
- "Pony Diffusion V6 XL": {
66
- "repo": "LyliaEngine/Pony_Diffusion_V6_XL",
67
  "type": "sdxl",
68
- "description": "SDXL anime model"
 
69
  },
70
- "AbyssOrangeMix3": {
71
- "repo": "wootwoot/abyssorangemix3-popupparade-fp16",
72
  "type": "sd15",
73
- "description": "Vibrant anime style"
 
74
  },
75
- "WAI-NSFW Illustrious XL": {
76
- "repo": "John6666/wai-nsfw-illustrious-v80-sdxl",
77
  "type": "sdxl",
78
- "description": "SDXL NSFW-focused model"
 
79
  }
80
  }
81
 
82
- # Global variables for models
83
- loaded_models = {}
84
- controlnet_standard = None
85
- controlnet_anime = None
86
  lineart_detector = None
87
  lineart_anime_detector = None
88
- txt2img_pipe = None
89
- current_txt2img_model = None
90
-
91
- def load_controlnet_models():
92
- """Load ControlNet models"""
93
- global controlnet_standard, controlnet_anime, lineart_detector, lineart_anime_detector
94
-
95
- print("📦 Loading ControlNet models...")
96
-
97
- try:
98
- controlnet_standard = ControlNetModel.from_pretrained(
99
- "lllyasviel/control_v11p_sd15_lineart",
100
- torch_dtype=dtype
101
- ).to(device)
102
- print("✅ Standard ControlNet loaded")
103
- except Exception as e:
104
- print(f"❌ Failed to load standard ControlNet: {e}")
105
-
106
- try:
107
- controlnet_anime = ControlNetModel.from_pretrained(
108
- "lllyasviel/control_v11p_sd15s2_lineart_anime",
109
- torch_dtype=dtype
110
- ).to(device)
111
- print("✅ Anime ControlNet loaded")
112
- except Exception as e:
113
- print(f"❌ Failed to load anime ControlNet: {e}")
114
-
115
- try:
116
- lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
117
- lineart_anime_detector = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
118
- print("✅ Line art detectors loaded")
119
- except Exception as e:
120
- print(f"❌ Failed to load detectors: {e}")
121
 
122
- def load_model_for_controlnet(model_name, lineart_type):
123
- """Load a specific model with ControlNet"""
124
- global loaded_models
125
-
126
- cache_key = f"{model_name}_{lineart_type}"
127
 
128
- if cache_key in loaded_models:
129
- print(f"♻️ Using cached model: {cache_key}")
130
- return loaded_models[cache_key]
131
-
132
- model_info = AVAILABLE_MODELS[model_name]
133
- controlnet = controlnet_anime if lineart_type == "Anime" else controlnet_standard
134
-
135
- if controlnet is None:
136
- raise gr.Error("ControlNet model not loaded")
137
-
138
- print(f"📦 Loading {model_name}...")
139
 
140
  try:
141
- if model_info["type"] == "sdxl":
142
- # SDXL models don't support SD1.5 ControlNet currently
143
- raise gr.Error("SDXL models are not yet supported for line art colorization. Use SD1.5 models.")
 
 
 
 
 
144
 
145
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
146
- model_info["repo"],
147
- controlnet=controlnet,
148
- torch_dtype=dtype,
149
- safety_checker=None,
150
- requires_safety_checker=False
151
- ).to(device)
152
-
153
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
154
-
155
- if device.type == "cuda":
156
- pipe.enable_model_cpu_offload()
157
- try:
158
- pipe.enable_xformers_memory_efficient_attention()
159
- except:
160
- pass
161
- pipe.enable_attention_slicing()
162
 
163
- loaded_models[cache_key] = pipe
164
- print(f" {model_name} loaded successfully")
165
- return pipe
166
 
167
- except Exception as e:
168
- print(f"❌ Error loading {model_name}: {e}")
169
- raise gr.Error(f"Failed to load {model_name}: {str(e)}")
170
-
171
- def load_txt2img_model(model_name):
172
- """Load model for text-to-image"""
173
- global txt2img_pipe, current_txt2img_model
174
-
175
- if current_txt2img_model == model_name and txt2img_pipe is not None:
176
- print(f"♻️ Using cached txt2img model: {model_name}")
177
- return txt2img_pipe
178
-
179
- model_info = AVAILABLE_MODELS[model_name]
180
- print(f"📦 Loading {model_name} for text-to-image...")
181
-
182
- try:
183
- if model_info["type"] == "sdxl":
184
  pipe = StableDiffusionXLPipeline.from_pretrained(
185
- model_info["repo"],
186
  torch_dtype=dtype,
187
  safety_checker=None,
188
- requires_safety_checker=False
 
 
189
  ).to(device)
 
 
 
 
190
  else:
 
191
  pipe = StableDiffusionPipeline.from_pretrained(
192
- model_info["repo"],
193
  torch_dtype=dtype,
194
  safety_checker=None,
195
- requires_safety_checker=False
 
 
196
  ).to(device)
 
 
 
197
 
198
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
199
-
200
  if device.type == "cuda":
201
  pipe.enable_model_cpu_offload()
202
  try:
203
  pipe.enable_xformers_memory_efficient_attention()
204
  except:
205
- pass
206
- pipe.enable_attention_slicing()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- txt2img_pipe = pipe
209
- current_txt2img_model = model_name
210
- print(f"✅ {model_name} loaded for txt2img")
211
- return pipe
212
 
213
  except Exception as e:
214
- print(f"❌ Error loading {model_name}: {e}")
215
- raise gr.Error(f"Failed to load {model_name}: {str(e)}")
216
 
217
- # Load ControlNet models at startup
218
- load_controlnet_models()
219
 
220
  COLOR_STYLES = {
221
  "Anime Style": "anime, masterpiece, best quality, highly detailed, vibrant colors",
@@ -239,7 +262,7 @@ CONTENT_TEMPLATES = {
239
  }
240
 
241
  def is_already_lineart(image):
242
- """Check if image is already line art"""
243
  if isinstance(image, Image.Image):
244
  image = np.array(image)
245
 
@@ -250,15 +273,15 @@ def is_already_lineart(image):
250
  return black_white_ratio > 0.7 or unique_vals < 30
251
 
252
  def extract_lineart(image, lineart_type="Standard", skip_if_lineart=True):
253
- """Extract line art from image"""
254
  if isinstance(image, np.ndarray):
255
  image = Image.fromarray(image)
256
 
257
  if skip_if_lineart and is_already_lineart(image):
258
- print("✅ Already line art, skipping extraction")
259
  return image.convert('RGB')
260
 
261
- print(f"🔄 Extracting line art ({lineart_type})...")
262
 
263
  if lineart_type == "Anime" and lineart_anime_detector is not None:
264
  lineart = lineart_anime_detector(image, detect_resolution=512, image_resolution=512)
@@ -272,7 +295,6 @@ def extract_lineart(image, lineart_type="Standard", skip_if_lineart=True):
272
 
273
  def colorize_lineart(
274
  sketch_image,
275
- model_name,
276
  lineart_type,
277
  content_type,
278
  style,
@@ -286,14 +308,22 @@ def colorize_lineart(
286
  controlnet_strength,
287
  progress=gr.Progress(track_tqdm=True)
288
  ):
289
- """Colorize with explicit content support"""
290
  if sketch_image is None:
291
- raise gr.Error("Please upload a sketch/line art image")
292
 
293
- # Load the selected model
294
- pipe = load_model_for_controlnet(model_name, lineart_type)
 
 
 
 
 
295
 
296
- # Convert all numeric inputs to proper types
 
 
 
297
  seed = int(seed)
298
  guidance_scale = float(guidance_scale)
299
  num_steps = int(num_steps)
@@ -305,7 +335,7 @@ def colorize_lineart(
305
 
306
  generator = torch.Generator(device=device).manual_seed(seed)
307
 
308
- # Convert and resize
309
  if isinstance(sketch_image, np.ndarray):
310
  sketch_image = Image.fromarray(sketch_image)
311
 
@@ -324,30 +354,30 @@ def colorize_lineart(
324
  new_height = (new_height // 8) * 8
325
  sketch_image = sketch_image.resize((new_width, new_height), Image.LANCZOS)
326
 
327
- # Extract lineart with selected type
328
  lineart = extract_lineart(sketch_image, lineart_type=lineart_type, skip_if_lineart=True)
329
 
330
- # Build prompt intelligently
331
  prompt_parts = []
332
 
333
- # Content template
334
  content_template = CONTENT_TEMPLATES.get(content_type, "")
335
  if content_template:
336
  prompt_parts.append(content_template)
337
 
338
- # Custom prompt (most important)
339
  if custom_prompt.strip():
340
  prompt_parts.append(custom_prompt.strip())
341
 
342
- # Quality tags
343
  if quality_tags:
344
  prompt_parts.append(quality_tags)
345
 
346
- # Style
347
  style_prompt = COLOR_STYLES.get(style, COLOR_STYLES["Anime Style"])
348
  prompt_parts.append(style_prompt)
349
 
350
- # NSFW level tags
351
  if nsfw_level == "Safe":
352
  nsfw_tags = "sfw, safe for work"
353
  elif nsfw_level == "Suggestive":
@@ -363,18 +393,18 @@ def colorize_lineart(
363
 
364
  full_prompt = ", ".join(prompt_parts)
365
 
366
- # Negative prompt
367
  negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, jpeg artifacts, signature, watermark, username, blurry, artist name, black and white, monochrome"
368
 
369
  if nsfw_level in ["Moderate", "Explicit"]:
370
  negative_prompt = negative_prompt.replace("nsfw, ", "")
371
 
372
- print(f"🎨 Model: {model_name}")
373
- print(f"🎨 Prompt: {full_prompt}")
374
- print(f"🎛️ ControlNet Strength: {controlnet_strength}")
375
 
376
  try:
377
- progress(0.3, desc="Generating colors...")
378
 
379
  result = pipe(
380
  prompt=full_prompt,
@@ -393,10 +423,10 @@ def colorize_lineart(
393
 
394
  except Exception as e:
395
  import traceback
396
- print(f"❌ Full error: {traceback.format_exc()}")
397
- raise gr.Error(f"Error: {str(e)}")
398
 
399
- def generate_txt2img(
400
  model_name,
401
  content_type,
402
  style,
@@ -411,12 +441,12 @@ def generate_txt2img(
411
  height,
412
  progress=gr.Progress(track_tqdm=True)
413
  ):
414
- """Generate image from text only"""
415
-
416
- # Load the selected model
417
- pipe = load_txt2img_model(model_name)
418
 
419
- # Convert all numeric inputs to proper types
420
  seed = int(seed)
421
  guidance_scale = float(guidance_scale)
422
  num_steps = int(num_steps)
@@ -429,27 +459,27 @@ def generate_txt2img(
429
 
430
  generator = torch.Generator(device=device).manual_seed(seed)
431
 
432
- # Build prompt
433
  prompt_parts = []
434
 
435
- # Content template
436
  content_template = CONTENT_TEMPLATES.get(content_type, "")
437
  if content_template:
438
  prompt_parts.append(content_template)
439
 
440
- # Custom prompt
441
  if custom_prompt.strip():
442
  prompt_parts.append(custom_prompt.strip())
443
 
444
- # Quality tags
445
  if quality_tags:
446
  prompt_parts.append(quality_tags)
447
 
448
- # Style
449
  style_prompt = COLOR_STYLES.get(style, COLOR_STYLES["Anime Style"])
450
  prompt_parts.append(style_prompt)
451
 
452
- # NSFW level tags
453
  if nsfw_level == "Safe":
454
  nsfw_tags = "sfw, safe for work"
455
  elif nsfw_level == "Suggestive":
@@ -465,28 +495,27 @@ def generate_txt2img(
465
 
466
  full_prompt = ", ".join(prompt_parts)
467
 
468
- # Negative prompt
469
- negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, jpeg artifacts, signature, watermark, username, blurry, artist name"
470
 
471
  if nsfw_level in ["Moderate", "Explicit"]:
472
  negative_prompt = negative_prompt.replace("nsfw, ", "")
473
 
474
- print(f"🎨 Model: {model_name}")
475
- print(f"🎨 Prompt: {full_prompt}")
 
 
476
 
477
  try:
478
- progress(0.3, desc="Generating image...")
479
-
480
- # Check if SDXL model
481
- model_info = AVAILABLE_MODELS[model_name]
482
 
483
- result = pipe(
484
  prompt=full_prompt,
485
  negative_prompt=negative_prompt,
486
- num_inference_steps=num_steps,
487
- guidance_scale=guidance_scale,
488
  width=width,
489
  height=height,
 
 
490
  generator=generator,
491
  ).images[0]
492
 
@@ -497,8 +526,21 @@ def generate_txt2img(
497
 
498
  except Exception as e:
499
  import traceback
500
- print(f"❌ Full error: {traceback.format_exc()}")
501
- raise gr.Error(f"Error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
  css="""
504
  #col-container {
@@ -544,212 +586,411 @@ css="""
544
  font-size: 0.9em;
545
  margin: 5px;
546
  }
 
 
 
 
 
 
547
  """
548
 
549
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
550
  with gr.Column(elem_id="col-container"):
551
- gr.Markdown("# 🎨 Style2Paints - Multi-Model Edition", elem_id="main-title")
552
- gr.Markdown("### ✨ Professional AI Art Generation with Multiple Models")
553
 
554
  gr.HTML("""
555
  <div class="warning-box">
556
- <strong>⚠️ Content Warning:</strong> This tool supports generation of all content types including NSFW/adult material.
557
- Use responsibly and ensure compliance with local laws. Users must be 18+ for explicit content.
558
  </div>
559
  """)
560
 
561
- with gr.Tabs():
562
- # Tab 1: Line Art Colorization
563
- with gr.Tab("🖊️ Line Art Colorization"):
564
- gr.HTML("""
565
- <div class="feature-box">
566
- <h3>✨ Colorize Your Line Art</h3>
567
- <ul style="color:white; font-size:1.1em;">
568
- <li>🎨 <strong>6 Anime Models</strong> - Choose your favorite style</li>
569
- <li>🖊️ <strong>Dual Lineart Detection</strong> - Standard & Anime optimized</li>
570
- <li>🎭 <strong>Smart Prompt Building</strong> - Easy content creation</li>
571
- <li>🎚️ <strong>NSFW Support</strong> - All content levels supported</li>
572
- </ul>
573
- </div>
574
- """)
575
-
 
 
 
 
 
 
576
  with gr.Row():
577
  with gr.Column(scale=1):
578
  input_image = gr.Image(
579
- label="📤 Upload Line Art",
580
  type="pil",
581
  height=400
582
  )
583
 
584
- # Get SD1.5 models only for ControlNet
585
- sd15_models = [name for name, info in AVAILABLE_MODELS.items() if info["type"] == "sd15"]
586
-
587
- model_selector = gr.Dropdown(
588
- choices=sd15_models,
589
- label="🤖 Select Model",
590
- value="Anything V3 (Anime)",
591
- info="Choose your preferred model (SD1.5 only for line art)"
592
- )
593
 
594
  lineart_type = gr.Radio(
595
  choices=["Standard", "Anime"],
596
- label="🖊️ Lineart Model",
597
  value="Anime",
598
- info="Anime works better for anime/manga"
599
  )
600
 
601
  content_type = gr.Dropdown(
602
  choices=list(CONTENT_TEMPLATES.keys()),
603
- label="📋 Content Template",
604
- value="Character Portrait"
 
605
  )
606
 
607
  custom_prompt = gr.Textbox(
608
- label="✍️ Detailed Description",
609
- placeholder="hair color, outfit, pose, features...",
610
- lines=3
 
611
  )
612
 
 
 
 
 
 
 
 
 
 
 
613
  with gr.Row():
614
  style = gr.Dropdown(
615
  choices=list(COLOR_STYLES.keys()),
616
- label="🎨 Style",
617
  value="Anime Style"
618
  )
619
 
620
  nsfw_level = gr.Dropdown(
621
  choices=["Safe", "Suggestive", "Mild", "Moderate", "Explicit"],
622
- label="🔞 Content Level",
623
- value="Moderate"
 
624
  )
625
 
626
  quality_tags = gr.Textbox(
627
- label="⭐ Quality Tags",
 
628
  value="masterpiece, best quality, highly detailed"
629
  )
630
 
631
- colorize_button = gr.Button("✨ Colorize!", variant="primary", size="lg")
632
 
633
  with gr.Column(scale=2):
634
  with gr.Row():
635
- lineart_output = gr.Image(label="🖊️ Line Art", type="pil", height=380)
636
- colorized_output = gr.Image(label="🎨 Result", type="pil", height=380)
 
 
 
 
 
 
 
 
637
 
638
- colorized_prompt = gr.Textbox(
639
- label="📝 Generated Prompt",
640
  lines=3,
641
  interactive=False,
642
  show_copy_button=True
643
  )
644
 
645
- with gr.Accordion("⚙️ Advanced Settings", open=False):
646
  with gr.Row():
647
- c_seed = gr.Slider(label="🎲 Seed", minimum=0, maximum=2**32-1, step=1, value=42)
648
- c_random = gr.Checkbox(label="🔀 Random", value=True)
 
 
 
 
 
 
 
 
 
649
 
650
  with gr.Row():
651
- c_guidance = gr.Slider(label="💬 Guidance", minimum=5.0, maximum=15.0, step=0.5, value=8.0)
652
- c_steps = gr.Slider(label="🔢 Steps", minimum=10, maximum=30, step=5, value=20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
 
654
- c_strength = gr.Slider(
655
- label="🎛️ Line Preservation",
656
  minimum=0.5,
657
  maximum=1.5,
658
  step=0.1,
659
- value=1.0
 
660
  )
661
 
662
  colorize_button.click(
663
  fn=colorize_lineart,
664
  inputs=[
665
- input_image, model_selector, lineart_type, content_type, style,
666
- custom_prompt, quality_tags, nsfw_level, c_seed, c_random,
667
- c_guidance, c_steps, c_strength
668
  ],
669
- outputs=[colorized_output, lineart_output, c_seed, colorized_prompt]
670
  )
671
 
672
- # Tab 2: Text to Image
673
- with gr.Tab("✍️ Text to Image"):
674
- gr.HTML("""
675
- <div class="feature-box">
676
- <h3>✨ Generate Images from Text</h3>
677
- <ul style="color:white; font-size:1.1em;">
678
- <li>🎨 <strong>6 Models Including SDXL</strong> - Maximum variety</li>
679
- <li>✍️ <strong>Pure Text Generation</strong> - No line art needed</li>
680
- <li>📐 <strong>Flexible Sizes</strong> - From 512x512 to 1024x1024</li>
681
- <li>🎭 <strong>All Content Types</strong> - SFW to NSFW</li>
682
- </ul>
683
- </div>
684
- """)
685
-
686
  with gr.Row():
687
  with gr.Column(scale=1):
688
- t2i_model = gr.Dropdown(
689
- choices=list(AVAILABLE_MODELS.keys()),
690
- label="🤖 Select Model",
691
- value="Anything V3 (Anime)",
692
- info="All models available for txt2img"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
  )
694
 
695
- t2i_content = gr.Dropdown(
 
 
696
  choices=list(CONTENT_TEMPLATES.keys()),
697
- label="📋 Content Template",
698
- value="Character Portrait"
 
699
  )
700
 
701
- t2i_prompt = gr.Textbox(
702
- label="✍️ Your Description",
703
- placeholder="Describe what you want to see...",
704
- lines=4
 
705
  )
706
 
707
  gr.HTML("""
708
  <div class="info-box">
709
- <strong>💡 Example Prompts:</strong><br>
710
- • "girl with long silver hair, blue eyes, wearing kimono, cherry blossoms"<br>
711
- • "cyberpunk girl, neon lights, futuristic city, night scene"<br>
712
- • "maid outfit, cat ears, blushing, bedroom"
 
713
  </div>
714
  """)
715
 
716
  with gr.Row():
717
  t2i_style = gr.Dropdown(
718
  choices=list(COLOR_STYLES.keys()),
719
- label="🎨 Style",
720
  value="Anime Style"
721
  )
722
 
723
- t2i_nsfw = gr.Dropdown(
724
  choices=["Safe", "Suggestive", "Mild", "Moderate", "Explicit"],
725
- label="🔞 Content Level",
726
- value="Safe"
 
727
  )
728
 
729
- t2i_quality = gr.Textbox(
730
- label="⭐ Quality Tags",
 
731
  value="masterpiece, best quality, highly detailed"
732
  )
733
 
734
- with gr.Row():
735
- t2i_width = gr.Slider(label="📏 Width", minimum=512, maximum=1024, step=64, value=512)
736
- t2i_height = gr.Slider(label="📏 Height", minimum=512, maximum=1024, step=64, value=512)
737
-
738
- generate_button = gr.Button("🎨 Generate!", variant="primary", size="lg")
739
 
740
  with gr.Column(scale=2):
741
- t2i_output = gr.Image(label="🎨 Generated Image", type="pil", height=600)
 
 
 
 
742
 
743
- t2i_prompt_output = gr.Textbox(
744
- label="📝 Full Prompt Used",
745
  lines=3,
746
  interactive=False,
747
  show_copy_button=True
748
  )
749
 
750
- with gr.Accordion("⚙️ Advanced Settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
751
  with gr.Row():
752
- t2i_seed = gr.Slider(label="🎲 Seed", minimum=0, maximum=
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
753
  """)
754
 
755
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
6
  import torch
7
  from gradio.themes import Soft
8
  from gradio.themes.utils import colors, fonts, sizes
9
+ import gc
10
 
11
  colors.steel_blue = colors.Color(
12
  name="steel_blue",
 
37
  steel_blue_theme = SteelBlueTheme()
38
 
39
  print("=" * 50)
40
+ print("🎨 Style2Paints - Uncensored Line Art Colorization & Text-to-Image")
41
  print("=" * 50)
42
 
43
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
45
 
46
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDIMScheduler, StableDiffusionPipeline, StableDiffusionXLPipeline, EulerDiscreteScheduler
47
  from controlnet_aux import LineartDetector, LineartAnimeDetector
48
 
49
+ # ===== 模型配置 =====
50
  AVAILABLE_MODELS = {
51
+ "LineArt Colorization": ["Anything V3 (ControlNet)"],
52
+ "Text-to-Image": [
53
+ "Linaqruf/anything-v3.0",
54
+ "digiplay/ChikMix_V3",
55
+ "digiplay/chilloutmix_NiPrunedFp16Fix",
56
+ "LyliaEngine/Pony_Diffusion_V6_XL",
57
+ "wootwoot/abyssorangemix3-popupparade-fp16",
58
+ "John6666/wai-nsfw-illustrious-v80-sdxl"
59
+ ]
60
+ }
61
+
62
+ MODEL_CONFIGS = {
63
+ "Linaqruf/anything-v3.0": {
64
  "type": "sd15",
65
+ "description": "Anything V3 - 全能模型",
66
+ "default_resolution": (512, 768)
67
  },
68
+ "digiplay/ChikMix_V3": {
 
69
  "type": "sd15",
70
+ "description": "ChikMix V3 - 高质量动漫模型",
71
+ "default_resolution": (512, 768)
72
  },
73
+ "digiplay/chilloutmix_NiPrunedFp16Fix": {
 
74
  "type": "sd15",
75
+ "description": "ChilloutMix - 真人风格",
76
+ "default_resolution": (512, 768)
77
  },
78
+ "LyliaEngine/Pony_Diffusion_V6_XL": {
 
79
  "type": "sdxl",
80
+ "description": "Pony Diffusion V6 XL - SDXL动漫模型",
81
+ "default_resolution": (1024, 1024)
82
  },
83
+ "wootwoot/abyssorangemix3-popupparade-fp16": {
 
84
  "type": "sd15",
85
+ "description": "AbyssOrangeMix3 - 色彩鲜艳",
86
+ "default_resolution": (512, 768)
87
  },
88
+ "John6666/wai-nsfw-illustrious-v80-sdxl": {
 
89
  "type": "sdxl",
90
+ "description": "WAI NSFW Illustrious - SDXL成人内容优化",
91
+ "default_resolution": (1024, 1024)
92
  }
93
  }
94
 
95
+ # ===== 全局模型变量 =====
96
+ pipe_standard = None
97
+ pipe_anime = None
 
98
  lineart_detector = None
99
  lineart_anime_detector = None
100
+ current_t2i_model = None
101
+ current_t2i_pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ def load_text_to_image_model(model_name, progress=gr.Progress()):
104
+ """动态加载文本到图像模型"""
105
+ global current_t2i_model, current_t2i_pipe
 
 
106
 
107
+ if model_name == current_t2i_model and current_t2i_pipe is not None:
108
+ print(f" 模型 {model_name} 已加载,跳过重新加载")
109
+ return True
 
 
 
 
 
 
 
 
110
 
111
  try:
112
+ # 清理之前的模型
113
+ if current_t2i_pipe is not None:
114
+ del current_t2i_pipe
115
+ current_t2i_pipe = None
116
+ current_t2i_model = None
117
+ gc.collect()
118
+ if torch.cuda.is_available():
119
+ torch.cuda.empty_cache()
120
 
121
+ print(f"🔄 正在加载模型: {model_name}")
122
+ progress(0.3, desc=f"正在加载 {model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ model_config = MODEL_CONFIGS.get(model_name, {})
125
+ model_type = model_config.get("type", "sd15")
 
126
 
127
+ if model_type == "sdxl":
128
+ # SDXL 模型
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  pipe = StableDiffusionXLPipeline.from_pretrained(
130
+ model_name,
131
  torch_dtype=dtype,
132
  safety_checker=None,
133
+ requires_safety_checker=False,
134
+ use_safetensors=True,
135
+ variant="fp16" if dtype == torch.float16 else None
136
  ).to(device)
137
+
138
+ # SDXL 推荐使用 Euler scheduler
139
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
140
+
141
  else:
142
+ # SD1.5 模型
143
  pipe = StableDiffusionPipeline.from_pretrained(
144
+ model_name,
145
  torch_dtype=dtype,
146
  safety_checker=None,
147
+ requires_safety_checker=False,
148
+ use_safetensors=True,
149
+ variant="fp16" if dtype == torch.float16 else None
150
  ).to(device)
151
+
152
+ # SD1.5 使用 DDIM
153
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
154
 
155
+ # 优化设置
 
156
  if device.type == "cuda":
157
  pipe.enable_model_cpu_offload()
158
  try:
159
  pipe.enable_xformers_memory_efficient_attention()
160
  except:
161
+ print("⚠️ XFormers 不可用,跳过")
162
+ pipe.enable_attention_slicing()
163
+
164
+ current_t2i_model = model_name
165
+ current_t2i_pipe = pipe
166
+
167
+ print(f"✅ 模型 {model_name} 加载成功!")
168
+ progress(1.0, desc="模型加载完成")
169
+ return True
170
+
171
+ except Exception as e:
172
+ import traceback
173
+ print(f"❌ 加载模型失败: {str(e)}")
174
+ print(f"详细错误: {traceback.format_exc()}")
175
+ return False
176
+
177
+ def load_lineart_models():
178
+ """加载线稿着色模型"""
179
+ global pipe_standard, pipe_anime, lineart_detector, lineart_anime_detector
180
+
181
+ try:
182
+ print("🔄 加载线稿着色模型...")
183
+
184
+ # Load STANDARD lineart ControlNet
185
+ print("📦 加载标准线稿模型...")
186
+ controlnet_standard = ControlNetModel.from_pretrained(
187
+ "lllyasviel/control_v11p_sd15_lineart",
188
+ torch_dtype=dtype
189
+ ).to(device)
190
+
191
+ # Load ANIME lineart ControlNet
192
+ print("📦 加载动漫线稿模型...")
193
+ controlnet_anime = ControlNetModel.from_pretrained(
194
+ "lllyasviel/control_v11p_sd15s2_lineart_anime",
195
+ torch_dtype=dtype
196
+ ).to(device)
197
+
198
+ # 使用 Anything V3 作为基础模型
199
+ print("📦 加载基础模型 (Anything V3)...")
200
+ pipe_standard = StableDiffusionControlNetPipeline.from_pretrained(
201
+ "Linaqruf/anything-v3.0",
202
+ controlnet=controlnet_standard,
203
+ torch_dtype=dtype,
204
+ safety_checker=None,
205
+ requires_safety_checker=False
206
+ ).to(device)
207
+
208
+ pipe_anime = StableDiffusionControlNetPipeline.from_pretrained(
209
+ "Linaqruf/anything-v3.0",
210
+ controlnet=controlnet_anime,
211
+ torch_dtype=dtype,
212
+ safety_checker=None,
213
+ requires_safety_checker=False
214
+ ).to(device)
215
+
216
+ # 配置两个管道
217
+ for pipe in [pipe_standard, pipe_anime]:
218
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
219
+
220
+ if device.type == "cuda":
221
+ pipe.enable_model_cpu_offload()
222
+ try:
223
+ pipe.enable_xformers_memory_efficient_attention()
224
+ except:
225
+ pass
226
+ pipe.enable_attention_slicing()
227
+
228
+ # 加载线稿检测器
229
+ print("📦 加载线稿检测器...")
230
+ lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
231
+ lineart_anime_detector = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
232
 
233
+ print("✅ 线稿着色模型加载成功!")
234
+ return True
 
 
235
 
236
  except Exception as e:
237
+ print(f"❌ 加载线稿模型失败: {e}")
238
+ return False
239
 
240
+ # 加载线稿着色模型
241
+ load_lineart_models()
242
 
243
  COLOR_STYLES = {
244
  "Anime Style": "anime, masterpiece, best quality, highly detailed, vibrant colors",
 
262
  }
263
 
264
  def is_already_lineart(image):
265
+ """检查图像是否已经是线稿"""
266
  if isinstance(image, Image.Image):
267
  image = np.array(image)
268
 
 
273
  return black_white_ratio > 0.7 or unique_vals < 30
274
 
275
  def extract_lineart(image, lineart_type="Standard", skip_if_lineart=True):
276
+ """从图像中提取线稿"""
277
  if isinstance(image, np.ndarray):
278
  image = Image.fromarray(image)
279
 
280
  if skip_if_lineart and is_already_lineart(image):
281
+ print("✅ 已经是线稿,跳过提取")
282
  return image.convert('RGB')
283
 
284
+ print(f"🔄 提取线稿 ({lineart_type})...")
285
 
286
  if lineart_type == "Anime" and lineart_anime_detector is not None:
287
  lineart = lineart_anime_detector(image, detect_resolution=512, image_resolution=512)
 
295
 
296
  def colorize_lineart(
297
  sketch_image,
 
298
  lineart_type,
299
  content_type,
300
  style,
 
308
  controlnet_strength,
309
  progress=gr.Progress(track_tqdm=True)
310
  ):
311
+ """线稿着色"""
312
  if sketch_image is None:
313
+ raise gr.Error("请上传线稿图像")
314
 
315
+ # 根据线稿类型选择管道
316
+ if lineart_type == "Anime" and pipe_anime is not None:
317
+ pipe = pipe_anime
318
+ print("🎨 使用动漫线稿模型")
319
+ else:
320
+ pipe = pipe_standard
321
+ print("🎨 使用标准线稿模型")
322
 
323
+ if pipe is None:
324
+ raise gr.Error("模型未加载")
325
+
326
+ # 转换数值输入
327
  seed = int(seed)
328
  guidance_scale = float(guidance_scale)
329
  num_steps = int(num_steps)
 
335
 
336
  generator = torch.Generator(device=device).manual_seed(seed)
337
 
338
+ # 转换和调整大小
339
  if isinstance(sketch_image, np.ndarray):
340
  sketch_image = Image.fromarray(sketch_image)
341
 
 
354
  new_height = (new_height // 8) * 8
355
  sketch_image = sketch_image.resize((new_width, new_height), Image.LANCZOS)
356
 
357
+ # 提取线稿
358
  lineart = extract_lineart(sketch_image, lineart_type=lineart_type, skip_if_lineart=True)
359
 
360
+ # 构建提示词
361
  prompt_parts = []
362
 
363
+ # 内容模板
364
  content_template = CONTENT_TEMPLATES.get(content_type, "")
365
  if content_template:
366
  prompt_parts.append(content_template)
367
 
368
+ # 自定义提示词
369
  if custom_prompt.strip():
370
  prompt_parts.append(custom_prompt.strip())
371
 
372
+ # 质量标签
373
  if quality_tags:
374
  prompt_parts.append(quality_tags)
375
 
376
+ # 风格
377
  style_prompt = COLOR_STYLES.get(style, COLOR_STYLES["Anime Style"])
378
  prompt_parts.append(style_prompt)
379
 
380
+ # NSFW 级别标签
381
  if nsfw_level == "Safe":
382
  nsfw_tags = "sfw, safe for work"
383
  elif nsfw_level == "Suggestive":
 
393
 
394
  full_prompt = ", ".join(prompt_parts)
395
 
396
+ # 负面提示词
397
  negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, jpeg artifacts, signature, watermark, username, blurry, artist name, black and white, monochrome"
398
 
399
  if nsfw_level in ["Moderate", "Explicit"]:
400
  negative_prompt = negative_prompt.replace("nsfw, ", "")
401
 
402
+ print(f"🎨 提示词: {full_prompt}")
403
+ print(f"🎛️ ControlNet 强度: {controlnet_strength}")
404
+ print(f"🖼️ 线稿类型: {lineart_type}")
405
 
406
  try:
407
+ progress(0.3, desc="正在生成颜色...")
408
 
409
  result = pipe(
410
  prompt=full_prompt,
 
423
 
424
  except Exception as e:
425
  import traceback
426
+ print(f"❌ 完整错误: {traceback.format_exc()}")
427
+ raise gr.Error(f"错误: {str(e)}")
428
 
429
+ def generate_text_to_image(
430
  model_name,
431
  content_type,
432
  style,
 
441
  height,
442
  progress=gr.Progress(track_tqdm=True)
443
  ):
444
+ """文本到图像生成"""
445
+ # 加载模型
446
+ if not load_text_to_image_model(model_name, progress):
447
+ raise gr.Error(f"无法加载模型: {model_name}")
448
 
449
+ # 转换数值输入
450
  seed = int(seed)
451
  guidance_scale = float(guidance_scale)
452
  num_steps = int(num_steps)
 
459
 
460
  generator = torch.Generator(device=device).manual_seed(seed)
461
 
462
+ # 构建提示词
463
  prompt_parts = []
464
 
465
+ # 内容模板
466
  content_template = CONTENT_TEMPLATES.get(content_type, "")
467
  if content_template:
468
  prompt_parts.append(content_template)
469
 
470
+ # 自定义提示词
471
  if custom_prompt.strip():
472
  prompt_parts.append(custom_prompt.strip())
473
 
474
+ # 质量标签
475
  if quality_tags:
476
  prompt_parts.append(quality_tags)
477
 
478
+ # 风格
479
  style_prompt = COLOR_STYLES.get(style, COLOR_STYLES["Anime Style"])
480
  prompt_parts.append(style_prompt)
481
 
482
+ # NSFW 级别标签
483
  if nsfw_level == "Safe":
484
  nsfw_tags = "sfw, safe for work"
485
  elif nsfw_level == "Suggestive":
 
495
 
496
  full_prompt = ", ".join(prompt_parts)
497
 
498
+ # 负面提示词
499
+ negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, jpeg artifacts, signature, watermark, username, blurry, artist name, black and white, monochrome"
500
 
501
  if nsfw_level in ["Moderate", "Explicit"]:
502
  negative_prompt = negative_prompt.replace("nsfw, ", "")
503
 
504
+ print(f"���� 提示词: {full_prompt}")
505
+ print(f"📐 分辨率: {width}x{height}")
506
+ print(f"🎛️ 引导尺度: {guidance_scale}")
507
+ print(f"🔄 步数: {num_steps}")
508
 
509
  try:
510
+ progress(0.5, desc="正在生成图像...")
 
 
 
511
 
512
+ result = current_t2i_pipe(
513
  prompt=full_prompt,
514
  negative_prompt=negative_prompt,
 
 
515
  width=width,
516
  height=height,
517
+ num_inference_steps=num_steps,
518
+ guidance_scale=guidance_scale,
519
  generator=generator,
520
  ).images[0]
521
 
 
526
 
527
  except Exception as e:
528
  import traceback
529
+ print(f"❌ 完整错误: {traceback.format_exc()}")
530
+ raise gr.Error(f"错误: {str(e)}")
531
+
532
+ def update_resolution_from_model(model_name):
533
+ """根据选择的模型更新推荐分辨率"""
534
+ config = MODEL_CONFIGS.get(model_name, {})
535
+ default_res = config.get("default_resolution", (512, 768))
536
+ description = config.get("description", "通用模型")
537
+
538
+ width, height = default_res
539
+ return (
540
+ gr.update(value=width, minimum=256, maximum=2048, step=8),
541
+ gr.update(value=height, minimum=256, maximum=2048, step=8),
542
+ gr.update(value=f"📊 推荐分辨率: {width}x{height} ({description})")
543
+ )
544
 
545
  css="""
546
  #col-container {
 
586
  font-size: 0.9em;
587
  margin: 5px;
588
  }
589
+ .tab-buttons {
590
+ margin-bottom: 20px;
591
+ }
592
+ .tab-nav {
593
+ border-bottom: 2px solid #e0e0e0;
594
+ }
595
  """
596
 
597
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
598
  with gr.Column(elem_id="col-container"):
599
+ gr.Markdown("# 🎨 Style2Paints - 全能图像生成", elem_id="main-title")
600
+ gr.Markdown("### ✨ 专业线稿着色与文本生成图像")
601
 
602
  gr.HTML("""
603
  <div class="warning-box">
604
+ <strong>⚠️ 内容警告:</strong> 此工具支持所有类型内容的生成,包括 NSFW/成人内容。
605
+ 请负责任地使用,并确保符合当地法律。生成成人内容必须年满18岁。
606
  </div>
607
  """)
608
 
609
+ gr.HTML("""
610
+ <div class="feature-box">
611
+ <h3>✨ 核心功能</h3>
612
+ <ul style="color:white; font-size:1.1em;">
613
+ <li>🎨 <strong>双线稿模型</strong> - 标准和动漫专用线稿检测</li>
614
+ <li>🖼️ <strong>文本生成图像</strong> - 从文本描述直接生成图像</li>
615
+ <li>🎭 <strong>多模型支持</strong> - 6种不同风格的模型可选</li>
616
+ <li>📝 <strong>内容模板</strong> - 常见场景的预设提示词</li>
617
+ <li>🎚️ <strong>NSFW级别控制</strong> - 精确的内容级别控制</li>
618
+ <li>⚡ <strong>智能模型加载</strong> - 按需加载,节省显存</li>
619
+ </ul>
620
+ <div style="margin-top:15px;">
621
+ <span class="model-badge">6种文本生成模型</span>
622
+ <span class="model-badge">2种线稿模型</span>
623
+ </div>
624
+ </div>
625
+ """)
626
+
627
+ with gr.Tabs() as tabs:
628
+ # ===== 标签页 1: 线稿着色 =====
629
+ with gr.TabItem("🎨 线稿着色"):
630
  with gr.Row():
631
  with gr.Column(scale=1):
632
  input_image = gr.Image(
633
+ label="📤 上传线稿",
634
  type="pil",
635
  height=400
636
  )
637
 
638
+ gr.Markdown("### 🎨 内容设置")
 
 
 
 
 
 
 
 
639
 
640
  lineart_type = gr.Radio(
641
  choices=["Standard", "Anime"],
642
+ label="🖊️ 线稿模型",
643
  value="Anime",
644
+ info="动漫模型更适合动漫/漫画风格"
645
  )
646
 
647
  content_type = gr.Dropdown(
648
  choices=list(CONTENT_TEMPLATES.keys()),
649
+ label="📋 内容模板",
650
+ value="Character Portrait",
651
+ info="提示词起点"
652
  )
653
 
654
  custom_prompt = gr.Textbox(
655
+ label="✍️ 详细描述 (重要)",
656
+ placeholder="描述您想要的内容:发色、服装、姿势、背景、身体特征等",
657
+ lines=3,
658
+ info="请具体描述!这是最重要的字段。"
659
  )
660
 
661
+ gr.HTML("""
662
+ <div class="info-box">
663
+ <strong>💡 提示词示例:</strong><br>
664
+ • "金发,蓝眼睛,女仆装,丰满"<br>
665
+ • "红色马尾辫,校服,短裙,过膝袜"<br>
666
+ • "白发,猫耳,裸体,躺在床上"<br>
667
+ • "两个女孩,接吻,亲密,卧室"
668
+ </div>
669
+ """)
670
+
671
  with gr.Row():
672
  style = gr.Dropdown(
673
  choices=list(COLOR_STYLES.keys()),
674
+ label="🎨 颜色风格",
675
  value="Anime Style"
676
  )
677
 
678
  nsfw_level = gr.Dropdown(
679
  choices=["Safe", "Suggestive", "Mild", "Moderate", "Explicit"],
680
+ label="🔞 内容级别",
681
+ value="Moderate",
682
+ info="内容明确程度"
683
  )
684
 
685
  quality_tags = gr.Textbox(
686
+ label="⭐ 质量标签 (可选)",
687
+ placeholder="masterpiece, best quality, highly detailed",
688
  value="masterpiece, best quality, highly detailed"
689
  )
690
 
691
+ colorize_button = gr.Button("✨ 开始着色!", variant="primary", size="lg")
692
 
693
  with gr.Column(scale=2):
694
  with gr.Row():
695
+ lineart_output = gr.Image(
696
+ label="🖊️ 提取的线稿",
697
+ type="pil",
698
+ height=380
699
+ )
700
+ output_image = gr.Image(
701
+ label="🎨 着色结果",
702
+ type="pil",
703
+ height=380
704
+ )
705
 
706
+ generated_prompt = gr.Textbox(
707
+ label="📝 生成的提示词",
708
  lines=3,
709
  interactive=False,
710
  show_copy_button=True
711
  )
712
 
713
+ with gr.Accordion("⚙️ 高级设置", open=True):
714
  with gr.Row():
715
+ seed = gr.Slider(
716
+ label="🎲 种子",
717
+ minimum=0,
718
+ maximum=2**32-1,
719
+ step=1,
720
+ value=42
721
+ )
722
+ randomize_seed = gr.Checkbox(
723
+ label="🔀 随机种子",
724
+ value=True
725
+ )
726
 
727
  with gr.Row():
728
+ guidance_scale = gr.Slider(
729
+ label="💬 引导尺度",
730
+ minimum=5.0,
731
+ maximum=15.0,
732
+ step=0.5,
733
+ value=8.0,
734
+ info="7-9 推荐用于 NSFW"
735
+ )
736
+
737
+ num_steps = gr.Slider(
738
+ label="🔢 步数",
739
+ minimum=10,
740
+ maximum=30,
741
+ step=5,
742
+ value=20,
743
+ info="20 是良好平衡"
744
+ )
745
 
746
+ controlnet_strength = gr.Slider(
747
+ label="🎛️ 线稿保留强度",
748
  minimum=0.5,
749
  maximum=1.5,
750
  step=0.1,
751
+ value=1.0,
752
+ info="严格遵循线稿的程度"
753
  )
754
 
755
  colorize_button.click(
756
  fn=colorize_lineart,
757
  inputs=[
758
+ input_image, lineart_type, content_type, style, custom_prompt, quality_tags, nsfw_level,
759
+ seed, randomize_seed, guidance_scale, num_steps, controlnet_strength
 
760
  ],
761
+ outputs=[output_image, lineart_output, seed, generated_prompt]
762
  )
763
 
764
+ # ===== 标签页 2: 文本生成图像 =====
765
+ with gr.TabItem("🖼️ 文本生成图像"):
 
 
 
 
 
 
 
 
 
 
 
 
766
  with gr.Row():
767
  with gr.Column(scale=1):
768
+ gr.Markdown("### 🤖 模型选择")
769
+
770
+ model_selector = gr.Dropdown(
771
+ choices=AVAILABLE_MODELS["Text-to-Image"],
772
+ label="🎯 选择模型",
773
+ value="Linaqruf/anything-v3.0",
774
+ info="选择要使用的生成模型"
775
+ )
776
+
777
+ model_info = gr.Textbox(
778
+ label="📊 模型信息",
779
+ value="📊 推荐分辨率: 512x768 (Anything V3 - 全能模型)",
780
+ interactive=False
781
+ )
782
+
783
+ load_model_btn = gr.Button("🔄 加载模型", variant="secondary")
784
+ model_status = gr.Textbox(
785
+ label="✅ 状态",
786
+ value="✅ 模型已就绪",
787
+ interactive=False
788
  )
789
 
790
+ gr.Markdown("### 🎨 内容设置")
791
+
792
+ t2i_content_type = gr.Dropdown(
793
  choices=list(CONTENT_TEMPLATES.keys()),
794
+ label="📋 内容模板",
795
+ value="Character Portrait",
796
+ info="提示词起点"
797
  )
798
 
799
+ t2i_custom_prompt = gr.Textbox(
800
+ label="✍️ 详细描述 (重要)",
801
+ placeholder="详细描述您想要生成的图像:角色特征、服装、姿势、场景等",
802
+ lines=3,
803
+ info="描述越详细,生成效果越好"
804
  )
805
 
806
  gr.HTML("""
807
  <div class="info-box">
808
+ <strong>💡 提示词示例:</strong><br>
809
+ • "美丽的女孩,金色长发,蓝色眼睛,穿着白色连衣裙,站在花园里"<br>
810
+ • "性感的女战士,红色铠甲,手持长剑,动态姿势,战场背景"<br>
811
+ • "两个女孩在咖啡馆约会,温馨的氛围,详细的面部表情"<br>
812
+ • "幻想风格的女精灵,尖耳朵,魔法光效,森林背景"
813
  </div>
814
  """)
815
 
816
  with gr.Row():
817
  t2i_style = gr.Dropdown(
818
  choices=list(COLOR_STYLES.keys()),
819
+ label="🎨 艺术风格",
820
  value="Anime Style"
821
  )
822
 
823
+ t2i_nsfw_level = gr.Dropdown(
824
  choices=["Safe", "Suggestive", "Mild", "Moderate", "Explicit"],
825
+ label="🔞 内容级别",
826
+ value="Moderate",
827
+ info="内容明确程度"
828
  )
829
 
830
+ t2i_quality_tags = gr.Textbox(
831
+ label="⭐ 质量标签 (可选)",
832
+ placeholder="masterpiece, best quality, highly detailed",
833
  value="masterpiece, best quality, highly detailed"
834
  )
835
 
836
+ generate_button = gr.Button("✨ 生成图像!", variant="primary", size="lg")
 
 
 
 
837
 
838
  with gr.Column(scale=2):
839
+ t2i_output_image = gr.Image(
840
+ label="🖼️ 生成的图像",
841
+ type="pil",
842
+ height=500
843
+ )
844
 
845
+ t2i_generated_prompt = gr.Textbox(
846
+ label="📝 生成的提示词",
847
  lines=3,
848
  interactive=False,
849
  show_copy_button=True
850
  )
851
 
852
+ with gr.Accordion("⚙️ 高级设置", open=True):
853
+ with gr.Row():
854
+ t2i_seed = gr.Slider(
855
+ label="🎲 种子",
856
+ minimum=0,
857
+ maximum=2**32-1,
858
+ step=1,
859
+ value=42
860
+ )
861
+ t2i_randomize_seed = gr.Checkbox(
862
+ label="🔀 随机种子",
863
+ value=True
864
+ )
865
+
866
+ with gr.Row():
867
+ t2i_guidance_scale = gr.Slider(
868
+ label="💬 引导尺度",
869
+ minimum=5.0,
870
+ maximum=15.0,
871
+ step=0.5,
872
+ value=7.5,
873
+ info="控制提示词影响力"
874
+ )
875
+
876
+ t2i_num_steps = gr.Slider(
877
+ label="🔢 生成步数",
878
+ minimum=10,
879
+ maximum=50,
880
+ step=5,
881
+ value=30,
882
+ info="步数越多质量越高但越慢"
883
+ )
884
+
885
  with gr.Row():
886
+ t2i_width = gr.Slider(
887
+ label="📏 宽度",
888
+ minimum=256,
889
+ maximum=2048,
890
+ step=8,
891
+ value=512,
892
+ info="图像宽度"
893
+ )
894
+
895
+ t2i_height = gr.Slider(
896
+ label="📐 高度",
897
+ minimum=256,
898
+ maximum=2048,
899
+ step=8,
900
+ value=768,
901
+ info="图像高度"
902
+ )
903
+
904
+ # 事件处理
905
+ model_selector.change(
906
+ fn=update_resolution_from_model,
907
+ inputs=[model_selector],
908
+ outputs=[t2i_width, t2i_height, model_info]
909
+ )
910
+
911
+ load_model_btn.click(
912
+ fn=lambda model_name: (
913
+ load_text_to_image_model(model_name, gr.Progress()) and
914
+ gr.update(value=f"✅ {model_name} 加载成功")
915
+ ),
916
+ inputs=[model_selector],
917
+ outputs=[model_status]
918
+ )
919
+
920
+ generate_button.click(
921
+ fn=generate_text_to_image,
922
+ inputs=[
923
+ model_selector,
924
+ t2i_content_type,
925
+ t2i_style,
926
+ t2i_custom_prompt,
927
+ t2i_quality_tags,
928
+ t2i_nsfw_level,
929
+ t2i_seed,
930
+ t2i_randomize_seed,
931
+ t2i_guidance_scale,
932
+ t2i_num_steps,
933
+ t2i_width,
934
+ t2i_height
935
+ ],
936
+ outputs=[t2i_output_image, t2i_seed, t2i_generated_prompt]
937
+ )
938
+
939
+ gr.Markdown("""
940
+ ---
941
+ ## 📚 快速开始指南
942
+
943
+ ### 🆕 **新功能: 文本生成图像**
944
+
945
+ 此版本新增文本生成图像功能,支持6种不同的模型:
946
+
947
+ #### 🤖 **可用模型:**
948
+
949
+ 1. **Anything V3** (`Linaqruf/anything-v3.0`) - 全能动漫模型
950
+ 2. **ChikMix V3** (`digiplay/ChikMix_V3`) - 高质量动漫模型
951
+ 3. **ChilloutMix** (`digiplay/chilloutmix_NiPrunedFp16Fix`) - 真人风格模型
952
+ 4. **Pony Diffusion V6 XL** (`LyliaEngine/Pony_Diffusion_V6_XL`) - SDXL动漫模型 (高分辨率)
953
+ 5. **AbyssOrangeMix3** (`wootwoot/abyssorangemix3-popupparade-fp16`) - 色彩鲜艳的动漫模型
954
+ 6. **WAI NSFW Illustrious** (`John6666/wai-nsfw-illustrious-v80-sdxl`) - SDXL成人内容优化模型
955
+
956
+ ### 🎨 **线稿着色模型**
957
+
958
+ 线稿着色功能提供两种线稿模型:
959
+ - **标准线稿** (`control_v11p_sd15_lineart`) - 适合一般艺术作品
960
+ - **动漫线稿** (`control_v11p_sd15s2_lineart_anime`) - 专为动漫/漫画风格优化 ✨
961
+
962
+ ### ✅ **如何使用**
963
+
964
+ #### **线稿着色:**
965
+ 1. 上传您的线稿(黑白线条在白底上效果最好)
966
+ 2. 选择线稿模型 - 动漫风格使用"Anime"模型
967
+ 3. 选择内容模板作为起点
968
+ 4. 编写详细描述 - 具体说明颜色、特征、服装等
969
+ 5. 设置NSFW级别以匹配您的内容
970
+ 6. 点击"开始着色!"
971
+
972
+ #### **文本生成图像:**
973
+ 1. 选择您想要使用的模型
974
+ 2. 点击"加载模型"按钮(首次使用或切换模型时需要)
975
+ 3. 编写详细描述您想要生成的图像
976
+ 4. 调整分辨率和生成参数
977
+ 5. 点击"生成图像!"
978
+
979
+ ### 💡 **最佳实践提示**
980
+
981
+ - **模型选择**: SDXL模型需要更多显存但生成质量更高
982
+ - **详细描述**: 描述越详细,生成效果越好
983
+ - **分辨率设置**: SDXL模型推荐使用1024x1024,SD1.5模型推荐512x768
984
+ - **显存管理**: 模型按需加载,切换模型时会自动清理之前的模型
985
+
986
+ ---
987
+
988
+ <div style="text-align:center; color:#666; padding:20px;">
989
+ <strong>🔞 负责任使用</strong><br>
990
+ 此工具用于艺术创作目的。用户必须年满18岁。<br>
991
+ 请尊重版权、同意和当地法律。<br>
992
+ <em>由 Stable Diffusion + ControlNet + 多种生成模型驱动</em>
993
+ </div>
994
  """)
995
 
996
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)