Update app.py
Browse files
app.py
CHANGED
|
@@ -20,9 +20,9 @@ if torch.cuda.is_available():
|
|
| 20 |
|
| 21 |
# Device
|
| 22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
-
|
| 24 |
|
| 25 |
-
print(f"🖥️ Device: {device} | dtype: {
|
| 26 |
|
| 27 |
# Lazy import (to avoid long startup if unused)
|
| 28 |
from diffusers import (
|
|
@@ -234,24 +234,25 @@ def load_florence2():
|
|
| 234 |
|
| 235 |
print("📥 Loading Microsoft/Florence-2-base...")
|
| 236 |
|
| 237 |
-
#
|
| 238 |
-
FLORENCE2_PROCESSOR = AutoProcessor.from_pretrained(
|
| 239 |
-
"microsoft/Florence-2-base",
|
| 240 |
-
trust_remote_code=True
|
| 241 |
-
)
|
| 242 |
-
|
| 243 |
-
# 使用較舊的加載方式避免兼容性問題
|
| 244 |
FLORENCE2_MODEL = AutoModelForCausalLM.from_pretrained(
|
| 245 |
-
"microsoft/Florence-2-base",
|
| 246 |
-
torch_dtype=
|
| 247 |
trust_remote_code=True
|
| 248 |
).to(device)
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
print("✅ Florence-2 model loaded successfully")
|
| 251 |
return FLORENCE2_PROCESSOR, FLORENCE2_MODEL
|
| 252 |
|
| 253 |
except Exception as e:
|
| 254 |
print(f"❌ Error loading Florence-2: {e}")
|
|
|
|
|
|
|
| 255 |
return None, None
|
| 256 |
|
| 257 |
def analyze_with_florence2(image, task_prompt):
|
|
@@ -271,8 +272,6 @@ def analyze_with_florence2(image, task_prompt):
|
|
| 271 |
try:
|
| 272 |
if isinstance(image, np.ndarray):
|
| 273 |
image = Image.fromarray(image)
|
| 274 |
-
elif hasattr(image, 'shape'): # 可能是 torch tensor
|
| 275 |
-
image = Image.fromarray(image.cpu().numpy())
|
| 276 |
else:
|
| 277 |
return "❌ Invalid image format. Please upload a valid image."
|
| 278 |
except Exception as e:
|
|
@@ -289,49 +288,61 @@ def analyze_with_florence2(image, task_prompt):
|
|
| 289 |
new_size = (int(image.width * ratio), int(image.height * ratio))
|
| 290 |
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
| 291 |
|
| 292 |
-
#
|
| 293 |
try:
|
| 294 |
inputs = processor(
|
| 295 |
text=task_prompt,
|
| 296 |
images=image,
|
| 297 |
return_tensors="pt"
|
| 298 |
-
).to(device)
|
| 299 |
except Exception as e:
|
| 300 |
print(f"❌ Error processing image: {e}")
|
| 301 |
return f"❌ Error processing image: {str(e)}"
|
| 302 |
|
| 303 |
-
#
|
| 304 |
-
if inputs is None or 'pixel_values' not in inputs:
|
| 305 |
-
return "❌ Failed to process image for analysis."
|
| 306 |
-
|
| 307 |
-
# Generate
|
| 308 |
try:
|
| 309 |
generated_ids = model.generate(
|
| 310 |
input_ids=inputs["input_ids"],
|
| 311 |
pixel_values=inputs["pixel_values"],
|
| 312 |
-
max_new_tokens=
|
| 313 |
-
|
| 314 |
-
|
| 315 |
)
|
| 316 |
except Exception as e:
|
| 317 |
print(f"❌ Error generating text: {e}")
|
| 318 |
return f"❌ Error during analysis: {str(e)}"
|
| 319 |
|
| 320 |
-
#
|
| 321 |
try:
|
| 322 |
generated_text = processor.batch_decode(
|
| 323 |
generated_ids,
|
| 324 |
-
skip_special_tokens=
|
| 325 |
)[0]
|
| 326 |
except Exception as e:
|
| 327 |
print(f"❌ Error decoding text: {e}")
|
| 328 |
return f"❌ Error decoding result: {str(e)}"
|
| 329 |
|
| 330 |
-
#
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
except Exception as e:
|
| 337 |
print(f"❌ Error in Florence-2 analysis: {e}")
|
|
@@ -404,17 +415,17 @@ def get_pipeline(model_name: str, controlnet_type: str = "lineart", lora_model:
|
|
| 404 |
controlnet_model_name = get_controlnet_model(controlnet_type)
|
| 405 |
controlnet = ControlNetModel.from_pretrained(
|
| 406 |
controlnet_model_name,
|
| 407 |
-
torch_dtype=
|
| 408 |
).to(device)
|
| 409 |
|
| 410 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 411 |
model_name,
|
| 412 |
controlnet=controlnet,
|
| 413 |
-
torch_dtype=
|
| 414 |
safety_checker=None,
|
| 415 |
requires_safety_checker=False,
|
| 416 |
use_safetensors=True,
|
| 417 |
-
variant="fp16" if
|
| 418 |
).to(device)
|
| 419 |
else:
|
| 420 |
raise ValueError(f"SDXL model {model_name} only supports limited ControlNet types: {list(SDXL_CONTROLNET_MODELS.keys())}")
|
|
@@ -423,17 +434,17 @@ def get_pipeline(model_name: str, controlnet_type: str = "lineart", lora_model:
|
|
| 423 |
controlnet_model_name = get_controlnet_model(controlnet_type)
|
| 424 |
controlnet = ControlNetModel.from_pretrained(
|
| 425 |
controlnet_model_name,
|
| 426 |
-
torch_dtype=
|
| 427 |
).to(device)
|
| 428 |
|
| 429 |
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 430 |
model_name,
|
| 431 |
controlnet=controlnet,
|
| 432 |
-
torch_dtype=
|
| 433 |
safety_checker=None,
|
| 434 |
requires_safety_checker=False,
|
| 435 |
use_safetensors=True,
|
| 436 |
-
variant="fp16" if
|
| 437 |
).to(device)
|
| 438 |
|
| 439 |
# Apply LoRA if specified
|
|
@@ -540,20 +551,20 @@ def load_t2i_model(model_name: str, lora_model: str = None, lora_weight: float =
|
|
| 540 |
# Load base and refiner
|
| 541 |
CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
|
| 542 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 543 |
-
torch_dtype=
|
| 544 |
safety_checker=None,
|
| 545 |
requires_safety_checker=False,
|
| 546 |
use_safetensors=True,
|
| 547 |
-
variant="fp16" if
|
| 548 |
).to(device)
|
| 549 |
|
| 550 |
CURRENT_SDXL_REFINER = StableDiffusionXLPipeline.from_pretrained(
|
| 551 |
model_name,
|
| 552 |
-
torch_dtype=
|
| 553 |
safety_checker=None,
|
| 554 |
requires_safety_checker=False,
|
| 555 |
use_safetensors=True,
|
| 556 |
-
variant="fp16" if
|
| 557 |
text_encoder_2=CURRENT_T2I_PIPE.text_encoder_2,
|
| 558 |
vae=CURRENT_T2I_PIPE.vae
|
| 559 |
).to(device)
|
|
@@ -561,22 +572,22 @@ def load_t2i_model(model_name: str, lora_model: str = None, lora_weight: float =
|
|
| 561 |
else:
|
| 562 |
CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
|
| 563 |
model_name,
|
| 564 |
-
torch_dtype=
|
| 565 |
safety_checker=None,
|
| 566 |
requires_safety_checker=False,
|
| 567 |
use_safetensors=True,
|
| 568 |
-
variant="fp16" if
|
| 569 |
).to(device)
|
| 570 |
print(f"✅ Loaded SDXL model: {model_name}")
|
| 571 |
else:
|
| 572 |
# Load SD1.5 model
|
| 573 |
CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
|
| 574 |
model_name,
|
| 575 |
-
torch_dtype=
|
| 576 |
safety_checker=None,
|
| 577 |
requires_safety_checker=False,
|
| 578 |
use_safetensors=True,
|
| 579 |
-
variant="fp16" if
|
| 580 |
).to(device)
|
| 581 |
print(f"✅ Loaded SD1.5 model: {model_name}")
|
| 582 |
|
|
@@ -631,14 +642,14 @@ def load_t2i_model(model_name: str, lora_model: str = None, lora_weight: float =
|
|
| 631 |
if is_sdxl_model(model_name):
|
| 632 |
CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
|
| 633 |
model_name,
|
| 634 |
-
torch_dtype=
|
| 635 |
safety_checker=None,
|
| 636 |
requires_safety_checker=False
|
| 637 |
).to(device)
|
| 638 |
else:
|
| 639 |
CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
|
| 640 |
model_name,
|
| 641 |
-
torch_dtype=
|
| 642 |
safety_checker=None,
|
| 643 |
requires_safety_checker=False
|
| 644 |
).to(device)
|
|
@@ -1128,15 +1139,24 @@ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Sof
|
|
| 1128 |
gr.Markdown("""
|
| 1129 |
### Microsoft Florence-2 Vision Language Model
|
| 1130 |
**Pre-trained Tasks:**
|
| 1131 |
-
- `<OCR>`: Text recognition
|
| 1132 |
-
- `<CAPTION>`: Image captioning
|
| 1133 |
-
- `<DETAILED_CAPTION>`: Detailed caption
|
| 1134 |
-
- `<MORE_DETAILED_CAPTION>`: More detailed caption
|
| 1135 |
-
- `<OD>`: Object detection
|
| 1136 |
- `<OPEN_VOCABULARY_DETECTION>`: Open-vocabulary detection
|
| 1137 |
- `<REGION_PROPOSAL>`: Region proposal
|
| 1138 |
|
| 1139 |
-
**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1140 |
""")
|
| 1141 |
|
| 1142 |
with gr.Row():
|
|
|
|
| 20 |
|
| 21 |
# Device
|
| 22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 24 |
|
| 25 |
+
print(f"🖥️ Device: {device} | dtype: {torch_dtype}")
|
| 26 |
|
| 27 |
# Lazy import (to avoid long startup if unused)
|
| 28 |
from diffusers import (
|
|
|
|
| 234 |
|
| 235 |
print("📥 Loading Microsoft/Florence-2-base...")
|
| 236 |
|
| 237 |
+
# 按照官方文檔加載模型
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
FLORENCE2_MODEL = AutoModelForCausalLM.from_pretrained(
|
| 239 |
+
"microsoft/Florence-2-base",
|
| 240 |
+
torch_dtype=torch_dtype,
|
| 241 |
trust_remote_code=True
|
| 242 |
).to(device)
|
| 243 |
|
| 244 |
+
FLORENCE2_PROCESSOR = AutoProcessor.from_pretrained(
|
| 245 |
+
"microsoft/Florence-2-base",
|
| 246 |
+
trust_remote_code=True
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
print("✅ Florence-2 model loaded successfully")
|
| 250 |
return FLORENCE2_PROCESSOR, FLORENCE2_MODEL
|
| 251 |
|
| 252 |
except Exception as e:
|
| 253 |
print(f"❌ Error loading Florence-2: {e}")
|
| 254 |
+
import traceback
|
| 255 |
+
traceback.print_exc()
|
| 256 |
return None, None
|
| 257 |
|
| 258 |
def analyze_with_florence2(image, task_prompt):
|
|
|
|
| 272 |
try:
|
| 273 |
if isinstance(image, np.ndarray):
|
| 274 |
image = Image.fromarray(image)
|
|
|
|
|
|
|
| 275 |
else:
|
| 276 |
return "❌ Invalid image format. Please upload a valid image."
|
| 277 |
except Exception as e:
|
|
|
|
| 288 |
new_size = (int(image.width * ratio), int(image.height * ratio))
|
| 289 |
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
| 290 |
|
| 291 |
+
# 按照官方文檔準備輸入
|
| 292 |
try:
|
| 293 |
inputs = processor(
|
| 294 |
text=task_prompt,
|
| 295 |
images=image,
|
| 296 |
return_tensors="pt"
|
| 297 |
+
).to(device, torch_dtype)
|
| 298 |
except Exception as e:
|
| 299 |
print(f"❌ Error processing image: {e}")
|
| 300 |
return f"❌ Error processing image: {str(e)}"
|
| 301 |
|
| 302 |
+
# 按照官方文檔生成
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
try:
|
| 304 |
generated_ids = model.generate(
|
| 305 |
input_ids=inputs["input_ids"],
|
| 306 |
pixel_values=inputs["pixel_values"],
|
| 307 |
+
max_new_tokens=1024,
|
| 308 |
+
do_sample=False,
|
| 309 |
+
num_beams=3,
|
| 310 |
)
|
| 311 |
except Exception as e:
|
| 312 |
print(f"❌ Error generating text: {e}")
|
| 313 |
return f"❌ Error during analysis: {str(e)}"
|
| 314 |
|
| 315 |
+
# 解碼
|
| 316 |
try:
|
| 317 |
generated_text = processor.batch_decode(
|
| 318 |
generated_ids,
|
| 319 |
+
skip_special_tokens=False
|
| 320 |
)[0]
|
| 321 |
except Exception as e:
|
| 322 |
print(f"❌ Error decoding text: {e}")
|
| 323 |
return f"❌ Error decoding result: {str(e)}"
|
| 324 |
|
| 325 |
+
# 使用 post_process_generation 解析結果
|
| 326 |
+
try:
|
| 327 |
+
parsed_answer = processor.post_process_generation(
|
| 328 |
+
generated_text,
|
| 329 |
+
task=task_prompt,
|
| 330 |
+
image_size=(image.width, image.height)
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# 將結果轉換為可讀字符串
|
| 334 |
+
if isinstance(parsed_answer, dict):
|
| 335 |
+
result_str = ""
|
| 336 |
+
for key, value in parsed_answer.items():
|
| 337 |
+
result_str += f"{key}:\n{value}\n\n"
|
| 338 |
+
return result_str.strip()
|
| 339 |
+
else:
|
| 340 |
+
return str(parsed_answer)
|
| 341 |
+
|
| 342 |
+
except Exception as e:
|
| 343 |
+
print(f"❌ Error in post-processing: {e}")
|
| 344 |
+
# 如果後處理失敗,返回原始生成的文本
|
| 345 |
+
return f"Raw output: {generated_text}"
|
| 346 |
|
| 347 |
except Exception as e:
|
| 348 |
print(f"❌ Error in Florence-2 analysis: {e}")
|
|
|
|
| 415 |
controlnet_model_name = get_controlnet_model(controlnet_type)
|
| 416 |
controlnet = ControlNetModel.from_pretrained(
|
| 417 |
controlnet_model_name,
|
| 418 |
+
torch_dtype=torch_dtype
|
| 419 |
).to(device)
|
| 420 |
|
| 421 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 422 |
model_name,
|
| 423 |
controlnet=controlnet,
|
| 424 |
+
torch_dtype=torch_dtype,
|
| 425 |
safety_checker=None,
|
| 426 |
requires_safety_checker=False,
|
| 427 |
use_safetensors=True,
|
| 428 |
+
variant="fp16" if torch_dtype == torch.float16 else None
|
| 429 |
).to(device)
|
| 430 |
else:
|
| 431 |
raise ValueError(f"SDXL model {model_name} only supports limited ControlNet types: {list(SDXL_CONTROLNET_MODELS.keys())}")
|
|
|
|
| 434 |
controlnet_model_name = get_controlnet_model(controlnet_type)
|
| 435 |
controlnet = ControlNetModel.from_pretrained(
|
| 436 |
controlnet_model_name,
|
| 437 |
+
torch_dtype=torch_dtype
|
| 438 |
).to(device)
|
| 439 |
|
| 440 |
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 441 |
model_name,
|
| 442 |
controlnet=controlnet,
|
| 443 |
+
torch_dtype=torch_dtype,
|
| 444 |
safety_checker=None,
|
| 445 |
requires_safety_checker=False,
|
| 446 |
use_safetensors=True,
|
| 447 |
+
variant="fp16" if torch_dtype == torch.float16 else None
|
| 448 |
).to(device)
|
| 449 |
|
| 450 |
# Apply LoRA if specified
|
|
|
|
| 551 |
# Load base and refiner
|
| 552 |
CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
|
| 553 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 554 |
+
torch_dtype=torch_dtype,
|
| 555 |
safety_checker=None,
|
| 556 |
requires_safety_checker=False,
|
| 557 |
use_safetensors=True,
|
| 558 |
+
variant="fp16" if torch_dtype == torch.float16 else None
|
| 559 |
).to(device)
|
| 560 |
|
| 561 |
CURRENT_SDXL_REFINER = StableDiffusionXLPipeline.from_pretrained(
|
| 562 |
model_name,
|
| 563 |
+
torch_dtype=torch_dtype,
|
| 564 |
safety_checker=None,
|
| 565 |
requires_safety_checker=False,
|
| 566 |
use_safetensors=True,
|
| 567 |
+
variant="fp16" if torch_dtype == torch.float16 else None,
|
| 568 |
text_encoder_2=CURRENT_T2I_PIPE.text_encoder_2,
|
| 569 |
vae=CURRENT_T2I_PIPE.vae
|
| 570 |
).to(device)
|
|
|
|
| 572 |
else:
|
| 573 |
CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
|
| 574 |
model_name,
|
| 575 |
+
torch_dtype=torch_dtype,
|
| 576 |
safety_checker=None,
|
| 577 |
requires_safety_checker=False,
|
| 578 |
use_safetensors=True,
|
| 579 |
+
variant="fp16" if torch_dtype == torch.float16 else None
|
| 580 |
).to(device)
|
| 581 |
print(f"✅ Loaded SDXL model: {model_name}")
|
| 582 |
else:
|
| 583 |
# Load SD1.5 model
|
| 584 |
CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
|
| 585 |
model_name,
|
| 586 |
+
torch_dtype=torch_dtype,
|
| 587 |
safety_checker=None,
|
| 588 |
requires_safety_checker=False,
|
| 589 |
use_safetensors=True,
|
| 590 |
+
variant="fp16" if torch_dtype == torch.float16 else None
|
| 591 |
).to(device)
|
| 592 |
print(f"✅ Loaded SD1.5 model: {model_name}")
|
| 593 |
|
|
|
|
| 642 |
if is_sdxl_model(model_name):
|
| 643 |
CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
|
| 644 |
model_name,
|
| 645 |
+
torch_dtype=torch_dtype,
|
| 646 |
safety_checker=None,
|
| 647 |
requires_safety_checker=False
|
| 648 |
).to(device)
|
| 649 |
else:
|
| 650 |
CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
|
| 651 |
model_name,
|
| 652 |
+
torch_dtype=torch_dtype,
|
| 653 |
safety_checker=None,
|
| 654 |
requires_safety_checker=False
|
| 655 |
).to(device)
|
|
|
|
| 1139 |
gr.Markdown("""
|
| 1140 |
### Microsoft Florence-2 Vision Language Model
|
| 1141 |
**Pre-trained Tasks:**
|
| 1142 |
+
- `<OCR>`: Text recognition (Extract text from image)
|
| 1143 |
+
- `<CAPTION>`: Image captioning (Generate a caption)
|
| 1144 |
+
- `<DETAILED_CAPTION>`: Detailed caption (More detailed description)
|
| 1145 |
+
- `<MORE_DETAILED_CAPTION>`: More detailed caption (Even more details)
|
| 1146 |
+
- `<OD>`: Object detection (Detect objects with bounding boxes)
|
| 1147 |
- `<OPEN_VOCABULARY_DETECTION>`: Open-vocabulary detection
|
| 1148 |
- `<REGION_PROPOSAL>`: Region proposal
|
| 1149 |
|
| 1150 |
+
**How to use:**
|
| 1151 |
+
1. Upload an image
|
| 1152 |
+
2. Select a task from the dropdown
|
| 1153 |
+
3. Click "Analyze Image"
|
| 1154 |
+
4. Results will be displayed in the text box
|
| 1155 |
+
|
| 1156 |
+
**Example tasks:**
|
| 1157 |
+
- Extract text from a document: `<OCR>`
|
| 1158 |
+
- Describe what's in the image: `<CAPTION>`
|
| 1159 |
+
- Detect objects in the image: `<OD>`
|
| 1160 |
""")
|
| 1161 |
|
| 1162 |
with gr.Row():
|