Commit
·
e7afcc5
0
Parent(s):
Duplicate from ritikjain51/PDF-experimentation
Browse filesCo-authored-by: Ritik Jain <[email protected]>
- .gitattributes +34 -0
- .gitignore +6 -0
- Dockerfile +19 -0
- LICENSE +0 -0
- README.md +35 -0
- __init__.py +0 -0
- app.py +171 -0
- backend.py +146 -0
- configs.py +4 -0
- qna.py +0 -0
- requirements.txt +9 -0
- schema.py +63 -0
.gitattributes
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.Chroma
|
| 2 |
+
.chroma
|
| 3 |
+
*.ipynb
|
| 4 |
+
*.pyc
|
| 5 |
+
__pycache__
|
| 6 |
+
.faiss
|
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /code
|
| 4 |
+
|
| 5 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 6 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 7 |
+
PIP_NO_CACHE_DIR=off \
|
| 8 |
+
PIP_DISABLE_PIP_VERSION_CHECK=on \
|
| 9 |
+
PIP_DEFAULT_TIMEOUT=100 \
|
| 10 |
+
HNSWLIB_NO_NATIVE=1
|
| 11 |
+
|
| 12 |
+
RUN apt-get update && apt install python3-dev libprotobuf-dev build-essential -y
|
| 13 |
+
|
| 14 |
+
COPY . .
|
| 15 |
+
RUN pip install --upgrade pip
|
| 16 |
+
RUN pip install duckdb
|
| 17 |
+
RUN pip install -r requirements.txt
|
| 18 |
+
EXPOSE 8071
|
| 19 |
+
CMD ["gradio", "app.py"]
|
LICENSE
ADDED
|
File without changes
|
README.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
title: PDF Experimentation
|
| 4 |
+
sdk: streamlit
|
| 5 |
+
emoji: 🚀
|
| 6 |
+
colorFrom: purple
|
| 7 |
+
colorTo: gray
|
| 8 |
+
pinned: true
|
| 9 |
+
app_file: app.py
|
| 10 |
+
duplicated_from: ritikjain51/PDF-experimentation
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## Next Steps
|
| 14 |
+
|
| 15 |
+
- [x] Build UI using Streamlit
|
| 16 |
+
- [x] Add Advance Settings in sidebar
|
| 17 |
+
- [x] Build backend using Langchain
|
| 18 |
+
- [x] Dockerize
|
| 19 |
+
- [ ] Add Docs
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
### UI Components
|
| 23 |
+
|
| 24 |
+
- [x] Add Upload PDF Tab
|
| 25 |
+
- [x] Show PDF Tab
|
| 26 |
+
- [x] Question Answer Tab
|
| 27 |
+
- [x] Conversational Tab
|
| 28 |
+
- [x] Advance Settings
|
| 29 |
+
- [x] Model Settings
|
| 30 |
+
|
| 31 |
+
### Backend Components
|
| 32 |
+
- [x] Read PDF and ingest
|
| 33 |
+
- [x] Fetch Configuration
|
| 34 |
+
- [x] Vector DB Indexing
|
| 35 |
+
- []
|
__init__.py
ADDED
|
File without changes
|
app.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
from streamlit_chat import message
|
| 5 |
+
from streamlit_extras.colored_header import colored_header
|
| 6 |
+
|
| 7 |
+
from backend import QnASystem
|
| 8 |
+
from schema import TransformType, EmbeddingTypes, IndexerType, BotType
|
| 9 |
+
|
| 10 |
+
kwargs = {}
|
| 11 |
+
source_docs = []
|
| 12 |
+
st.set_page_config(page_title="PDFChat - An LLM-powered experimentation app")
|
| 13 |
+
|
| 14 |
+
if "qna_system" not in st.session_state:
|
| 15 |
+
st.session_state.qna_system = QnASystem()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def show_pdf(f):
|
| 19 |
+
f.seek(0)
|
| 20 |
+
base64_pdf = base64.b64encode(f.read()).decode('utf-8')
|
| 21 |
+
pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="700" height="800" ' \
|
| 22 |
+
f'type="application/pdf"></iframe>'
|
| 23 |
+
st.markdown(pdf_display, unsafe_allow_html=True)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def model_settings():
|
| 27 |
+
kwargs["temperature"] = st.slider("Temperature", max_value=1.0, min_value=0.0)
|
| 28 |
+
kwargs["max_tokens"] = st.number_input("Max Token", min_value=0, value=512)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
st.title("PDF Question and Answering")
|
| 32 |
+
|
| 33 |
+
tab1, tab2, tab3 = st.tabs(["Upload and Ingest PDF", "Ask", "Show PDF"])
|
| 34 |
+
|
| 35 |
+
with st.sidebar:
|
| 36 |
+
st.header("Advance Setting ⚙️")
|
| 37 |
+
require_pdf = st.checkbox("Show PDF", value=1)
|
| 38 |
+
st.markdown('---')
|
| 39 |
+
kwargs["bot_type"] = st.selectbox("Bot Type", options=BotType)
|
| 40 |
+
st.markdown("---")
|
| 41 |
+
st.text("Model Parameters")
|
| 42 |
+
kwargs["return_documents"] = st.checkbox("Require Source Documents", value=True)
|
| 43 |
+
text_transform = st.selectbox("Text Transformer", options=TransformType)
|
| 44 |
+
st.markdown("---")
|
| 45 |
+
selected_model = st.selectbox("Select Model", options=EmbeddingTypes)
|
| 46 |
+
match selected_model:
|
| 47 |
+
case EmbeddingTypes.OPENAI:
|
| 48 |
+
api_key = st.text_input("OpenAI API Key", placeholder="sk-...", type="password")
|
| 49 |
+
if not api_key.startswith('sk-'):
|
| 50 |
+
st.warning('Please enter your OpenAI API key!', icon='⚠')
|
| 51 |
+
model_settings()
|
| 52 |
+
case EmbeddingTypes.HUGGING_FACE:
|
| 53 |
+
api_key = st.text_input("Hugging Face API Key", placeholder="hg-...", type="password")
|
| 54 |
+
if not api_key.startswith('hg-'):
|
| 55 |
+
st.warning('Please enter your HuggingFace API key!', icon='⚠')
|
| 56 |
+
huggingface_model = st.selectbox("Choose Model", options=["google/flan-t5-xl"])
|
| 57 |
+
model_settings()
|
| 58 |
+
case EmbeddingTypes.COHERE:
|
| 59 |
+
api_key = st.text_input("Cohere API Key", placeholder="...", type="password")
|
| 60 |
+
if not api_key:
|
| 61 |
+
st.warning('Please enter your Cohere API key!', icon='⚠')
|
| 62 |
+
model_settings()
|
| 63 |
+
case _:
|
| 64 |
+
api_key = None
|
| 65 |
+
kwargs["api_key"] = api_key
|
| 66 |
+
st.markdown("---")
|
| 67 |
+
|
| 68 |
+
vector_indexer = st.selectbox("Vector Indexer", options=IndexerType)
|
| 69 |
+
match vector_indexer:
|
| 70 |
+
case IndexerType.ELASTICSEARCH:
|
| 71 |
+
kwargs["elasticsearch_url"] = st.text_input("Elastic Search URL: ")
|
| 72 |
+
if not kwargs.get("elasticsearch_url"):
|
| 73 |
+
st.warning("Please enter your elastic search url", icon='⚠')
|
| 74 |
+
kwargs["elasticsearch_index"] = st.text_input("Elastic Search Index: ")
|
| 75 |
+
if not kwargs.get("elasticsearch_index"):
|
| 76 |
+
st.warning("Please enter your elastic search index", icon='⚠')
|
| 77 |
+
|
| 78 |
+
st.markdown("---")
|
| 79 |
+
st.text("Chain Settings")
|
| 80 |
+
kwargs["chain_type"] = st.selectbox("Chain Type", options=["stuff", "map_reduce"])
|
| 81 |
+
kwargs["search_type"] = st.selectbox("Search Type", options=["similarity"])
|
| 82 |
+
st.markdown("---")
|
| 83 |
+
|
| 84 |
+
with tab1:
|
| 85 |
+
uploaded_file = st.file_uploader("Upload and Ingest PDF 🚀", type="pdf")
|
| 86 |
+
if uploaded_file:
|
| 87 |
+
with st.spinner("Uploading and Ingesting"):
|
| 88 |
+
documents = st.session_state.qna_system.read_and_load_pdf(uploaded_file)
|
| 89 |
+
if selected_model == EmbeddingTypes.NA:
|
| 90 |
+
st.warning("Please select the model", icon='⚠')
|
| 91 |
+
else:
|
| 92 |
+
st.session_state.qna_system.build_chain(transform_type=text_transform, embedding_type=selected_model,
|
| 93 |
+
indexer_type=vector_indexer, **kwargs)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def generate_response(prompt):
|
| 97 |
+
if prompt and uploaded_file:
|
| 98 |
+
response = st.session_state.qna_system.ask_question(prompt)
|
| 99 |
+
return response.get("answer", response.get("result", "")), response.get("source_documents")
|
| 100 |
+
return "", []
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
with tab2:
|
| 104 |
+
if not uploaded_file:
|
| 105 |
+
st.warning("Please upload PDF", icon='⚠')
|
| 106 |
+
else:
|
| 107 |
+
match kwargs["bot_type"]:
|
| 108 |
+
case BotType.qna:
|
| 109 |
+
with st.container():
|
| 110 |
+
with st.form('my_form'):
|
| 111 |
+
text = st.text_area("", placeholder='Ask me...')
|
| 112 |
+
submitted = st.form_submit_button('Submit')
|
| 113 |
+
if text:
|
| 114 |
+
st.write(f"Question:\n{text}")
|
| 115 |
+
response, source_docs = generate_response(text)
|
| 116 |
+
st.write(response)
|
| 117 |
+
case BotType.conversational:
|
| 118 |
+
# Generate empty lists for generated and past.
|
| 119 |
+
## generated stores AI generated responses
|
| 120 |
+
if 'generated' not in st.session_state:
|
| 121 |
+
st.session_state['generated'] = ["Hi! I'm PDF Assistant 🤖, How may I help you?"]
|
| 122 |
+
## past stores User's questions
|
| 123 |
+
if 'past' not in st.session_state:
|
| 124 |
+
st.session_state['past'] = ['Hi!']
|
| 125 |
+
|
| 126 |
+
input_container = st.container()
|
| 127 |
+
colored_header(label='', description='', color_name='blue-30')
|
| 128 |
+
response_container = st.container()
|
| 129 |
+
response = ""
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_text():
|
| 133 |
+
input_text = st.text_input("You: ", "", key="input")
|
| 134 |
+
return input_text
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
with input_container:
|
| 138 |
+
user_input = get_text()
|
| 139 |
+
if st.button("Clear"):
|
| 140 |
+
st.session_state.generated.clear()
|
| 141 |
+
st.session_state.past.clear()
|
| 142 |
+
|
| 143 |
+
with response_container:
|
| 144 |
+
if user_input:
|
| 145 |
+
response, source_docs = generate_response(user_input)
|
| 146 |
+
st.session_state.past.append(user_input)
|
| 147 |
+
st.session_state.generated.append(response)
|
| 148 |
+
|
| 149 |
+
if st.session_state['generated']:
|
| 150 |
+
for i in range(len(st.session_state['generated'])):
|
| 151 |
+
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
|
| 152 |
+
message(st.session_state["generated"][i], key=str(i))
|
| 153 |
+
|
| 154 |
+
require_document = st.container()
|
| 155 |
+
if kwargs["return_documents"]:
|
| 156 |
+
with require_document:
|
| 157 |
+
with st.expander("Related Documents", expanded=False):
|
| 158 |
+
for source in source_docs:
|
| 159 |
+
metadata = source.metadata
|
| 160 |
+
st.write("{source} - {page_no}".format(source=metadata.get("source"),
|
| 161 |
+
page_no=metadata.get("page_no")))
|
| 162 |
+
st.write(source.page_content)
|
| 163 |
+
st.markdown("---")
|
| 164 |
+
|
| 165 |
+
with tab3:
|
| 166 |
+
if require_pdf and uploaded_file:
|
| 167 |
+
show_pdf(uploaded_file)
|
| 168 |
+
elif uploaded_file:
|
| 169 |
+
st.warning("Feature not enabled.", icon='⚠')
|
| 170 |
+
else:
|
| 171 |
+
st.warning("Please upload PDF", icon='⚠')
|
backend.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from langchain import FAISS, OpenAI, HuggingFaceHub, Cohere, PromptTemplate
|
| 4 |
+
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
| 5 |
+
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings, CohereEmbeddings
|
| 6 |
+
from langchain.memory import ConversationBufferMemory
|
| 7 |
+
from langchain.schema import Document
|
| 8 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter, NLTKTextSplitter, \
|
| 9 |
+
SpacyTextSplitter
|
| 10 |
+
from langchain.vectorstores import Chroma, ElasticVectorSearch
|
| 11 |
+
from pypdf import PdfReader
|
| 12 |
+
|
| 13 |
+
from schema import EmbeddingTypes, IndexerType, TransformType, BotType
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class QnASystem:
|
| 17 |
+
|
| 18 |
+
def read_and_load_pdf(self, f_data):
|
| 19 |
+
pdf_data = PdfReader(f_data)
|
| 20 |
+
documents = []
|
| 21 |
+
for idx, page in enumerate(pdf_data.pages):
|
| 22 |
+
documents.append(Document(page_content=page.extract_text(),
|
| 23 |
+
metadata={"page_no": idx, "source": f_data.name}))
|
| 24 |
+
|
| 25 |
+
self.documents = documents
|
| 26 |
+
|
| 27 |
+
def document_transformer(self, transform_type: TransformType):
|
| 28 |
+
match transform_type:
|
| 29 |
+
case TransformType.CharacterTransform:
|
| 30 |
+
t_type = CharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
|
| 31 |
+
case TransformType.RecursiveTransform:
|
| 32 |
+
t_type = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
|
| 33 |
+
case TransformType.NLTKTransform:
|
| 34 |
+
t_type = NLTKTextSplitter()
|
| 35 |
+
case TransformType.SpacyTransform:
|
| 36 |
+
t_type = SpacyTextSplitter()
|
| 37 |
+
|
| 38 |
+
case _:
|
| 39 |
+
raise IndexError("Invalid Transformer Type")
|
| 40 |
+
|
| 41 |
+
self.transformed_documents = t_type.split_documents(documents=self.documents)
|
| 42 |
+
|
| 43 |
+
def generate_embeddings(self, embedding_type: EmbeddingTypes = EmbeddingTypes.OPENAI,
|
| 44 |
+
indexer_type: IndexerType = IndexerType.FAISS, **kwargs):
|
| 45 |
+
temperature = kwargs.get("temperature", 0)
|
| 46 |
+
max_tokens = kwargs.get("max_tokens", 512)
|
| 47 |
+
match embedding_type:
|
| 48 |
+
case EmbeddingTypes.OPENAI:
|
| 49 |
+
os.environ["OPENAI_API_KEY"] = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
|
| 50 |
+
embeddings = OpenAIEmbeddings()
|
| 51 |
+
llm = OpenAI(temperature=temperature, max_tokens=max_tokens)
|
| 52 |
+
case EmbeddingTypes.HUGGING_FACE:
|
| 53 |
+
embeddings = HuggingFaceEmbeddings(model_name=kwargs.get("model_name"))
|
| 54 |
+
llm = HuggingFaceHub(repo_id=kwargs.get("model_name"),
|
| 55 |
+
model_kwargs={"temperature": temperature, "max_tokens": max_tokens})
|
| 56 |
+
case EmbeddingTypes.COHERE:
|
| 57 |
+
embeddings = CohereEmbeddings(model=kwargs.get("model_name"), cohere_api_key=kwargs.get("api_key"))
|
| 58 |
+
llm = Cohere(model=kwargs.get("model_name"), cohere_api_key=kwargs.get("api_key"),
|
| 59 |
+
model_kwargs={"temperature": temperature,
|
| 60 |
+
"max_tokens": max_tokens})
|
| 61 |
+
case _:
|
| 62 |
+
raise IndexError("Invalid Embedding Type")
|
| 63 |
+
|
| 64 |
+
match indexer_type:
|
| 65 |
+
case IndexerType.FAISS:
|
| 66 |
+
indexer = FAISS
|
| 67 |
+
case IndexerType.CHROMA:
|
| 68 |
+
indexer = Chroma()
|
| 69 |
+
|
| 70 |
+
case IndexerType.ELASTICSEARCH:
|
| 71 |
+
indexer = ElasticVectorSearch(elasticsearch_url=kwargs.get("elasticsearch_url"))
|
| 72 |
+
case _:
|
| 73 |
+
raise IndexError("Invalid Indexer Function")
|
| 74 |
+
|
| 75 |
+
self.llm = llm
|
| 76 |
+
self.indexer = indexer
|
| 77 |
+
self.vector_store = indexer.from_documents(documents=self.transformed_documents, embedding=embeddings)
|
| 78 |
+
|
| 79 |
+
def get_retriever(self, search_type="similarity", top_k=5, **kwargs):
|
| 80 |
+
retriever = self.vector_store.as_retriever(search_type=search_type, search_kwargs={"k": top_k})
|
| 81 |
+
self.retriever = retriever
|
| 82 |
+
|
| 83 |
+
def get_prompt(self, bot_type: BotType, **kwargs):
|
| 84 |
+
match bot_type:
|
| 85 |
+
case BotType.qna:
|
| 86 |
+
prompt = """
|
| 87 |
+
You are a smart and helpful AI assistant, who answer the question given context
|
| 88 |
+
{context}
|
| 89 |
+
Question: {question}
|
| 90 |
+
"""
|
| 91 |
+
case BotType.conversational:
|
| 92 |
+
prompt = """
|
| 93 |
+
Given the following conversation and a follow up question,
|
| 94 |
+
rephrase the follow up question to be a standalone question, in its original language.
|
| 95 |
+
\nChat History:\n{chat_history}\nFollow Up Input: {question}\nStandalone question:
|
| 96 |
+
"""
|
| 97 |
+
return PromptTemplate(input_variables=["context", "question", "chat_history"], template=prompt)
|
| 98 |
+
|
| 99 |
+
def build_qa(self, qa_type: BotType, chain_type="stuff",
|
| 100 |
+
return_documents: bool = True, **kwargs):
|
| 101 |
+
match qa_type:
|
| 102 |
+
case BotType.qna:
|
| 103 |
+
self.chain = RetrievalQA.from_chain_type(llm=self.llm, retriever=self.retriever, chain_type=chain_type,
|
| 104 |
+
return_source_documents=return_documents, verbose=True)
|
| 105 |
+
|
| 106 |
+
case BotType.conversational:
|
| 107 |
+
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True,
|
| 108 |
+
output_key="answer")
|
| 109 |
+
self.chain = ConversationalRetrievalChain.from_llm(llm=self.llm, retriever=self.retriever,
|
| 110 |
+
chain_type=chain_type,
|
| 111 |
+
return_source_documents=return_documents,
|
| 112 |
+
memory=self.memory, verbose=True)
|
| 113 |
+
|
| 114 |
+
case _:
|
| 115 |
+
raise IndexError("Invalid QA Type")
|
| 116 |
+
|
| 117 |
+
def ask_question(self, query):
|
| 118 |
+
if type(self.chain) == RetrievalQA:
|
| 119 |
+
data = {"query": query}
|
| 120 |
+
else:
|
| 121 |
+
data = {"question": query}
|
| 122 |
+
return self.chain(data)
|
| 123 |
+
|
| 124 |
+
def build_chain(self, transform_type, embedding_type, indexer_type, **kwargs):
|
| 125 |
+
if hasattr(self, "llm"):
|
| 126 |
+
return self.chain
|
| 127 |
+
self.document_transformer(transform_type)
|
| 128 |
+
self.generate_embeddings(embedding_type=embedding_type,
|
| 129 |
+
indexer_type=indexer_type, **kwargs)
|
| 130 |
+
self.get_retriever(**kwargs)
|
| 131 |
+
qa = self.build_qa(qa_type=kwargs.get("bot_type"), **kwargs)
|
| 132 |
+
return qa
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
qna = QnASystem()
|
| 137 |
+
with open("../docs/Doc A.pdf", "rb") as f:
|
| 138 |
+
qna.read_and_load_pdf(f)
|
| 139 |
+
chain = qna.build_chain(
|
| 140 |
+
transform_type=TransformType.RecursiveTransform,
|
| 141 |
+
embedding_type=EmbeddingTypes.OPENAI, indexer_type=IndexerType.FAISS,
|
| 142 |
+
chain_type="map_reduce", bot_type=BotType.conversational, return_documents=True
|
| 143 |
+
)
|
| 144 |
+
question = qna.ask_question(query="Hi! Summarize the document.")
|
| 145 |
+
question = qna.ask_question(query="What happened from June 1984 to September 1996")
|
| 146 |
+
print(question)
|
configs.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from qna_retrival.schema import EmbeddingTypes, IndexerType
|
| 2 |
+
|
| 3 |
+
indexer_type = IndexerType.FAISS
|
| 4 |
+
embedding_type = EmbeddingTypes.OPENAI
|
qna.py
ADDED
|
File without changes
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
langchain
|
| 2 |
+
openai
|
| 3 |
+
chroma
|
| 4 |
+
streamlit
|
| 5 |
+
streamlit-extras
|
| 6 |
+
streamlit-chat
|
| 7 |
+
faiss-cpu
|
| 8 |
+
pypdf
|
| 9 |
+
tiktoken
|
schema.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum, EnumMeta
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class EnumMetaClass(Enum):
|
| 6 |
+
|
| 7 |
+
def __eq__(self, other):
|
| 8 |
+
if self.__class__ is other.__class__:
|
| 9 |
+
return self.value.upper() == other.value.upper()
|
| 10 |
+
return self.value == other
|
| 11 |
+
|
| 12 |
+
def __hash__(self):
|
| 13 |
+
return hash(self._name_)
|
| 14 |
+
|
| 15 |
+
def __str__(self):
|
| 16 |
+
return self.value
|
| 17 |
+
|
| 18 |
+
@classmethod
|
| 19 |
+
def get_enum(cls, value: str) -> Union[EnumMeta, None]:
|
| 20 |
+
return next(
|
| 21 |
+
(
|
| 22 |
+
enum_val
|
| 23 |
+
for enum_val in cls
|
| 24 |
+
if (enum_val.value == value)
|
| 25 |
+
or (
|
| 26 |
+
isinstance(value, str)
|
| 27 |
+
and isinstance(enum_val.value, str)
|
| 28 |
+
and (value.lower() == enum_val.value.lower() or value.upper() == enum_val.name.upper())
|
| 29 |
+
)
|
| 30 |
+
),
|
| 31 |
+
None,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def _missing_(cls, name):
|
| 36 |
+
for member in cls:
|
| 37 |
+
if isinstance(member.name, str) and isinstance(name, str) and member.name.lower() == name.lower():
|
| 38 |
+
return member
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class EmbeddingTypes(EnumMetaClass):
|
| 42 |
+
NA = "NA"
|
| 43 |
+
OPENAI = "OpenAI"
|
| 44 |
+
HUGGING_FACE = "Hugging Face"
|
| 45 |
+
COHERE = "Cohere"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class TransformType(EnumMetaClass):
|
| 49 |
+
RecursiveTransform = "Recursive Text Splitter"
|
| 50 |
+
CharacterTransform = "Character Text Splitter"
|
| 51 |
+
SpacyTransform = "Spacy Text Splitter"
|
| 52 |
+
NLTKTransform = "NLTK Text Splitter"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class IndexerType(EnumMetaClass):
|
| 56 |
+
FAISS = "FAISS"
|
| 57 |
+
CHROMA = "Chroma"
|
| 58 |
+
ELASTICSEARCH = "Elastic Search"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class BotType(EnumMetaClass):
|
| 62 |
+
qna = "Question Answering Bot ❓"
|
| 63 |
+
conversational = "Chatbot 🤖"
|