chrisjcc commited on
Commit
a7ea9b8
·
verified ·
1 Parent(s): 5872bd4

Minor update

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -8,7 +8,8 @@ import gradio as gr
8
 
9
  hf_api_key = os.environ['HF_API_KEY']
10
 
11
- get_completion = pipeline("ner", model="Salesforce/blip-image-captioning-base")
 
12
 
13
  def image_to_base64_str(pil_image):
14
  byte_arr = io.BytesIO()
@@ -18,21 +19,20 @@ def image_to_base64_str(pil_image):
18
  return str(base64.b64encode(byte_arr).decode('utf-8'))
19
 
20
  def captioner(image):
21
- base64_image = image_to_base64_str(image)
22
- result = get_completion(base64_image)
23
 
24
  return result[0]['generated_text']
25
 
26
- gr.close_all()
27
  demo = gr.Interface(fn=captioner,
28
  inputs=[gr.Image(label="Upload image", type="pil")],
29
  outputs=[gr.Textbox(label="Caption")],
30
  title="Image Captioning with BLIP",
31
  description="Caption any image using the BLIP model",
32
- allow_flagging="never",
33
  examples=["images/christmas_dog.jpg", "images/bird_flight.jpg", "images/cow.jpg"])
34
 
35
  demo.launch(
36
  share=True,
37
- #server_port=int(os.environ['PORT1'])
38
  )
 
8
 
9
  hf_api_key = os.environ['HF_API_KEY']
10
 
11
+ # Load the image-to-text pipeline with BLIP model
12
+ get_completion = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
13
 
14
  def image_to_base64_str(pil_image):
15
  byte_arr = io.BytesIO()
 
19
  return str(base64.b64encode(byte_arr).decode('utf-8'))
20
 
21
  def captioner(image):
22
+ # The BLIP model expects a PIL image directly
23
+ result = get_completion(image)
24
 
25
  return result[0]['generated_text']
26
 
 
27
  demo = gr.Interface(fn=captioner,
28
  inputs=[gr.Image(label="Upload image", type="pil")],
29
  outputs=[gr.Textbox(label="Caption")],
30
  title="Image Captioning with BLIP",
31
  description="Caption any image using the BLIP model",
32
+ flagging_mode="never", # Updated from allow_flagging
33
  examples=["images/christmas_dog.jpg", "images/bird_flight.jpg", "images/cow.jpg"])
34
 
35
  demo.launch(
36
  share=True,
37
+ # server_port=int(os.environ.get('PORT3', 7860)) # Uncomment if needed
38
  )