Spaces:
Build error
Build error
| import os | |
| import base64 | |
| from unstructured.partition.pdf import partition_pdf | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.messages import HumanMessage | |
| from PIL import Image | |
| import pytesseract | |
| # Function to partition PDF | |
| def partition_pdf_elements(filename): | |
| raw_pdf_elements = partition_pdf( | |
| filename=filename, | |
| strategy="hi_res", | |
| extract_images_in_pdf=True, | |
| extract_image_block_types=["Image", "Table"], | |
| extract_image_block_to_payload=False, | |
| extract_image_block_output_dir="extracted_data" | |
| ) | |
| return raw_pdf_elements | |
| # Function to classify elements | |
| def classify_elements(raw_pdf_elements): | |
| Header, Footer, Title, NarrativeText, Text, ListItem, img, tab = [], [], [], [], [], [], [], [] | |
| for element in raw_pdf_elements: | |
| if "unstructured.documents.elements.Header" in str(type(element)): | |
| Header.append(str(element)) | |
| elif "unstructured.documents.elements.Footer" in str(type(element)): | |
| Footer.append(str(element)) | |
| elif "unstructured.documents.elements.Title" in str(type(element)): | |
| Title.append(str(element)) | |
| elif "unstructured.documents.elements.NarrativeText" in str(type(element)): | |
| NarrativeText.append(str(element)) | |
| elif "unstructured.documents.elements.Text" in str(type(element)): | |
| Text.append(str(element)) | |
| elif "unstructured.documents.elements.ListItem" in str(type(element)): | |
| ListItem.append(str(element)) | |
| elif "unstructured.documents.elements.Image" in str(type(element)): | |
| img.append(str(element)) | |
| elif "unstructured.documents.elements.Table" in str(type(element)): | |
| tab.append(str(element)) | |
| return Header, Footer, Title, NarrativeText, Text, ListItem, img, tab | |
| # Function to summarize tables | |
| def summarize_tables(tab, google_api_key): | |
| prompt_text = """You are an assistant tasked with summarizing tables for retrieval. \ | |
| These summaries will be embedded and used to retrieve the raw table elements. \ | |
| Give a concise summary of the table that is well optimized for retrieval. Table {element} """ | |
| prompt = ChatPromptTemplate.from_template(prompt_text) | |
| model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key) | |
| summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser() | |
| table_summaries = summarize_chain.batch(tab, {"max_concurrency": 5}) | |
| return table_summaries | |
| # Function to encode image | |
| def encode_image(image_path): | |
| with open(image_path, "rb") as image_file: | |
| return base64.b64encode(image_file.read()).decode("utf-8") | |
| # Function to summarize images | |
| def image_summarize(img_base64, prompt, google_api_key): | |
| chat = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key, max_output_tokens=512) | |
| msg = chat.invoke( | |
| [ | |
| HumanMessage( | |
| content=[ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}} | |
| ] | |
| ) | |
| ] | |
| ) | |
| return msg.content | |
| # Function to generate image summaries | |
| def generate_img_summaries(path, google_api_key): | |
| img_base64_list = [] | |
| image_summaries = [] | |
| prompt = """You are an assistant tasked with summarizing images for retrieval. \ | |
| These summaries will be embedded and used to retrieve the raw image. \ | |
| Give a concise summary of the image that is well optimized for retrieval. | |
| also give the image output if possible""" | |
| base64_image = encode_image(path) | |
| img_base64_list.append(base64_image) | |
| image_summaries.append(image_summarize(base64_image, prompt, google_api_key)) | |
| return img_base64_list, image_summaries | |
| # Function to handle text-based queries | |
| def handle_query(query, google_api_key, text_elements): | |
| prompt_text = f"You are an assistant tasked with answering the following query based on the provided text elements:\n\n{query}\n\nText elements: {text_elements}" | |
| model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key) | |
| msg = model.invoke([HumanMessage(content=prompt_text)]) | |
| return msg.content | |
| # Function to extract text from an image | |
| def extract_text_from_image(image_path): | |
| image = Image.open(image_path) | |
| text = pytesseract.image_to_string(image) | |
| return text | |
| # Function to handle image-based queries | |
| def handle_image_query(image_path, query, google_api_key): | |
| extracted_text = extract_text_from_image(image_path) | |
| prompt_text = f"You are an assistant tasked with answering the following query based on the extracted text from the image:\n\n{query}\n\nExtracted text: {extracted_text}" | |
| model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=google_api_key) | |
| msg = model.invoke([HumanMessage(content=prompt_text)]) | |
| return msg.content | |