K1Z3M1112 commited on
Commit
47dcfc2
·
verified ·
1 Parent(s): d0b221c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +435 -271
app.py CHANGED
@@ -36,106 +36,186 @@ class SteelBlueTheme(Soft):
36
  steel_blue_theme = SteelBlueTheme()
37
 
38
  print("=" * 50)
39
- print("🎨 Style2Paints - Multi-Model Edition")
40
  print("=" * 50)
41
 
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
44
 
45
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDIMScheduler
46
  from controlnet_aux import LineartDetector, LineartAnimeDetector
47
 
48
- # Global variables for pipes
49
- current_pipe = None
50
- current_base_model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  controlnet_standard = None
52
  controlnet_anime = None
53
  lineart_detector = None
54
  lineart_anime_detector = None
 
 
55
 
56
- BASE_MODELS = {
57
- "Anything V3": "Linaqruf/anything-v3.0",
58
- "AbyssOrangeMix": "WarriorMama777/OrangeMixs"
59
- }
60
-
61
- try:
62
- print("🔄 Loading ControlNet Models...")
63
-
64
- # Load STANDARD lineart ControlNet
65
- print("📦 Loading standard lineart ControlNet...")
66
- controlnet_standard = ControlNetModel.from_pretrained(
67
- "lllyasviel/control_v11p_sd15_lineart",
68
- torch_dtype=dtype
69
- ).to(device)
70
-
71
- # Load ANIME lineart ControlNet
72
- print("📦 Loading anime lineart ControlNet...")
73
- controlnet_anime = ControlNetModel.from_pretrained(
74
- "lllyasviel/control_v11p_sd15s2_lineart_anime",
75
- torch_dtype=dtype
76
- ).to(device)
77
-
78
- # Load detectors
79
- print("📦 Loading line art detectors...")
80
- lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
81
- lineart_anime_detector = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
82
-
83
- print("✅ ControlNet models loaded successfully!")
84
-
85
- except Exception as e:
86
- print(f"❌ Error loading ControlNet: {e}")
87
 
88
- def load_base_model(base_model_name, lineart_type):
89
- """Load base model on demand"""
90
- global current_pipe, current_base_model
91
-
92
- # Check if we already have the right model loaded
93
- if current_pipe is not None and current_base_model == (base_model_name, lineart_type):
94
- print(f"✅ Model already loaded: {base_model_name} ({lineart_type})")
95
- return current_pipe
96
-
97
- # Clear existing pipe
98
- if current_pipe is not None:
99
- print("🗑️ Clearing previous model...")
100
- del current_pipe
101
- if device.type == "cuda":
102
- torch.cuda.empty_cache()
103
- current_pipe = None
104
 
105
- # Select ControlNet
 
 
 
 
 
 
106
  controlnet = controlnet_anime if lineart_type == "Anime" else controlnet_standard
107
- model_path = BASE_MODELS[base_model_name]
108
 
109
- print(f"📦 Loading {base_model_name} with {lineart_type} ControlNet...")
 
 
 
110
 
111
  try:
112
- current_pipe = StableDiffusionControlNetPipeline.from_pretrained(
113
- model_path,
 
 
 
 
114
  controlnet=controlnet,
115
  torch_dtype=dtype,
116
  safety_checker=None,
117
  requires_safety_checker=False
118
  ).to(device)
119
 
120
- # Configure pipe
121
- current_pipe.scheduler = DDIMScheduler.from_config(current_pipe.scheduler.config)
122
 
123
  if device.type == "cuda":
124
- current_pipe.enable_model_cpu_offload()
125
  try:
126
- current_pipe.enable_xformers_memory_efficient_attention()
127
  except:
128
  pass
129
- current_pipe.enable_attention_slicing()
130
 
131
- current_base_model = (base_model_name, lineart_type)
132
- print(f"✅ {base_model_name} loaded successfully!")
 
133
 
134
- return current_pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  except Exception as e:
137
- print(f"❌ Error loading {base_model_name}: {e}")
138
- raise gr.Error(f"Failed to load {base_model_name}: {str(e)}")
 
 
 
139
 
140
  COLOR_STYLES = {
141
  "Anime Style": "anime, masterpiece, best quality, highly detailed, vibrant colors",
@@ -192,7 +272,7 @@ def extract_lineart(image, lineart_type="Standard", skip_if_lineart=True):
192
 
193
  def colorize_lineart(
194
  sketch_image,
195
- base_model,
196
  lineart_type,
197
  content_type,
198
  style,
@@ -206,16 +286,12 @@ def colorize_lineart(
206
  controlnet_strength,
207
  progress=gr.Progress(track_tqdm=True)
208
  ):
209
- """Colorize with model selection"""
210
  if sketch_image is None:
211
  raise gr.Error("Please upload a sketch/line art image")
212
 
213
  # Load the selected model
214
- progress(0.1, desc=f"Loading {base_model}...")
215
- pipe = load_base_model(base_model, lineart_type)
216
-
217
- if pipe is None:
218
- raise gr.Error("Failed to load model")
219
 
220
  # Convert all numeric inputs to proper types
221
  seed = int(seed)
@@ -248,7 +324,7 @@ def colorize_lineart(
248
  new_height = (new_height // 8) * 8
249
  sketch_image = sketch_image.resize((new_width, new_height), Image.LANCZOS)
250
 
251
- progress(0.2, desc="Extracting lineart...")
252
  lineart = extract_lineart(sketch_image, lineart_type=lineart_type, skip_if_lineart=True)
253
 
254
  # Build prompt intelligently
@@ -293,8 +369,7 @@ def colorize_lineart(
293
  if nsfw_level in ["Moderate", "Explicit"]:
294
  negative_prompt = negative_prompt.replace("nsfw, ", "")
295
 
296
- print(f"🎨 Base Model: {base_model}")
297
- print(f"🖼️ Lineart Type: {lineart_type}")
298
  print(f"🎨 Prompt: {full_prompt}")
299
  print(f"🎛️ ControlNet Strength: {controlnet_strength}")
300
 
@@ -314,8 +389,111 @@ def colorize_lineart(
314
  if device.type == "cuda":
315
  torch.cuda.empty_cache()
316
 
317
- model_info = f"Model: {base_model} | Lineart: {lineart_type}"
318
- return result, lineart, seed, f"{model_info}\n{full_prompt}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  except Exception as e:
321
  import traceback
@@ -371,221 +549,207 @@ css="""
371
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
372
  with gr.Column(elem_id="col-container"):
373
  gr.Markdown("# 🎨 Style2Paints - Multi-Model Edition", elem_id="main-title")
374
- gr.Markdown("### ✨ Professional Line Art Colorization with Model Selection")
375
 
376
  gr.HTML("""
377
  <div class="warning-box">
378
- <strong>⚠️ Content Warning:</strong> This tool supports colorization of all content types including NSFW/adult material.
379
  Use responsibly and ensure compliance with local laws. Users must be 18+ for explicit content.
380
  </div>
381
  """)
382
 
383
- gr.HTML("""
384
- <div class="feature-box">
385
- <h3>✨ Key Features</h3>
386
- <ul style="color:white; font-size:1.1em;">
387
- <li>🎨 <strong>Multiple Base Models</strong> - Choose between Anything V3 and AbyssOrangeMix</li>
388
- <li>🚀 <strong>Lazy Loading</strong> - Models load only when needed to save memory</li>
389
- <li>🎭 <strong>Dual ControlNet</strong> - Standard and Anime-specific lineart detection</li>
390
- <li>📝 <strong>Content Templates</strong> - Pre-built prompts for common scenarios</li>
391
- <li>🎚️ <strong>NSFW Level Control</strong> - Explicit tags for better accuracy</li>
392
- </ul>
393
- <div style="margin-top:15px;">
394
- <span class="model-badge">Anything V3</span>
395
- <span class="model-badge">AbyssOrangeMix</span>
396
- <span class="model-badge">Standard Lineart</span>
397
- <span class="model-badge">Anime Lineart</span>
398
- </div>
399
- </div>
400
- """)
401
-
402
- with gr.Row():
403
- with gr.Column(scale=1):
404
- input_image = gr.Image(
405
- label="📤 Upload Line Art",
406
- type="pil",
407
- height=400
408
- )
409
-
410
- gr.Markdown("### 🎨 Model & Content Settings")
411
-
412
- base_model = gr.Radio(
413
- choices=list(BASE_MODELS.keys()),
414
- label="🤖 Base Model",
415
- value="Anything V3",
416
- info="Model loads when you click Colorize"
417
- )
418
-
419
- lineart_type = gr.Radio(
420
- choices=["Standard", "Anime"],
421
- label="🖊️ Lineart ControlNet",
422
- value="Anime",
423
- info="Anime model works better for anime/manga style"
424
- )
425
-
426
- content_type = gr.Dropdown(
427
- choices=list(CONTENT_TEMPLATES.keys()),
428
- label="📋 Content Template",
429
- value="Character Portrait",
430
- info="Starting point for your prompt"
431
- )
432
-
433
- custom_prompt = gr.Textbox(
434
- label="✍️ Detailed Description (IMPORTANT)",
435
- placeholder="Describe what you want: hair color, outfit, pose, background, body features, etc.",
436
- lines=3,
437
- info="Be specific! This is the most important field."
438
- )
439
-
440
  gr.HTML("""
441
- <div class="info-box">
442
- <strong>💡 Base Model Characteristics:</strong><br>
443
- <strong>Anything V3</strong>: Clean anime style, good for general use<br>
444
- <strong>AbyssOrangeMix</strong>: More detailed, painterly style
 
 
 
 
445
  </div>
446
  """)
447
 
448
  with gr.Row():
449
- style = gr.Dropdown(
450
- choices=list(COLOR_STYLES.keys()),
451
- label="🎨 Color Style",
452
- value="Anime Style"
453
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
- nsfw_level = gr.Dropdown(
456
- choices=["Safe", "Suggestive", "Mild", "Moderate", "Explicit"],
457
- label="🔞 Content Level",
458
- value="Moderate",
459
- info="How explicit?"
460
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
- quality_tags = gr.Textbox(
463
- label="⭐ Quality Tags (Optional)",
464
- placeholder="masterpiece, best quality, highly detailed",
465
- value="masterpiece, best quality, highly detailed"
 
 
 
 
466
  )
467
-
468
- colorize_button = gr.Button("✨ Colorize!", variant="primary", size="lg")
469
 
470
- with gr.Column(scale=2):
471
- with gr.Row():
472
- lineart_output = gr.Image(
473
- label="🖊️ Line Art",
474
- type="pil",
475
- height=380
476
- )
477
- output_image = gr.Image(
478
- label="🎨 Colorized Result",
479
- type="pil",
480
- height=380
481
- )
482
-
483
- generated_prompt = gr.Textbox(
484
- label="📝 Model & Prompt Info",
485
- lines=3,
486
- interactive=False,
487
- show_copy_button=True
488
- )
489
 
490
- with gr.Accordion("⚙️ Advanced Settings", open=True):
491
- with gr.Row():
492
- seed = gr.Slider(
493
- label="🎲 Seed",
494
- minimum=0,
495
- maximum=2**32-1,
496
- step=1,
497
- value=42
498
  )
499
- randomize_seed = gr.Checkbox(
500
- label="🔀 Random Seed",
501
- value=True
 
 
502
  )
503
-
504
- with gr.Row():
505
- guidance_scale = gr.Slider(
506
- label="💬 Guidance Scale",
507
- minimum=5.0,
508
- maximum=15.0,
509
- step=0.5,
510
- value=8.0,
511
- info="7-9 recommended for NSFW"
512
  )
513
 
514
- num_steps = gr.Slider(
515
- label="🔢 Steps",
516
- minimum=10,
517
- maximum=30,
518
- step=5,
519
- value=20,
520
- info="20 is good balance"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  )
 
 
 
 
 
 
522
 
523
- controlnet_strength = gr.Slider(
524
- label="🎛️ Line Preservation",
525
- minimum=0.5,
526
- maximum=1.5,
527
- step=0.1,
528
- value=1.0,
529
- info="How strictly to follow lines"
530
- )
531
-
532
- colorize_button.click(
533
- fn=colorize_lineart,
534
- inputs=[
535
- input_image, base_model, lineart_type, content_type, style, custom_prompt,
536
- quality_tags, nsfw_level, seed, randomize_seed, guidance_scale, num_steps,
537
- controlnet_strength
538
- ],
539
- outputs=[output_image, lineart_output, seed, generated_prompt]
540
- )
541
-
542
- gr.Markdown("""
543
- ---
544
- ## 📚 Quick Start Guide
545
-
546
- ### 🆕 **Model Selection Feature**
547
-
548
- Choose between two popular anime models:
549
- - **Anything V3** - Clean, versatile anime style. Good all-rounder.
550
- - **AbyssOrangeMix** - More detailed and painterly. Better for complex scenes.
551
-
552
- **💡 Memory Efficient**: Models load only when you click "Colorize", saving RAM!
553
-
554
- ### ✅ **How to Use**
555
-
556
- 1. **Upload your line art** (black lines on white background works best)
557
- 2. **Select base model** - Try both to see which style you prefer!
558
- 3. **Choose lineart ControlNet** - Use "Anime" for anime/manga style
559
- 4. **Choose a content template** to get started
560
- 5. **Write a detailed description** - be specific about colors, features, clothing, pose
561
- 6. **Set the NSFW level** to match your content
562
- 7. **Click Colorize!** - Model will load automatically
563
-
564
- ### 🎨 **Model Comparison**
565
-
566
- | Feature | Anything V3 | AbyssOrangeMix |
567
- |---------|-------------|----------------|
568
- | Style | Clean anime | Painterly, detailed |
569
- | Best for | General use | Complex scenes |
570
- | Detail level | Standard | High |
571
- | Colors | Vibrant | Rich, layered |
572
-
573
- ### 💡 **Tips for Best Results**
574
-
575
- - **First time**: Models download on first use (may take a few minutes)
576
- - **Model switching**: Automatic - just select and click Colorize
577
- - **Be detailed in prompts**: "blonde long hair, blue eyes, red dress" beats "girl"
578
- - **Match NSFW level**: Use "Explicit" for nude, "Safe" for SFW
579
- - **Clean line art is key**: Black lines on white background = best results
580
-
581
- ---
582
-
583
- <div style="text-align:center; color:#666; padding:20px;">
584
- <strong>🔞 Responsible Use</strong><br>
585
- This tool is for artistic purposes. Users must be 18+.<br>
586
- Respect copyright, consent, and local laws.<br>
587
- <em>Powered by Stable Diffusion + ControlNet (Multi-Model Support)</em>
588
- </div>
589
  """)
590
 
591
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
36
  steel_blue_theme = SteelBlueTheme()
37
 
38
  print("=" * 50)
39
+ print("🎨 Style2Paints - Enhanced Multi-Model Edition")
40
  print("=" * 50)
41
 
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
44
 
45
+ from diffusers import StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline, ControlNetModel, DDIMScheduler, StableDiffusionPipeline, StableDiffusionXLPipeline
46
  from controlnet_aux import LineartDetector, LineartAnimeDetector
47
 
48
+ # Model configurations
49
+ AVAILABLE_MODELS = {
50
+ "Anything V3 (Anime)": {
51
+ "repo": "Linaqruf/anything-v3.0",
52
+ "type": "sd15",
53
+ "description": "Original anime model"
54
+ },
55
+ "ChikMix V3": {
56
+ "repo": "digiplay/ChikMix_V3",
57
+ "type": "sd15",
58
+ "description": "High quality anime/realistic mix"
59
+ },
60
+ "ChilloutMix": {
61
+ "repo": "digiplay/chilloutmix_NiPrunedFp16Fix",
62
+ "type": "sd15",
63
+ "description": "Popular realistic/anime hybrid"
64
+ },
65
+ "Pony Diffusion V6 XL": {
66
+ "repo": "LyliaEngine/Pony_Diffusion_V6_XL",
67
+ "type": "sdxl",
68
+ "description": "SDXL anime model"
69
+ },
70
+ "AbyssOrangeMix3": {
71
+ "repo": "wootwoot/abyssorangemix3-popupparade-fp16",
72
+ "type": "sd15",
73
+ "description": "Vibrant anime style"
74
+ },
75
+ "WAI-NSFW Illustrious XL": {
76
+ "repo": "John6666/wai-nsfw-illustrious-v80-sdxl",
77
+ "type": "sdxl",
78
+ "description": "SDXL NSFW-focused model"
79
+ }
80
+ }
81
+
82
+ # Global variables for models
83
+ loaded_models = {}
84
  controlnet_standard = None
85
  controlnet_anime = None
86
  lineart_detector = None
87
  lineart_anime_detector = None
88
+ txt2img_pipe = None
89
+ current_txt2img_model = None
90
 
91
+ def load_controlnet_models():
92
+ """Load ControlNet models"""
93
+ global controlnet_standard, controlnet_anime, lineart_detector, lineart_anime_detector
94
+
95
+ print("📦 Loading ControlNet models...")
96
+
97
+ try:
98
+ controlnet_standard = ControlNetModel.from_pretrained(
99
+ "lllyasviel/control_v11p_sd15_lineart",
100
+ torch_dtype=dtype
101
+ ).to(device)
102
+ print("✅ Standard ControlNet loaded")
103
+ except Exception as e:
104
+ print(f"❌ Failed to load standard ControlNet: {e}")
105
+
106
+ try:
107
+ controlnet_anime = ControlNetModel.from_pretrained(
108
+ "lllyasviel/control_v11p_sd15s2_lineart_anime",
109
+ torch_dtype=dtype
110
+ ).to(device)
111
+ print("✅ Anime ControlNet loaded")
112
+ except Exception as e:
113
+ print(f"❌ Failed to load anime ControlNet: {e}")
114
+
115
+ try:
116
+ lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
117
+ lineart_anime_detector = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
118
+ print("✅ Line art detectors loaded")
119
+ except Exception as e:
120
+ print(f"❌ Failed to load detectors: {e}")
 
121
 
122
+ def load_model_for_controlnet(model_name, lineart_type):
123
+ """Load a specific model with ControlNet"""
124
+ global loaded_models
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ cache_key = f"{model_name}_{lineart_type}"
127
+
128
+ if cache_key in loaded_models:
129
+ print(f"♻️ Using cached model: {cache_key}")
130
+ return loaded_models[cache_key]
131
+
132
+ model_info = AVAILABLE_MODELS[model_name]
133
  controlnet = controlnet_anime if lineart_type == "Anime" else controlnet_standard
 
134
 
135
+ if controlnet is None:
136
+ raise gr.Error("ControlNet model not loaded")
137
+
138
+ print(f"📦 Loading {model_name}...")
139
 
140
  try:
141
+ if model_info["type"] == "sdxl":
142
+ # SDXL models don't support SD1.5 ControlNet currently
143
+ raise gr.Error("SDXL models are not yet supported for line art colorization. Use SD1.5 models.")
144
+
145
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
146
+ model_info["repo"],
147
  controlnet=controlnet,
148
  torch_dtype=dtype,
149
  safety_checker=None,
150
  requires_safety_checker=False
151
  ).to(device)
152
 
153
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
 
154
 
155
  if device.type == "cuda":
156
+ pipe.enable_model_cpu_offload()
157
  try:
158
+ pipe.enable_xformers_memory_efficient_attention()
159
  except:
160
  pass
161
+ pipe.enable_attention_slicing()
162
 
163
+ loaded_models[cache_key] = pipe
164
+ print(f"✅ {model_name} loaded successfully")
165
+ return pipe
166
 
167
+ except Exception as e:
168
+ print(f"❌ Error loading {model_name}: {e}")
169
+ raise gr.Error(f"Failed to load {model_name}: {str(e)}")
170
+
171
+ def load_txt2img_model(model_name):
172
+ """Load model for text-to-image"""
173
+ global txt2img_pipe, current_txt2img_model
174
+
175
+ if current_txt2img_model == model_name and txt2img_pipe is not None:
176
+ print(f"♻️ Using cached txt2img model: {model_name}")
177
+ return txt2img_pipe
178
+
179
+ model_info = AVAILABLE_MODELS[model_name]
180
+ print(f"📦 Loading {model_name} for text-to-image...")
181
+
182
+ try:
183
+ if model_info["type"] == "sdxl":
184
+ pipe = StableDiffusionXLPipeline.from_pretrained(
185
+ model_info["repo"],
186
+ torch_dtype=dtype,
187
+ safety_checker=None,
188
+ requires_safety_checker=False
189
+ ).to(device)
190
+ else:
191
+ pipe = StableDiffusionPipeline.from_pretrained(
192
+ model_info["repo"],
193
+ torch_dtype=dtype,
194
+ safety_checker=None,
195
+ requires_safety_checker=False
196
+ ).to(device)
197
+
198
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
199
+
200
+ if device.type == "cuda":
201
+ pipe.enable_model_cpu_offload()
202
+ try:
203
+ pipe.enable_xformers_memory_efficient_attention()
204
+ except:
205
+ pass
206
+ pipe.enable_attention_slicing()
207
+
208
+ txt2img_pipe = pipe
209
+ current_txt2img_model = model_name
210
+ print(f"✅ {model_name} loaded for txt2img")
211
+ return pipe
212
 
213
  except Exception as e:
214
+ print(f"❌ Error loading {model_name}: {e}")
215
+ raise gr.Error(f"Failed to load {model_name}: {str(e)}")
216
+
217
+ # Load ControlNet models at startup
218
+ load_controlnet_models()
219
 
220
  COLOR_STYLES = {
221
  "Anime Style": "anime, masterpiece, best quality, highly detailed, vibrant colors",
 
272
 
273
  def colorize_lineart(
274
  sketch_image,
275
+ model_name,
276
  lineart_type,
277
  content_type,
278
  style,
 
286
  controlnet_strength,
287
  progress=gr.Progress(track_tqdm=True)
288
  ):
289
+ """Colorize with explicit content support"""
290
  if sketch_image is None:
291
  raise gr.Error("Please upload a sketch/line art image")
292
 
293
  # Load the selected model
294
+ pipe = load_model_for_controlnet(model_name, lineart_type)
 
 
 
 
295
 
296
  # Convert all numeric inputs to proper types
297
  seed = int(seed)
 
324
  new_height = (new_height // 8) * 8
325
  sketch_image = sketch_image.resize((new_width, new_height), Image.LANCZOS)
326
 
327
+ # Extract lineart with selected type
328
  lineart = extract_lineart(sketch_image, lineart_type=lineart_type, skip_if_lineart=True)
329
 
330
  # Build prompt intelligently
 
369
  if nsfw_level in ["Moderate", "Explicit"]:
370
  negative_prompt = negative_prompt.replace("nsfw, ", "")
371
 
372
+ print(f"🎨 Model: {model_name}")
 
373
  print(f"🎨 Prompt: {full_prompt}")
374
  print(f"🎛️ ControlNet Strength: {controlnet_strength}")
375
 
 
389
  if device.type == "cuda":
390
  torch.cuda.empty_cache()
391
 
392
+ return result, lineart, seed, full_prompt
393
+
394
+ except Exception as e:
395
+ import traceback
396
+ print(f"❌ Full error: {traceback.format_exc()}")
397
+ raise gr.Error(f"Error: {str(e)}")
398
+
399
+ def generate_txt2img(
400
+ model_name,
401
+ content_type,
402
+ style,
403
+ custom_prompt,
404
+ quality_tags,
405
+ nsfw_level,
406
+ seed,
407
+ randomize_seed,
408
+ guidance_scale,
409
+ num_steps,
410
+ width,
411
+ height,
412
+ progress=gr.Progress(track_tqdm=True)
413
+ ):
414
+ """Generate image from text only"""
415
+
416
+ # Load the selected model
417
+ pipe = load_txt2img_model(model_name)
418
+
419
+ # Convert all numeric inputs to proper types
420
+ seed = int(seed)
421
+ guidance_scale = float(guidance_scale)
422
+ num_steps = int(num_steps)
423
+ width = int(width)
424
+ height = int(height)
425
+
426
+ if randomize_seed:
427
+ import random
428
+ seed = random.randint(0, 2**32-1)
429
+
430
+ generator = torch.Generator(device=device).manual_seed(seed)
431
+
432
+ # Build prompt
433
+ prompt_parts = []
434
+
435
+ # Content template
436
+ content_template = CONTENT_TEMPLATES.get(content_type, "")
437
+ if content_template:
438
+ prompt_parts.append(content_template)
439
+
440
+ # Custom prompt
441
+ if custom_prompt.strip():
442
+ prompt_parts.append(custom_prompt.strip())
443
+
444
+ # Quality tags
445
+ if quality_tags:
446
+ prompt_parts.append(quality_tags)
447
+
448
+ # Style
449
+ style_prompt = COLOR_STYLES.get(style, COLOR_STYLES["Anime Style"])
450
+ prompt_parts.append(style_prompt)
451
+
452
+ # NSFW level tags
453
+ if nsfw_level == "Safe":
454
+ nsfw_tags = "sfw, safe for work"
455
+ elif nsfw_level == "Suggestive":
456
+ nsfw_tags = "suggestive, slightly revealing"
457
+ elif nsfw_level == "Mild":
458
+ nsfw_tags = "nsfw, ecchi, revealing clothing"
459
+ elif nsfw_level == "Moderate":
460
+ nsfw_tags = "nsfw, nude, explicit"
461
+ else: # Explicit
462
+ nsfw_tags = "nsfw, explicit, uncensored"
463
+
464
+ prompt_parts.append(nsfw_tags)
465
+
466
+ full_prompt = ", ".join(prompt_parts)
467
+
468
+ # Negative prompt
469
+ negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, jpeg artifacts, signature, watermark, username, blurry, artist name"
470
+
471
+ if nsfw_level in ["Moderate", "Explicit"]:
472
+ negative_prompt = negative_prompt.replace("nsfw, ", "")
473
+
474
+ print(f"🎨 Model: {model_name}")
475
+ print(f"🎨 Prompt: {full_prompt}")
476
+
477
+ try:
478
+ progress(0.3, desc="Generating image...")
479
+
480
+ # Check if SDXL model
481
+ model_info = AVAILABLE_MODELS[model_name]
482
+
483
+ result = pipe(
484
+ prompt=full_prompt,
485
+ negative_prompt=negative_prompt,
486
+ num_inference_steps=num_steps,
487
+ guidance_scale=guidance_scale,
488
+ width=width,
489
+ height=height,
490
+ generator=generator,
491
+ ).images[0]
492
+
493
+ if device.type == "cuda":
494
+ torch.cuda.empty_cache()
495
+
496
+ return result, seed, full_prompt
497
 
498
  except Exception as e:
499
  import traceback
 
549
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
550
  with gr.Column(elem_id="col-container"):
551
  gr.Markdown("# 🎨 Style2Paints - Multi-Model Edition", elem_id="main-title")
552
+ gr.Markdown("### ✨ Professional AI Art Generation with Multiple Models")
553
 
554
  gr.HTML("""
555
  <div class="warning-box">
556
+ <strong>⚠️ Content Warning:</strong> This tool supports generation of all content types including NSFW/adult material.
557
  Use responsibly and ensure compliance with local laws. Users must be 18+ for explicit content.
558
  </div>
559
  """)
560
 
561
+ with gr.Tabs():
562
+ # Tab 1: Line Art Colorization
563
+ with gr.Tab("🖊️ Line Art Colorization"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
  gr.HTML("""
565
+ <div class="feature-box">
566
+ <h3>✨ Colorize Your Line Art</h3>
567
+ <ul style="color:white; font-size:1.1em;">
568
+ <li>🎨 <strong>6 Anime Models</strong> - Choose your favorite style</li>
569
+ <li>🖊️ <strong>Dual Lineart Detection</strong> - Standard & Anime optimized</li>
570
+ <li>🎭 <strong>Smart Prompt Building</strong> - Easy content creation</li>
571
+ <li>🎚️ <strong>NSFW Support</strong> - All content levels supported</li>
572
+ </ul>
573
  </div>
574
  """)
575
 
576
  with gr.Row():
577
+ with gr.Column(scale=1):
578
+ input_image = gr.Image(
579
+ label="📤 Upload Line Art",
580
+ type="pil",
581
+ height=400
582
+ )
583
+
584
+ # Get SD1.5 models only for ControlNet
585
+ sd15_models = [name for name, info in AVAILABLE_MODELS.items() if info["type"] == "sd15"]
586
+
587
+ model_selector = gr.Dropdown(
588
+ choices=sd15_models,
589
+ label="🤖 Select Model",
590
+ value="Anything V3 (Anime)",
591
+ info="Choose your preferred model (SD1.5 only for line art)"
592
+ )
593
+
594
+ lineart_type = gr.Radio(
595
+ choices=["Standard", "Anime"],
596
+ label="🖊️ Lineart Model",
597
+ value="Anime",
598
+ info="Anime works better for anime/manga"
599
+ )
600
+
601
+ content_type = gr.Dropdown(
602
+ choices=list(CONTENT_TEMPLATES.keys()),
603
+ label="📋 Content Template",
604
+ value="Character Portrait"
605
+ )
606
+
607
+ custom_prompt = gr.Textbox(
608
+ label="✍️ Detailed Description",
609
+ placeholder="hair color, outfit, pose, features...",
610
+ lines=3
611
+ )
612
+
613
+ with gr.Row():
614
+ style = gr.Dropdown(
615
+ choices=list(COLOR_STYLES.keys()),
616
+ label="🎨 Style",
617
+ value="Anime Style"
618
+ )
619
+
620
+ nsfw_level = gr.Dropdown(
621
+ choices=["Safe", "Suggestive", "Mild", "Moderate", "Explicit"],
622
+ label="🔞 Content Level",
623
+ value="Moderate"
624
+ )
625
+
626
+ quality_tags = gr.Textbox(
627
+ label="⭐ Quality Tags",
628
+ value="masterpiece, best quality, highly detailed"
629
+ )
630
+
631
+ colorize_button = gr.Button("✨ Colorize!", variant="primary", size="lg")
632
 
633
+ with gr.Column(scale=2):
634
+ with gr.Row():
635
+ lineart_output = gr.Image(label="🖊️ Line Art", type="pil", height=380)
636
+ colorized_output = gr.Image(label="🎨 Result", type="pil", height=380)
637
+
638
+ colorized_prompt = gr.Textbox(
639
+ label="📝 Generated Prompt",
640
+ lines=3,
641
+ interactive=False,
642
+ show_copy_button=True
643
+ )
644
+
645
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
646
+ with gr.Row():
647
+ c_seed = gr.Slider(label="🎲 Seed", minimum=0, maximum=2**32-1, step=1, value=42)
648
+ c_random = gr.Checkbox(label="🔀 Random", value=True)
649
+
650
+ with gr.Row():
651
+ c_guidance = gr.Slider(label="💬 Guidance", minimum=5.0, maximum=15.0, step=0.5, value=8.0)
652
+ c_steps = gr.Slider(label="🔢 Steps", minimum=10, maximum=30, step=5, value=20)
653
+
654
+ c_strength = gr.Slider(
655
+ label="🎛️ Line Preservation",
656
+ minimum=0.5,
657
+ maximum=1.5,
658
+ step=0.1,
659
+ value=1.0
660
+ )
661
 
662
+ colorize_button.click(
663
+ fn=colorize_lineart,
664
+ inputs=[
665
+ input_image, model_selector, lineart_type, content_type, style,
666
+ custom_prompt, quality_tags, nsfw_level, c_seed, c_random,
667
+ c_guidance, c_steps, c_strength
668
+ ],
669
+ outputs=[colorized_output, lineart_output, c_seed, colorized_prompt]
670
  )
 
 
671
 
672
+ # Tab 2: Text to Image
673
+ with gr.Tab("✍️ Text to Image"):
674
+ gr.HTML("""
675
+ <div class="feature-box">
676
+ <h3>✨ Generate Images from Text</h3>
677
+ <ul style="color:white; font-size:1.1em;">
678
+ <li>🎨 <strong>6 Models Including SDXL</strong> - Maximum variety</li>
679
+ <li>✍️ <strong>Pure Text Generation</strong> - No line art needed</li>
680
+ <li>📐 <strong>Flexible Sizes</strong> - From 512x512 to 1024x1024</li>
681
+ <li>🎭 <strong>All Content Types</strong> - SFW to NSFW</li>
682
+ </ul>
683
+ </div>
684
+ """)
 
 
 
 
 
 
685
 
686
+ with gr.Row():
687
+ with gr.Column(scale=1):
688
+ t2i_model = gr.Dropdown(
689
+ choices=list(AVAILABLE_MODELS.keys()),
690
+ label="🤖 Select Model",
691
+ value="Anything V3 (Anime)",
692
+ info="All models available for txt2img"
 
693
  )
694
+
695
+ t2i_content = gr.Dropdown(
696
+ choices=list(CONTENT_TEMPLATES.keys()),
697
+ label="📋 Content Template",
698
+ value="Character Portrait"
699
  )
700
+
701
+ t2i_prompt = gr.Textbox(
702
+ label="✍️ Your Description",
703
+ placeholder="Describe what you want to see...",
704
+ lines=4
 
 
 
 
705
  )
706
 
707
+ gr.HTML("""
708
+ <div class="info-box">
709
+ <strong>💡 Example Prompts:</strong><br>
710
+ • "girl with long silver hair, blue eyes, wearing kimono, cherry blossoms"<br>
711
+ • "cyberpunk girl, neon lights, futuristic city, night scene"<br>
712
+ • "maid outfit, cat ears, blushing, bedroom"
713
+ </div>
714
+ """)
715
+
716
+ with gr.Row():
717
+ t2i_style = gr.Dropdown(
718
+ choices=list(COLOR_STYLES.keys()),
719
+ label="🎨 Style",
720
+ value="Anime Style"
721
+ )
722
+
723
+ t2i_nsfw = gr.Dropdown(
724
+ choices=["Safe", "Suggestive", "Mild", "Moderate", "Explicit"],
725
+ label="🔞 Content Level",
726
+ value="Safe"
727
+ )
728
+
729
+ t2i_quality = gr.Textbox(
730
+ label="⭐ Quality Tags",
731
+ value="masterpiece, best quality, highly detailed"
732
  )
733
+
734
+ with gr.Row():
735
+ t2i_width = gr.Slider(label="📏 Width", minimum=512, maximum=1024, step=64, value=512)
736
+ t2i_height = gr.Slider(label="📏 Height", minimum=512, maximum=1024, step=64, value=512)
737
+
738
+ generate_button = gr.Button("🎨 Generate!", variant="primary", size="lg")
739
 
740
+ with gr.Column(scale=2):
741
+ t2i_output = gr.Image(label="🎨 Generated Image", type="pil", height=600)
742
+
743
+ t2i_prompt_output = gr.Textbox(
744
+ label="📝 Full Prompt Used",
745
+ lines=3,
746
+ interactive=False,
747
+ show_copy_button=True
748
+ )
749
+
750
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
751
+ with gr.Row():
752
+ t2i_seed = gr.Slider(label="🎲 Seed", minimum=0, maximum=
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
753
  """)
754
 
755
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)