K1Z3M1112 commited on
Commit
c8108cb
·
verified ·
1 Parent(s): b731a45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -52
app.py CHANGED
@@ -20,9 +20,9 @@ if torch.cuda.is_available():
20
 
21
  # Device
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- dtype = torch.float32 # 使用 float32 來避免兼容性問題
24
 
25
- print(f"🖥️ Device: {device} | dtype: {dtype}")
26
 
27
  # Lazy import (to avoid long startup if unused)
28
  from diffusers import (
@@ -234,24 +234,25 @@ def load_florence2():
234
 
235
  print("📥 Loading Microsoft/Florence-2-base...")
236
 
237
- # 加載 processor
238
- FLORENCE2_PROCESSOR = AutoProcessor.from_pretrained(
239
- "microsoft/Florence-2-base",
240
- trust_remote_code=True
241
- )
242
-
243
- # 使用較舊的加載方式避免兼容性問題
244
  FLORENCE2_MODEL = AutoModelForCausalLM.from_pretrained(
245
- "microsoft/Florence-2-base",
246
- torch_dtype=dtype,
247
  trust_remote_code=True
248
  ).to(device)
249
 
 
 
 
 
 
250
  print("✅ Florence-2 model loaded successfully")
251
  return FLORENCE2_PROCESSOR, FLORENCE2_MODEL
252
 
253
  except Exception as e:
254
  print(f"❌ Error loading Florence-2: {e}")
 
 
255
  return None, None
256
 
257
  def analyze_with_florence2(image, task_prompt):
@@ -271,8 +272,6 @@ def analyze_with_florence2(image, task_prompt):
271
  try:
272
  if isinstance(image, np.ndarray):
273
  image = Image.fromarray(image)
274
- elif hasattr(image, 'shape'): # 可能是 torch tensor
275
- image = Image.fromarray(image.cpu().numpy())
276
  else:
277
  return "❌ Invalid image format. Please upload a valid image."
278
  except Exception as e:
@@ -289,49 +288,61 @@ def analyze_with_florence2(image, task_prompt):
289
  new_size = (int(image.width * ratio), int(image.height * ratio))
290
  image = image.resize(new_size, Image.Resampling.LANCZOS)
291
 
292
- # Prepare input
293
  try:
294
  inputs = processor(
295
  text=task_prompt,
296
  images=image,
297
  return_tensors="pt"
298
- ).to(device)
299
  except Exception as e:
300
  print(f"❌ Error processing image: {e}")
301
  return f"❌ Error processing image: {str(e)}"
302
 
303
- # 檢查 inputs 是否有效
304
- if inputs is None or 'pixel_values' not in inputs:
305
- return "❌ Failed to process image for analysis."
306
-
307
- # Generate
308
  try:
309
  generated_ids = model.generate(
310
  input_ids=inputs["input_ids"],
311
  pixel_values=inputs["pixel_values"],
312
- max_new_tokens=512, # 減少 token 數量以加快處理
313
- num_beams=2, # 減少 beams 以加快處理
314
- early_stopping=True
315
  )
316
  except Exception as e:
317
  print(f"❌ Error generating text: {e}")
318
  return f"❌ Error during analysis: {str(e)}"
319
 
320
- # Decode
321
  try:
322
  generated_text = processor.batch_decode(
323
  generated_ids,
324
- skip_special_tokens=True
325
  )[0]
326
  except Exception as e:
327
  print(f"❌ Error decoding text: {e}")
328
  return f"❌ Error decoding result: {str(e)}"
329
 
330
- # Clean up
331
- if device.type == "cuda":
332
- torch.cuda.empty_cache()
333
-
334
- return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
  except Exception as e:
337
  print(f"❌ Error in Florence-2 analysis: {e}")
@@ -404,17 +415,17 @@ def get_pipeline(model_name: str, controlnet_type: str = "lineart", lora_model:
404
  controlnet_model_name = get_controlnet_model(controlnet_type)
405
  controlnet = ControlNetModel.from_pretrained(
406
  controlnet_model_name,
407
- torch_dtype=dtype
408
  ).to(device)
409
 
410
  pipe = StableDiffusionXLPipeline.from_pretrained(
411
  model_name,
412
  controlnet=controlnet,
413
- torch_dtype=dtype,
414
  safety_checker=None,
415
  requires_safety_checker=False,
416
  use_safetensors=True,
417
- variant="fp16" if dtype == torch.float16 else None
418
  ).to(device)
419
  else:
420
  raise ValueError(f"SDXL model {model_name} only supports limited ControlNet types: {list(SDXL_CONTROLNET_MODELS.keys())}")
@@ -423,17 +434,17 @@ def get_pipeline(model_name: str, controlnet_type: str = "lineart", lora_model:
423
  controlnet_model_name = get_controlnet_model(controlnet_type)
424
  controlnet = ControlNetModel.from_pretrained(
425
  controlnet_model_name,
426
- torch_dtype=dtype
427
  ).to(device)
428
 
429
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
430
  model_name,
431
  controlnet=controlnet,
432
- torch_dtype=dtype,
433
  safety_checker=None,
434
  requires_safety_checker=False,
435
  use_safetensors=True,
436
- variant="fp16" if dtype == torch.float16 else None
437
  ).to(device)
438
 
439
  # Apply LoRA if specified
@@ -540,20 +551,20 @@ def load_t2i_model(model_name: str, lora_model: str = None, lora_weight: float =
540
  # Load base and refiner
541
  CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
542
  "stabilityai/stable-diffusion-xl-base-1.0",
543
- torch_dtype=dtype,
544
  safety_checker=None,
545
  requires_safety_checker=False,
546
  use_safetensors=True,
547
- variant="fp16" if dtype == torch.float16 else None
548
  ).to(device)
549
 
550
  CURRENT_SDXL_REFINER = StableDiffusionXLPipeline.from_pretrained(
551
  model_name,
552
- torch_dtype=dtype,
553
  safety_checker=None,
554
  requires_safety_checker=False,
555
  use_safetensors=True,
556
- variant="fp16" if dtype == torch.float16 else None,
557
  text_encoder_2=CURRENT_T2I_PIPE.text_encoder_2,
558
  vae=CURRENT_T2I_PIPE.vae
559
  ).to(device)
@@ -561,22 +572,22 @@ def load_t2i_model(model_name: str, lora_model: str = None, lora_weight: float =
561
  else:
562
  CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
563
  model_name,
564
- torch_dtype=dtype,
565
  safety_checker=None,
566
  requires_safety_checker=False,
567
  use_safetensors=True,
568
- variant="fp16" if dtype == torch.float16 else None
569
  ).to(device)
570
  print(f"✅ Loaded SDXL model: {model_name}")
571
  else:
572
  # Load SD1.5 model
573
  CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
574
  model_name,
575
- torch_dtype=dtype,
576
  safety_checker=None,
577
  requires_safety_checker=False,
578
  use_safetensors=True,
579
- variant="fp16" if dtype == torch.float16 else None
580
  ).to(device)
581
  print(f"✅ Loaded SD1.5 model: {model_name}")
582
 
@@ -631,14 +642,14 @@ def load_t2i_model(model_name: str, lora_model: str = None, lora_weight: float =
631
  if is_sdxl_model(model_name):
632
  CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
633
  model_name,
634
- torch_dtype=dtype,
635
  safety_checker=None,
636
  requires_safety_checker=False
637
  ).to(device)
638
  else:
639
  CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
640
  model_name,
641
- torch_dtype=dtype,
642
  safety_checker=None,
643
  requires_safety_checker=False
644
  ).to(device)
@@ -1128,15 +1139,24 @@ with gr.Blocks(title="🎨 Advanced Image Generation Suite", theme=gr.themes.Sof
1128
  gr.Markdown("""
1129
  ### Microsoft Florence-2 Vision Language Model
1130
  **Pre-trained Tasks:**
1131
- - `<OCR>`: Text recognition
1132
- - `<CAPTION>`: Image captioning
1133
- - `<DETAILED_CAPTION>`: Detailed caption
1134
- - `<MORE_DETAILED_CAPTION>`: More detailed caption
1135
- - `<OD>`: Object detection
1136
  - `<OPEN_VOCABULARY_DETECTION>`: Open-vocabulary detection
1137
  - `<REGION_PROPOSAL>`: Region proposal
1138
 
1139
- **Note:** Upload an image and select a task to analyze it.
 
 
 
 
 
 
 
 
 
1140
  """)
1141
 
1142
  with gr.Row():
 
20
 
21
  # Device
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
 
25
+ print(f"🖥️ Device: {device} | dtype: {torch_dtype}")
26
 
27
  # Lazy import (to avoid long startup if unused)
28
  from diffusers import (
 
234
 
235
  print("📥 Loading Microsoft/Florence-2-base...")
236
 
237
+ # 按照官方文檔加載模型
 
 
 
 
 
 
238
  FLORENCE2_MODEL = AutoModelForCausalLM.from_pretrained(
239
+ "microsoft/Florence-2-base",
240
+ torch_dtype=torch_dtype,
241
  trust_remote_code=True
242
  ).to(device)
243
 
244
+ FLORENCE2_PROCESSOR = AutoProcessor.from_pretrained(
245
+ "microsoft/Florence-2-base",
246
+ trust_remote_code=True
247
+ )
248
+
249
  print("✅ Florence-2 model loaded successfully")
250
  return FLORENCE2_PROCESSOR, FLORENCE2_MODEL
251
 
252
  except Exception as e:
253
  print(f"❌ Error loading Florence-2: {e}")
254
+ import traceback
255
+ traceback.print_exc()
256
  return None, None
257
 
258
  def analyze_with_florence2(image, task_prompt):
 
272
  try:
273
  if isinstance(image, np.ndarray):
274
  image = Image.fromarray(image)
 
 
275
  else:
276
  return "❌ Invalid image format. Please upload a valid image."
277
  except Exception as e:
 
288
  new_size = (int(image.width * ratio), int(image.height * ratio))
289
  image = image.resize(new_size, Image.Resampling.LANCZOS)
290
 
291
+ # 按照官方文檔準備輸入
292
  try:
293
  inputs = processor(
294
  text=task_prompt,
295
  images=image,
296
  return_tensors="pt"
297
+ ).to(device, torch_dtype)
298
  except Exception as e:
299
  print(f"❌ Error processing image: {e}")
300
  return f"❌ Error processing image: {str(e)}"
301
 
302
+ # 按照官方文檔生成
 
 
 
 
303
  try:
304
  generated_ids = model.generate(
305
  input_ids=inputs["input_ids"],
306
  pixel_values=inputs["pixel_values"],
307
+ max_new_tokens=1024,
308
+ do_sample=False,
309
+ num_beams=3,
310
  )
311
  except Exception as e:
312
  print(f"❌ Error generating text: {e}")
313
  return f"❌ Error during analysis: {str(e)}"
314
 
315
+ # 解碼
316
  try:
317
  generated_text = processor.batch_decode(
318
  generated_ids,
319
+ skip_special_tokens=False
320
  )[0]
321
  except Exception as e:
322
  print(f"❌ Error decoding text: {e}")
323
  return f"❌ Error decoding result: {str(e)}"
324
 
325
+ # 使用 post_process_generation 解析結果
326
+ try:
327
+ parsed_answer = processor.post_process_generation(
328
+ generated_text,
329
+ task=task_prompt,
330
+ image_size=(image.width, image.height)
331
+ )
332
+
333
+ # 將結果轉換為可讀字符串
334
+ if isinstance(parsed_answer, dict):
335
+ result_str = ""
336
+ for key, value in parsed_answer.items():
337
+ result_str += f"{key}:\n{value}\n\n"
338
+ return result_str.strip()
339
+ else:
340
+ return str(parsed_answer)
341
+
342
+ except Exception as e:
343
+ print(f"❌ Error in post-processing: {e}")
344
+ # 如果後處理失敗,返回原始生成的文本
345
+ return f"Raw output: {generated_text}"
346
 
347
  except Exception as e:
348
  print(f"❌ Error in Florence-2 analysis: {e}")
 
415
  controlnet_model_name = get_controlnet_model(controlnet_type)
416
  controlnet = ControlNetModel.from_pretrained(
417
  controlnet_model_name,
418
+ torch_dtype=torch_dtype
419
  ).to(device)
420
 
421
  pipe = StableDiffusionXLPipeline.from_pretrained(
422
  model_name,
423
  controlnet=controlnet,
424
+ torch_dtype=torch_dtype,
425
  safety_checker=None,
426
  requires_safety_checker=False,
427
  use_safetensors=True,
428
+ variant="fp16" if torch_dtype == torch.float16 else None
429
  ).to(device)
430
  else:
431
  raise ValueError(f"SDXL model {model_name} only supports limited ControlNet types: {list(SDXL_CONTROLNET_MODELS.keys())}")
 
434
  controlnet_model_name = get_controlnet_model(controlnet_type)
435
  controlnet = ControlNetModel.from_pretrained(
436
  controlnet_model_name,
437
+ torch_dtype=torch_dtype
438
  ).to(device)
439
 
440
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
441
  model_name,
442
  controlnet=controlnet,
443
+ torch_dtype=torch_dtype,
444
  safety_checker=None,
445
  requires_safety_checker=False,
446
  use_safetensors=True,
447
+ variant="fp16" if torch_dtype == torch.float16 else None
448
  ).to(device)
449
 
450
  # Apply LoRA if specified
 
551
  # Load base and refiner
552
  CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
553
  "stabilityai/stable-diffusion-xl-base-1.0",
554
+ torch_dtype=torch_dtype,
555
  safety_checker=None,
556
  requires_safety_checker=False,
557
  use_safetensors=True,
558
+ variant="fp16" if torch_dtype == torch.float16 else None
559
  ).to(device)
560
 
561
  CURRENT_SDXL_REFINER = StableDiffusionXLPipeline.from_pretrained(
562
  model_name,
563
+ torch_dtype=torch_dtype,
564
  safety_checker=None,
565
  requires_safety_checker=False,
566
  use_safetensors=True,
567
+ variant="fp16" if torch_dtype == torch.float16 else None,
568
  text_encoder_2=CURRENT_T2I_PIPE.text_encoder_2,
569
  vae=CURRENT_T2I_PIPE.vae
570
  ).to(device)
 
572
  else:
573
  CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
574
  model_name,
575
+ torch_dtype=torch_dtype,
576
  safety_checker=None,
577
  requires_safety_checker=False,
578
  use_safetensors=True,
579
+ variant="fp16" if torch_dtype == torch.float16 else None
580
  ).to(device)
581
  print(f"✅ Loaded SDXL model: {model_name}")
582
  else:
583
  # Load SD1.5 model
584
  CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
585
  model_name,
586
+ torch_dtype=torch_dtype,
587
  safety_checker=None,
588
  requires_safety_checker=False,
589
  use_safetensors=True,
590
+ variant="fp16" if torch_dtype == torch.float16 else None
591
  ).to(device)
592
  print(f"✅ Loaded SD1.5 model: {model_name}")
593
 
 
642
  if is_sdxl_model(model_name):
643
  CURRENT_T2I_PIPE = StableDiffusionXLPipeline.from_pretrained(
644
  model_name,
645
+ torch_dtype=torch_dtype,
646
  safety_checker=None,
647
  requires_safety_checker=False
648
  ).to(device)
649
  else:
650
  CURRENT_T2I_PIPE = StableDiffusionPipeline.from_pretrained(
651
  model_name,
652
+ torch_dtype=torch_dtype,
653
  safety_checker=None,
654
  requires_safety_checker=False
655
  ).to(device)
 
1139
  gr.Markdown("""
1140
  ### Microsoft Florence-2 Vision Language Model
1141
  **Pre-trained Tasks:**
1142
+ - `<OCR>`: Text recognition (Extract text from image)
1143
+ - `<CAPTION>`: Image captioning (Generate a caption)
1144
+ - `<DETAILED_CAPTION>`: Detailed caption (More detailed description)
1145
+ - `<MORE_DETAILED_CAPTION>`: More detailed caption (Even more details)
1146
+ - `<OD>`: Object detection (Detect objects with bounding boxes)
1147
  - `<OPEN_VOCABULARY_DETECTION>`: Open-vocabulary detection
1148
  - `<REGION_PROPOSAL>`: Region proposal
1149
 
1150
+ **How to use:**
1151
+ 1. Upload an image
1152
+ 2. Select a task from the dropdown
1153
+ 3. Click "Analyze Image"
1154
+ 4. Results will be displayed in the text box
1155
+
1156
+ **Example tasks:**
1157
+ - Extract text from a document: `<OCR>`
1158
+ - Describe what's in the image: `<CAPTION>`
1159
+ - Detect objects in the image: `<OD>`
1160
  """)
1161
 
1162
  with gr.Row():