Spaces:
Runtime error
Runtime error
Commit
•
e6f8d33
1
Parent(s):
21fce51
Upload 3 files
Browse files- app.py +109 -0
- requirements.txt +21 -0
- utils.py +109 -0
app.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PyPDF2 import PdfReader
|
3 |
+
from langchain.vectorstores import FAISS
|
4 |
+
from langchain.chains import LLMChain, ConversationalRetrievalChain
|
5 |
+
from utils import (get_hf_embeddings,
|
6 |
+
get_openAI_chat_model,
|
7 |
+
get_hf_model,
|
8 |
+
get_local_gpt4_model,
|
9 |
+
set_LangChain_tracking,
|
10 |
+
check_password)
|
11 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
12 |
+
from langchain.memory import ConversationBufferMemory
|
13 |
+
from langchain.docstore.document import Document
|
14 |
+
|
15 |
+
embeddings = get_hf_embeddings()
|
16 |
+
openai_chat_model = get_openAI_chat_model()
|
17 |
+
#local_model = get_local_gpt4_model(model = "GPT4All-13B-snoozy.ggmlv3.q4_0.bin")
|
18 |
+
hf_chat_model = get_hf_model(repo_id = "tiiuae/falcon-40b")
|
19 |
+
|
20 |
+
## Preparing Prompt
|
21 |
+
from langchain.prompts import PromptTemplate
|
22 |
+
entity_extraction_template = """
|
23 |
+
Extract all top 10 important entites from the following context \
|
24 |
+
return as python list \
|
25 |
+
{input_text} \
|
26 |
+
List of entities:"""
|
27 |
+
ENTITY_EXTRACTION_PROMPT = PromptTemplate.from_template(entity_extraction_template)
|
28 |
+
|
29 |
+
def get_qa_prompt(List_of_entities):
|
30 |
+
qa_template = """
|
31 |
+
Use the following pieces of context to answer the question at the end. \
|
32 |
+
Use the following list of entities as your working scope. \
|
33 |
+
If the question is out of given list of entities, just say that your question \
|
34 |
+
is out of scope and give them the list of entities as your working scope \
|
35 |
+
If you dont know the answer, just say that you don't know and tell \
|
36 |
+
the user to seach web for more information, don't try to make up \
|
37 |
+
an answer. Use three sentences maximum and keep the answer as \
|
38 |
+
concise as possible.\
|
39 |
+
list of entities: \
|
40 |
+
""" + str(List_of_entities) + """ \
|
41 |
+
context: {context} \
|
42 |
+
Question: {question} \
|
43 |
+
Helpful Answer:"""
|
44 |
+
print(qa_template)
|
45 |
+
QA_CHAIN_PROMPT = PromptTemplate.from_template(qa_template)
|
46 |
+
|
47 |
+
return QA_CHAIN_PROMPT
|
48 |
+
|
49 |
+
if check_password():
|
50 |
+
st.title("Chat with your PDF ")
|
51 |
+
st.session_state.file_tracking = "new_run"
|
52 |
+
with st.expander("Upload your PDF : ", expanded=True):
|
53 |
+
st.session_state.lc_tracking = st.text_input("Please give a name to your session?")
|
54 |
+
input_file = st.file_uploader(label = "Upload a file",
|
55 |
+
accept_multiple_files=False,
|
56 |
+
type=["pdf"],
|
57 |
+
)
|
58 |
+
if st.button("Process the file"):
|
59 |
+
st.session_state.file_tracking = "req_to_process"
|
60 |
+
try:
|
61 |
+
set_LangChain_tracking(project=str(st.session_state.lc_tracking))
|
62 |
+
except:
|
63 |
+
set_LangChain_tracking(project="default")
|
64 |
+
if st.session_state.file_tracking == "req_to_process" and input_file is not None:
|
65 |
+
# Load Text Data
|
66 |
+
input_text = ''
|
67 |
+
bytes_data = PdfReader(input_file)
|
68 |
+
for page in bytes_data.pages:
|
69 |
+
input_text += page.extract_text()
|
70 |
+
|
71 |
+
st.session_state.ner_chain = LLMChain(llm=openai_chat_model, prompt=ENTITY_EXTRACTION_PROMPT)
|
72 |
+
st.session_state.ners = st.session_state.ner_chain.run(input_text=input_text, verbose=True)
|
73 |
+
|
74 |
+
input_text = input_text.replace('\n', '')
|
75 |
+
text_doc_chunks = [Document(page_content=x, metadata={}) for x in input_text.split('.')]
|
76 |
+
|
77 |
+
# Embed and VectorStore
|
78 |
+
vector_store = FAISS.from_documents(text_doc_chunks, embeddings)
|
79 |
+
st.session_state.chat_history = []
|
80 |
+
st.session_state.formatted_prompt = get_qa_prompt(st.session_state.ners)
|
81 |
+
st.session_state.chat_chain = ConversationalRetrievalChain.from_llm(
|
82 |
+
openai_chat_model,
|
83 |
+
chain_type="stuff", # "stuff", "map_reduce", "refine", "map_rerank"
|
84 |
+
verbose=True,
|
85 |
+
retriever=vector_store.as_retriever(),
|
86 |
+
# search_type="mmr"
|
87 |
+
# search_kwargs={"k": 1}
|
88 |
+
# search_type="similarity_score_threshold", search_kwargs={"score_threshold": .5}
|
89 |
+
combine_docs_chain_kwargs={"prompt": st.session_state.formatted_prompt},
|
90 |
+
)
|
91 |
+
if "chat_chain" in st.session_state:
|
92 |
+
st.header("We are ready to start chat with your pdf")
|
93 |
+
st.subheader("The scope of your PDF is: ")
|
94 |
+
st.markdown(st.session_state.ners)
|
95 |
+
else:
|
96 |
+
st.header("Upload and Process your file first")
|
97 |
+
|
98 |
+
|
99 |
+
if "chat_chain" in st.session_state and st.session_state.chat_history is not None:
|
100 |
+
if question := st.chat_input("Please type some thing here?"):
|
101 |
+
response = st.session_state.chat_chain({"question": question, "chat_history": st.session_state.chat_history})
|
102 |
+
st.session_state.chat_history.append((question, response["answer"]))
|
103 |
+
|
104 |
+
# Display chat messages from history on app rerun
|
105 |
+
for message in st.session_state.chat_history:
|
106 |
+
with st.chat_message("user"):
|
107 |
+
st.markdown(message[0])
|
108 |
+
with st.chat_message("assistant"):
|
109 |
+
st.markdown(message[1])
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# APIs
|
2 |
+
gpt4all
|
3 |
+
openai
|
4 |
+
huggingface_hub
|
5 |
+
# LLM Framework
|
6 |
+
langchain
|
7 |
+
# Chunking Dependencies
|
8 |
+
tiktoken
|
9 |
+
transformers
|
10 |
+
# Embedding Dependencies
|
11 |
+
InstructorEmbedding
|
12 |
+
torch
|
13 |
+
# Loading Dependencies
|
14 |
+
PyPDF2
|
15 |
+
pypdf
|
16 |
+
# VectorStore Dependencies
|
17 |
+
faiss-cpu
|
18 |
+
# UI
|
19 |
+
streamlit==1.25.0
|
20 |
+
watchdog==3.0.0
|
21 |
+
environs
|
utils.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from environs import Env
|
2 |
+
env = Env()
|
3 |
+
|
4 |
+
try:
|
5 |
+
env.read_env("/Users/kanasani/Documents/api_keys/.env.llm")
|
6 |
+
print("Using local .env.llm file")
|
7 |
+
except:
|
8 |
+
env.read_env()
|
9 |
+
print(".env file from repo secrets is used")
|
10 |
+
|
11 |
+
import openai
|
12 |
+
openai.api_type = env("API_TYPE")
|
13 |
+
openai.api_base = env("API_BASE")
|
14 |
+
openai.api_version = env("API_VERSION")
|
15 |
+
openai.api_key = env("AZURE_OPENAI_KEY")
|
16 |
+
|
17 |
+
def check_password():
|
18 |
+
import streamlit as st
|
19 |
+
"""Returns `True` if the user had the correct password."""
|
20 |
+
|
21 |
+
def password_entered():
|
22 |
+
"""Checks whether a password entered by the user is correct."""
|
23 |
+
if st.session_state["password"] == env("st_password"):
|
24 |
+
st.session_state["password_correct"] = True
|
25 |
+
del st.session_state["password"] # don't store password
|
26 |
+
else:
|
27 |
+
st.session_state["password_correct"] = False
|
28 |
+
|
29 |
+
if "password_correct" not in st.session_state:
|
30 |
+
# First run, show input for password.
|
31 |
+
st.text_input(
|
32 |
+
"Password", type="password", on_change=password_entered, key="password"
|
33 |
+
)
|
34 |
+
return False
|
35 |
+
elif not st.session_state["password_correct"]:
|
36 |
+
# Password not correct, show input + error.
|
37 |
+
st.text_input(
|
38 |
+
"Password", type="password", on_change=password_entered, key="password"
|
39 |
+
)
|
40 |
+
st.error("😕 Password incorrect")
|
41 |
+
return False
|
42 |
+
else:
|
43 |
+
# Password correct.
|
44 |
+
return True
|
45 |
+
|
46 |
+
def submit_prompt_to_gpt(input_list_of_prompts):
|
47 |
+
response = openai.ChatCompletion.create(
|
48 |
+
engine=env("DEPLOYMENT_NAME"),
|
49 |
+
messages=input_list_of_prompts,
|
50 |
+
temperature=1,
|
51 |
+
max_tokens=256,
|
52 |
+
top_p=1,
|
53 |
+
frequency_penalty=0,
|
54 |
+
presence_penalty=0,
|
55 |
+
)
|
56 |
+
response_content = response["choices"][0]["message"]["content"]
|
57 |
+
return response_content
|
58 |
+
|
59 |
+
|
60 |
+
def get_hf_embeddings():
|
61 |
+
from langchain.embeddings import HuggingFaceHubEmbeddings
|
62 |
+
|
63 |
+
embeddings = HuggingFaceHubEmbeddings(
|
64 |
+
repo_id="sentence-transformers/all-mpnet-base-v2",
|
65 |
+
task="feature-extraction",
|
66 |
+
huggingfacehub_api_token=env("HUGGINGFACEHUB_API_TOKEN"),
|
67 |
+
)
|
68 |
+
return embeddings
|
69 |
+
|
70 |
+
def get_openAI_chat_model():
|
71 |
+
import openai
|
72 |
+
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
73 |
+
chat_model = AzureChatOpenAI(deployment_name=env("DEPLOYMENT_NAME"),
|
74 |
+
openai_api_version=env("API_VERSION"),
|
75 |
+
openai_api_base=env("API_BASE"),
|
76 |
+
openai_api_type=env("API_TYPE"),
|
77 |
+
openai_api_key=env("AZURE_OPENAI_KEY"),
|
78 |
+
verbose=True)
|
79 |
+
return chat_model
|
80 |
+
|
81 |
+
def get_hf_model(repo_id = "google/flan-t5-xxl"):
|
82 |
+
|
83 |
+
from langchain import HuggingFaceHub
|
84 |
+
|
85 |
+
hf_llm = HuggingFaceHub(
|
86 |
+
repo_id=repo_id,
|
87 |
+
model_kwargs={"temperature": 0.1, "max_length": 1024},
|
88 |
+
huggingfacehub_api_token = env("HUGGINGFACEHUB_API_TOKEN"),
|
89 |
+
)
|
90 |
+
return hf_llm
|
91 |
+
|
92 |
+
def get_local_gpt4_model(model = "GPT4All-13B-snoozy.ggmlv3.q4_0.bin"):
|
93 |
+
from langchain.llms import GPT4All
|
94 |
+
gpt4_llm = GPT4All(model=".models/"+model,
|
95 |
+
verbose=True)
|
96 |
+
return gpt4_llm
|
97 |
+
|
98 |
+
def set_LangChain_tracking(project="Chat with your PDF"):
|
99 |
+
import os
|
100 |
+
os.environ['LANGCHAIN_PROJECT'] = project
|
101 |
+
print("LangChain tracking is set to : ", project)
|
102 |
+
|
103 |
+
def unset_LangChain_tracking():
|
104 |
+
import os
|
105 |
+
os.environ.pop('LANGCHAIN_API_KEY', None)
|
106 |
+
os.environ.pop('LANGCHAIN_TRACING_V2', None)
|
107 |
+
os.environ.pop('LANGCHAIN_ENDPOINT', None)
|
108 |
+
os.environ.pop('LANGCHAIN_PROJECT', None)
|
109 |
+
print("LangChain tracking is removed .")
|