Commit
·
e7afcc5
0
Parent(s):
Duplicate from ritikjain51/PDF-experimentation
Browse filesCo-authored-by: Ritik Jain <ritikjain51@users.noreply.huggingface.co>
- .gitattributes +34 -0
- .gitignore +6 -0
- Dockerfile +19 -0
- LICENSE +0 -0
- README.md +35 -0
- __init__.py +0 -0
- app.py +171 -0
- backend.py +146 -0
- configs.py +4 -0
- qna.py +0 -0
- requirements.txt +9 -0
- schema.py +63 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.Chroma
|
2 |
+
.chroma
|
3 |
+
*.ipynb
|
4 |
+
*.pyc
|
5 |
+
__pycache__
|
6 |
+
.faiss
|
Dockerfile
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
ENV PYTHONUNBUFFERED=1 \
|
6 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
7 |
+
PIP_NO_CACHE_DIR=off \
|
8 |
+
PIP_DISABLE_PIP_VERSION_CHECK=on \
|
9 |
+
PIP_DEFAULT_TIMEOUT=100 \
|
10 |
+
HNSWLIB_NO_NATIVE=1
|
11 |
+
|
12 |
+
RUN apt-get update && apt install python3-dev libprotobuf-dev build-essential -y
|
13 |
+
|
14 |
+
COPY . .
|
15 |
+
RUN pip install --upgrade pip
|
16 |
+
RUN pip install duckdb
|
17 |
+
RUN pip install -r requirements.txt
|
18 |
+
EXPOSE 8071
|
19 |
+
CMD ["gradio", "app.py"]
|
LICENSE
ADDED
File without changes
|
README.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
title: PDF Experimentation
|
4 |
+
sdk: streamlit
|
5 |
+
emoji: 🚀
|
6 |
+
colorFrom: purple
|
7 |
+
colorTo: gray
|
8 |
+
pinned: true
|
9 |
+
app_file: app.py
|
10 |
+
duplicated_from: ritikjain51/PDF-experimentation
|
11 |
+
---
|
12 |
+
|
13 |
+
## Next Steps
|
14 |
+
|
15 |
+
- [x] Build UI using Streamlit
|
16 |
+
- [x] Add Advance Settings in sidebar
|
17 |
+
- [x] Build backend using Langchain
|
18 |
+
- [x] Dockerize
|
19 |
+
- [ ] Add Docs
|
20 |
+
|
21 |
+
|
22 |
+
### UI Components
|
23 |
+
|
24 |
+
- [x] Add Upload PDF Tab
|
25 |
+
- [x] Show PDF Tab
|
26 |
+
- [x] Question Answer Tab
|
27 |
+
- [x] Conversational Tab
|
28 |
+
- [x] Advance Settings
|
29 |
+
- [x] Model Settings
|
30 |
+
|
31 |
+
### Backend Components
|
32 |
+
- [x] Read PDF and ingest
|
33 |
+
- [x] Fetch Configuration
|
34 |
+
- [x] Vector DB Indexing
|
35 |
+
- []
|
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
from streamlit_chat import message
|
5 |
+
from streamlit_extras.colored_header import colored_header
|
6 |
+
|
7 |
+
from backend import QnASystem
|
8 |
+
from schema import TransformType, EmbeddingTypes, IndexerType, BotType
|
9 |
+
|
10 |
+
kwargs = {}
|
11 |
+
source_docs = []
|
12 |
+
st.set_page_config(page_title="PDFChat - An LLM-powered experimentation app")
|
13 |
+
|
14 |
+
if "qna_system" not in st.session_state:
|
15 |
+
st.session_state.qna_system = QnASystem()
|
16 |
+
|
17 |
+
|
18 |
+
def show_pdf(f):
|
19 |
+
f.seek(0)
|
20 |
+
base64_pdf = base64.b64encode(f.read()).decode('utf-8')
|
21 |
+
pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="700" height="800" ' \
|
22 |
+
f'type="application/pdf"></iframe>'
|
23 |
+
st.markdown(pdf_display, unsafe_allow_html=True)
|
24 |
+
|
25 |
+
|
26 |
+
def model_settings():
|
27 |
+
kwargs["temperature"] = st.slider("Temperature", max_value=1.0, min_value=0.0)
|
28 |
+
kwargs["max_tokens"] = st.number_input("Max Token", min_value=0, value=512)
|
29 |
+
|
30 |
+
|
31 |
+
st.title("PDF Question and Answering")
|
32 |
+
|
33 |
+
tab1, tab2, tab3 = st.tabs(["Upload and Ingest PDF", "Ask", "Show PDF"])
|
34 |
+
|
35 |
+
with st.sidebar:
|
36 |
+
st.header("Advance Setting ⚙️")
|
37 |
+
require_pdf = st.checkbox("Show PDF", value=1)
|
38 |
+
st.markdown('---')
|
39 |
+
kwargs["bot_type"] = st.selectbox("Bot Type", options=BotType)
|
40 |
+
st.markdown("---")
|
41 |
+
st.text("Model Parameters")
|
42 |
+
kwargs["return_documents"] = st.checkbox("Require Source Documents", value=True)
|
43 |
+
text_transform = st.selectbox("Text Transformer", options=TransformType)
|
44 |
+
st.markdown("---")
|
45 |
+
selected_model = st.selectbox("Select Model", options=EmbeddingTypes)
|
46 |
+
match selected_model:
|
47 |
+
case EmbeddingTypes.OPENAI:
|
48 |
+
api_key = st.text_input("OpenAI API Key", placeholder="sk-...", type="password")
|
49 |
+
if not api_key.startswith('sk-'):
|
50 |
+
st.warning('Please enter your OpenAI API key!', icon='⚠')
|
51 |
+
model_settings()
|
52 |
+
case EmbeddingTypes.HUGGING_FACE:
|
53 |
+
api_key = st.text_input("Hugging Face API Key", placeholder="hg-...", type="password")
|
54 |
+
if not api_key.startswith('hg-'):
|
55 |
+
st.warning('Please enter your HuggingFace API key!', icon='⚠')
|
56 |
+
huggingface_model = st.selectbox("Choose Model", options=["google/flan-t5-xl"])
|
57 |
+
model_settings()
|
58 |
+
case EmbeddingTypes.COHERE:
|
59 |
+
api_key = st.text_input("Cohere API Key", placeholder="...", type="password")
|
60 |
+
if not api_key:
|
61 |
+
st.warning('Please enter your Cohere API key!', icon='⚠')
|
62 |
+
model_settings()
|
63 |
+
case _:
|
64 |
+
api_key = None
|
65 |
+
kwargs["api_key"] = api_key
|
66 |
+
st.markdown("---")
|
67 |
+
|
68 |
+
vector_indexer = st.selectbox("Vector Indexer", options=IndexerType)
|
69 |
+
match vector_indexer:
|
70 |
+
case IndexerType.ELASTICSEARCH:
|
71 |
+
kwargs["elasticsearch_url"] = st.text_input("Elastic Search URL: ")
|
72 |
+
if not kwargs.get("elasticsearch_url"):
|
73 |
+
st.warning("Please enter your elastic search url", icon='⚠')
|
74 |
+
kwargs["elasticsearch_index"] = st.text_input("Elastic Search Index: ")
|
75 |
+
if not kwargs.get("elasticsearch_index"):
|
76 |
+
st.warning("Please enter your elastic search index", icon='⚠')
|
77 |
+
|
78 |
+
st.markdown("---")
|
79 |
+
st.text("Chain Settings")
|
80 |
+
kwargs["chain_type"] = st.selectbox("Chain Type", options=["stuff", "map_reduce"])
|
81 |
+
kwargs["search_type"] = st.selectbox("Search Type", options=["similarity"])
|
82 |
+
st.markdown("---")
|
83 |
+
|
84 |
+
with tab1:
|
85 |
+
uploaded_file = st.file_uploader("Upload and Ingest PDF 🚀", type="pdf")
|
86 |
+
if uploaded_file:
|
87 |
+
with st.spinner("Uploading and Ingesting"):
|
88 |
+
documents = st.session_state.qna_system.read_and_load_pdf(uploaded_file)
|
89 |
+
if selected_model == EmbeddingTypes.NA:
|
90 |
+
st.warning("Please select the model", icon='⚠')
|
91 |
+
else:
|
92 |
+
st.session_state.qna_system.build_chain(transform_type=text_transform, embedding_type=selected_model,
|
93 |
+
indexer_type=vector_indexer, **kwargs)
|
94 |
+
|
95 |
+
|
96 |
+
def generate_response(prompt):
|
97 |
+
if prompt and uploaded_file:
|
98 |
+
response = st.session_state.qna_system.ask_question(prompt)
|
99 |
+
return response.get("answer", response.get("result", "")), response.get("source_documents")
|
100 |
+
return "", []
|
101 |
+
|
102 |
+
|
103 |
+
with tab2:
|
104 |
+
if not uploaded_file:
|
105 |
+
st.warning("Please upload PDF", icon='⚠')
|
106 |
+
else:
|
107 |
+
match kwargs["bot_type"]:
|
108 |
+
case BotType.qna:
|
109 |
+
with st.container():
|
110 |
+
with st.form('my_form'):
|
111 |
+
text = st.text_area("", placeholder='Ask me...')
|
112 |
+
submitted = st.form_submit_button('Submit')
|
113 |
+
if text:
|
114 |
+
st.write(f"Question:\n{text}")
|
115 |
+
response, source_docs = generate_response(text)
|
116 |
+
st.write(response)
|
117 |
+
case BotType.conversational:
|
118 |
+
# Generate empty lists for generated and past.
|
119 |
+
## generated stores AI generated responses
|
120 |
+
if 'generated' not in st.session_state:
|
121 |
+
st.session_state['generated'] = ["Hi! I'm PDF Assistant 🤖, How may I help you?"]
|
122 |
+
## past stores User's questions
|
123 |
+
if 'past' not in st.session_state:
|
124 |
+
st.session_state['past'] = ['Hi!']
|
125 |
+
|
126 |
+
input_container = st.container()
|
127 |
+
colored_header(label='', description='', color_name='blue-30')
|
128 |
+
response_container = st.container()
|
129 |
+
response = ""
|
130 |
+
|
131 |
+
|
132 |
+
def get_text():
|
133 |
+
input_text = st.text_input("You: ", "", key="input")
|
134 |
+
return input_text
|
135 |
+
|
136 |
+
|
137 |
+
with input_container:
|
138 |
+
user_input = get_text()
|
139 |
+
if st.button("Clear"):
|
140 |
+
st.session_state.generated.clear()
|
141 |
+
st.session_state.past.clear()
|
142 |
+
|
143 |
+
with response_container:
|
144 |
+
if user_input:
|
145 |
+
response, source_docs = generate_response(user_input)
|
146 |
+
st.session_state.past.append(user_input)
|
147 |
+
st.session_state.generated.append(response)
|
148 |
+
|
149 |
+
if st.session_state['generated']:
|
150 |
+
for i in range(len(st.session_state['generated'])):
|
151 |
+
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
|
152 |
+
message(st.session_state["generated"][i], key=str(i))
|
153 |
+
|
154 |
+
require_document = st.container()
|
155 |
+
if kwargs["return_documents"]:
|
156 |
+
with require_document:
|
157 |
+
with st.expander("Related Documents", expanded=False):
|
158 |
+
for source in source_docs:
|
159 |
+
metadata = source.metadata
|
160 |
+
st.write("{source} - {page_no}".format(source=metadata.get("source"),
|
161 |
+
page_no=metadata.get("page_no")))
|
162 |
+
st.write(source.page_content)
|
163 |
+
st.markdown("---")
|
164 |
+
|
165 |
+
with tab3:
|
166 |
+
if require_pdf and uploaded_file:
|
167 |
+
show_pdf(uploaded_file)
|
168 |
+
elif uploaded_file:
|
169 |
+
st.warning("Feature not enabled.", icon='⚠')
|
170 |
+
else:
|
171 |
+
st.warning("Please upload PDF", icon='⚠')
|
backend.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from langchain import FAISS, OpenAI, HuggingFaceHub, Cohere, PromptTemplate
|
4 |
+
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
5 |
+
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings, CohereEmbeddings
|
6 |
+
from langchain.memory import ConversationBufferMemory
|
7 |
+
from langchain.schema import Document
|
8 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter, NLTKTextSplitter, \
|
9 |
+
SpacyTextSplitter
|
10 |
+
from langchain.vectorstores import Chroma, ElasticVectorSearch
|
11 |
+
from pypdf import PdfReader
|
12 |
+
|
13 |
+
from schema import EmbeddingTypes, IndexerType, TransformType, BotType
|
14 |
+
|
15 |
+
|
16 |
+
class QnASystem:
|
17 |
+
|
18 |
+
def read_and_load_pdf(self, f_data):
|
19 |
+
pdf_data = PdfReader(f_data)
|
20 |
+
documents = []
|
21 |
+
for idx, page in enumerate(pdf_data.pages):
|
22 |
+
documents.append(Document(page_content=page.extract_text(),
|
23 |
+
metadata={"page_no": idx, "source": f_data.name}))
|
24 |
+
|
25 |
+
self.documents = documents
|
26 |
+
|
27 |
+
def document_transformer(self, transform_type: TransformType):
|
28 |
+
match transform_type:
|
29 |
+
case TransformType.CharacterTransform:
|
30 |
+
t_type = CharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
|
31 |
+
case TransformType.RecursiveTransform:
|
32 |
+
t_type = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
|
33 |
+
case TransformType.NLTKTransform:
|
34 |
+
t_type = NLTKTextSplitter()
|
35 |
+
case TransformType.SpacyTransform:
|
36 |
+
t_type = SpacyTextSplitter()
|
37 |
+
|
38 |
+
case _:
|
39 |
+
raise IndexError("Invalid Transformer Type")
|
40 |
+
|
41 |
+
self.transformed_documents = t_type.split_documents(documents=self.documents)
|
42 |
+
|
43 |
+
def generate_embeddings(self, embedding_type: EmbeddingTypes = EmbeddingTypes.OPENAI,
|
44 |
+
indexer_type: IndexerType = IndexerType.FAISS, **kwargs):
|
45 |
+
temperature = kwargs.get("temperature", 0)
|
46 |
+
max_tokens = kwargs.get("max_tokens", 512)
|
47 |
+
match embedding_type:
|
48 |
+
case EmbeddingTypes.OPENAI:
|
49 |
+
os.environ["OPENAI_API_KEY"] = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
|
50 |
+
embeddings = OpenAIEmbeddings()
|
51 |
+
llm = OpenAI(temperature=temperature, max_tokens=max_tokens)
|
52 |
+
case EmbeddingTypes.HUGGING_FACE:
|
53 |
+
embeddings = HuggingFaceEmbeddings(model_name=kwargs.get("model_name"))
|
54 |
+
llm = HuggingFaceHub(repo_id=kwargs.get("model_name"),
|
55 |
+
model_kwargs={"temperature": temperature, "max_tokens": max_tokens})
|
56 |
+
case EmbeddingTypes.COHERE:
|
57 |
+
embeddings = CohereEmbeddings(model=kwargs.get("model_name"), cohere_api_key=kwargs.get("api_key"))
|
58 |
+
llm = Cohere(model=kwargs.get("model_name"), cohere_api_key=kwargs.get("api_key"),
|
59 |
+
model_kwargs={"temperature": temperature,
|
60 |
+
"max_tokens": max_tokens})
|
61 |
+
case _:
|
62 |
+
raise IndexError("Invalid Embedding Type")
|
63 |
+
|
64 |
+
match indexer_type:
|
65 |
+
case IndexerType.FAISS:
|
66 |
+
indexer = FAISS
|
67 |
+
case IndexerType.CHROMA:
|
68 |
+
indexer = Chroma()
|
69 |
+
|
70 |
+
case IndexerType.ELASTICSEARCH:
|
71 |
+
indexer = ElasticVectorSearch(elasticsearch_url=kwargs.get("elasticsearch_url"))
|
72 |
+
case _:
|
73 |
+
raise IndexError("Invalid Indexer Function")
|
74 |
+
|
75 |
+
self.llm = llm
|
76 |
+
self.indexer = indexer
|
77 |
+
self.vector_store = indexer.from_documents(documents=self.transformed_documents, embedding=embeddings)
|
78 |
+
|
79 |
+
def get_retriever(self, search_type="similarity", top_k=5, **kwargs):
|
80 |
+
retriever = self.vector_store.as_retriever(search_type=search_type, search_kwargs={"k": top_k})
|
81 |
+
self.retriever = retriever
|
82 |
+
|
83 |
+
def get_prompt(self, bot_type: BotType, **kwargs):
|
84 |
+
match bot_type:
|
85 |
+
case BotType.qna:
|
86 |
+
prompt = """
|
87 |
+
You are a smart and helpful AI assistant, who answer the question given context
|
88 |
+
{context}
|
89 |
+
Question: {question}
|
90 |
+
"""
|
91 |
+
case BotType.conversational:
|
92 |
+
prompt = """
|
93 |
+
Given the following conversation and a follow up question,
|
94 |
+
rephrase the follow up question to be a standalone question, in its original language.
|
95 |
+
\nChat History:\n{chat_history}\nFollow Up Input: {question}\nStandalone question:
|
96 |
+
"""
|
97 |
+
return PromptTemplate(input_variables=["context", "question", "chat_history"], template=prompt)
|
98 |
+
|
99 |
+
def build_qa(self, qa_type: BotType, chain_type="stuff",
|
100 |
+
return_documents: bool = True, **kwargs):
|
101 |
+
match qa_type:
|
102 |
+
case BotType.qna:
|
103 |
+
self.chain = RetrievalQA.from_chain_type(llm=self.llm, retriever=self.retriever, chain_type=chain_type,
|
104 |
+
return_source_documents=return_documents, verbose=True)
|
105 |
+
|
106 |
+
case BotType.conversational:
|
107 |
+
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True,
|
108 |
+
output_key="answer")
|
109 |
+
self.chain = ConversationalRetrievalChain.from_llm(llm=self.llm, retriever=self.retriever,
|
110 |
+
chain_type=chain_type,
|
111 |
+
return_source_documents=return_documents,
|
112 |
+
memory=self.memory, verbose=True)
|
113 |
+
|
114 |
+
case _:
|
115 |
+
raise IndexError("Invalid QA Type")
|
116 |
+
|
117 |
+
def ask_question(self, query):
|
118 |
+
if type(self.chain) == RetrievalQA:
|
119 |
+
data = {"query": query}
|
120 |
+
else:
|
121 |
+
data = {"question": query}
|
122 |
+
return self.chain(data)
|
123 |
+
|
124 |
+
def build_chain(self, transform_type, embedding_type, indexer_type, **kwargs):
|
125 |
+
if hasattr(self, "llm"):
|
126 |
+
return self.chain
|
127 |
+
self.document_transformer(transform_type)
|
128 |
+
self.generate_embeddings(embedding_type=embedding_type,
|
129 |
+
indexer_type=indexer_type, **kwargs)
|
130 |
+
self.get_retriever(**kwargs)
|
131 |
+
qa = self.build_qa(qa_type=kwargs.get("bot_type"), **kwargs)
|
132 |
+
return qa
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
qna = QnASystem()
|
137 |
+
with open("../docs/Doc A.pdf", "rb") as f:
|
138 |
+
qna.read_and_load_pdf(f)
|
139 |
+
chain = qna.build_chain(
|
140 |
+
transform_type=TransformType.RecursiveTransform,
|
141 |
+
embedding_type=EmbeddingTypes.OPENAI, indexer_type=IndexerType.FAISS,
|
142 |
+
chain_type="map_reduce", bot_type=BotType.conversational, return_documents=True
|
143 |
+
)
|
144 |
+
question = qna.ask_question(query="Hi! Summarize the document.")
|
145 |
+
question = qna.ask_question(query="What happened from June 1984 to September 1996")
|
146 |
+
print(question)
|
configs.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from qna_retrival.schema import EmbeddingTypes, IndexerType
|
2 |
+
|
3 |
+
indexer_type = IndexerType.FAISS
|
4 |
+
embedding_type = EmbeddingTypes.OPENAI
|
qna.py
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
openai
|
3 |
+
chroma
|
4 |
+
streamlit
|
5 |
+
streamlit-extras
|
6 |
+
streamlit-chat
|
7 |
+
faiss-cpu
|
8 |
+
pypdf
|
9 |
+
tiktoken
|
schema.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum, EnumMeta
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
|
5 |
+
class EnumMetaClass(Enum):
|
6 |
+
|
7 |
+
def __eq__(self, other):
|
8 |
+
if self.__class__ is other.__class__:
|
9 |
+
return self.value.upper() == other.value.upper()
|
10 |
+
return self.value == other
|
11 |
+
|
12 |
+
def __hash__(self):
|
13 |
+
return hash(self._name_)
|
14 |
+
|
15 |
+
def __str__(self):
|
16 |
+
return self.value
|
17 |
+
|
18 |
+
@classmethod
|
19 |
+
def get_enum(cls, value: str) -> Union[EnumMeta, None]:
|
20 |
+
return next(
|
21 |
+
(
|
22 |
+
enum_val
|
23 |
+
for enum_val in cls
|
24 |
+
if (enum_val.value == value)
|
25 |
+
or (
|
26 |
+
isinstance(value, str)
|
27 |
+
and isinstance(enum_val.value, str)
|
28 |
+
and (value.lower() == enum_val.value.lower() or value.upper() == enum_val.name.upper())
|
29 |
+
)
|
30 |
+
),
|
31 |
+
None,
|
32 |
+
)
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def _missing_(cls, name):
|
36 |
+
for member in cls:
|
37 |
+
if isinstance(member.name, str) and isinstance(name, str) and member.name.lower() == name.lower():
|
38 |
+
return member
|
39 |
+
|
40 |
+
|
41 |
+
class EmbeddingTypes(EnumMetaClass):
|
42 |
+
NA = "NA"
|
43 |
+
OPENAI = "OpenAI"
|
44 |
+
HUGGING_FACE = "Hugging Face"
|
45 |
+
COHERE = "Cohere"
|
46 |
+
|
47 |
+
|
48 |
+
class TransformType(EnumMetaClass):
|
49 |
+
RecursiveTransform = "Recursive Text Splitter"
|
50 |
+
CharacterTransform = "Character Text Splitter"
|
51 |
+
SpacyTransform = "Spacy Text Splitter"
|
52 |
+
NLTKTransform = "NLTK Text Splitter"
|
53 |
+
|
54 |
+
|
55 |
+
class IndexerType(EnumMetaClass):
|
56 |
+
FAISS = "FAISS"
|
57 |
+
CHROMA = "Chroma"
|
58 |
+
ELASTICSEARCH = "Elastic Search"
|
59 |
+
|
60 |
+
|
61 |
+
class BotType(EnumMetaClass):
|
62 |
+
qna = "Question Answering Bot ❓"
|
63 |
+
conversational = "Chatbot 🤖"
|