anilkumar-kanasani commited on
Commit
e6f8d33
1 Parent(s): 21fce51

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +109 -0
  2. requirements.txt +21 -0
  3. utils.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PyPDF2 import PdfReader
3
+ from langchain.vectorstores import FAISS
4
+ from langchain.chains import LLMChain, ConversationalRetrievalChain
5
+ from utils import (get_hf_embeddings,
6
+ get_openAI_chat_model,
7
+ get_hf_model,
8
+ get_local_gpt4_model,
9
+ set_LangChain_tracking,
10
+ check_password)
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain.memory import ConversationBufferMemory
13
+ from langchain.docstore.document import Document
14
+
15
+ embeddings = get_hf_embeddings()
16
+ openai_chat_model = get_openAI_chat_model()
17
+ #local_model = get_local_gpt4_model(model = "GPT4All-13B-snoozy.ggmlv3.q4_0.bin")
18
+ hf_chat_model = get_hf_model(repo_id = "tiiuae/falcon-40b")
19
+
20
+ ## Preparing Prompt
21
+ from langchain.prompts import PromptTemplate
22
+ entity_extraction_template = """
23
+ Extract all top 10 important entites from the following context \
24
+ return as python list \
25
+ {input_text} \
26
+ List of entities:"""
27
+ ENTITY_EXTRACTION_PROMPT = PromptTemplate.from_template(entity_extraction_template)
28
+
29
+ def get_qa_prompt(List_of_entities):
30
+ qa_template = """
31
+ Use the following pieces of context to answer the question at the end. \
32
+ Use the following list of entities as your working scope. \
33
+ If the question is out of given list of entities, just say that your question \
34
+ is out of scope and give them the list of entities as your working scope \
35
+ If you dont know the answer, just say that you don't know and tell \
36
+ the user to seach web for more information, don't try to make up \
37
+ an answer. Use three sentences maximum and keep the answer as \
38
+ concise as possible.\
39
+ list of entities: \
40
+ """ + str(List_of_entities) + """ \
41
+ context: {context} \
42
+ Question: {question} \
43
+ Helpful Answer:"""
44
+ print(qa_template)
45
+ QA_CHAIN_PROMPT = PromptTemplate.from_template(qa_template)
46
+
47
+ return QA_CHAIN_PROMPT
48
+
49
+ if check_password():
50
+ st.title("Chat with your PDF ")
51
+ st.session_state.file_tracking = "new_run"
52
+ with st.expander("Upload your PDF : ", expanded=True):
53
+ st.session_state.lc_tracking = st.text_input("Please give a name to your session?")
54
+ input_file = st.file_uploader(label = "Upload a file",
55
+ accept_multiple_files=False,
56
+ type=["pdf"],
57
+ )
58
+ if st.button("Process the file"):
59
+ st.session_state.file_tracking = "req_to_process"
60
+ try:
61
+ set_LangChain_tracking(project=str(st.session_state.lc_tracking))
62
+ except:
63
+ set_LangChain_tracking(project="default")
64
+ if st.session_state.file_tracking == "req_to_process" and input_file is not None:
65
+ # Load Text Data
66
+ input_text = ''
67
+ bytes_data = PdfReader(input_file)
68
+ for page in bytes_data.pages:
69
+ input_text += page.extract_text()
70
+
71
+ st.session_state.ner_chain = LLMChain(llm=openai_chat_model, prompt=ENTITY_EXTRACTION_PROMPT)
72
+ st.session_state.ners = st.session_state.ner_chain.run(input_text=input_text, verbose=True)
73
+
74
+ input_text = input_text.replace('\n', '')
75
+ text_doc_chunks = [Document(page_content=x, metadata={}) for x in input_text.split('.')]
76
+
77
+ # Embed and VectorStore
78
+ vector_store = FAISS.from_documents(text_doc_chunks, embeddings)
79
+ st.session_state.chat_history = []
80
+ st.session_state.formatted_prompt = get_qa_prompt(st.session_state.ners)
81
+ st.session_state.chat_chain = ConversationalRetrievalChain.from_llm(
82
+ openai_chat_model,
83
+ chain_type="stuff", # "stuff", "map_reduce", "refine", "map_rerank"
84
+ verbose=True,
85
+ retriever=vector_store.as_retriever(),
86
+ # search_type="mmr"
87
+ # search_kwargs={"k": 1}
88
+ # search_type="similarity_score_threshold", search_kwargs={"score_threshold": .5}
89
+ combine_docs_chain_kwargs={"prompt": st.session_state.formatted_prompt},
90
+ )
91
+ if "chat_chain" in st.session_state:
92
+ st.header("We are ready to start chat with your pdf")
93
+ st.subheader("The scope of your PDF is: ")
94
+ st.markdown(st.session_state.ners)
95
+ else:
96
+ st.header("Upload and Process your file first")
97
+
98
+
99
+ if "chat_chain" in st.session_state and st.session_state.chat_history is not None:
100
+ if question := st.chat_input("Please type some thing here?"):
101
+ response = st.session_state.chat_chain({"question": question, "chat_history": st.session_state.chat_history})
102
+ st.session_state.chat_history.append((question, response["answer"]))
103
+
104
+ # Display chat messages from history on app rerun
105
+ for message in st.session_state.chat_history:
106
+ with st.chat_message("user"):
107
+ st.markdown(message[0])
108
+ with st.chat_message("assistant"):
109
+ st.markdown(message[1])
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # APIs
2
+ gpt4all
3
+ openai
4
+ huggingface_hub
5
+ # LLM Framework
6
+ langchain
7
+ # Chunking Dependencies
8
+ tiktoken
9
+ transformers
10
+ # Embedding Dependencies
11
+ InstructorEmbedding
12
+ torch
13
+ # Loading Dependencies
14
+ PyPDF2
15
+ pypdf
16
+ # VectorStore Dependencies
17
+ faiss-cpu
18
+ # UI
19
+ streamlit==1.25.0
20
+ watchdog==3.0.0
21
+ environs
utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from environs import Env
2
+ env = Env()
3
+
4
+ try:
5
+ env.read_env("/Users/kanasani/Documents/api_keys/.env.llm")
6
+ print("Using local .env.llm file")
7
+ except:
8
+ env.read_env()
9
+ print(".env file from repo secrets is used")
10
+
11
+ import openai
12
+ openai.api_type = env("API_TYPE")
13
+ openai.api_base = env("API_BASE")
14
+ openai.api_version = env("API_VERSION")
15
+ openai.api_key = env("AZURE_OPENAI_KEY")
16
+
17
+ def check_password():
18
+ import streamlit as st
19
+ """Returns `True` if the user had the correct password."""
20
+
21
+ def password_entered():
22
+ """Checks whether a password entered by the user is correct."""
23
+ if st.session_state["password"] == env("st_password"):
24
+ st.session_state["password_correct"] = True
25
+ del st.session_state["password"] # don't store password
26
+ else:
27
+ st.session_state["password_correct"] = False
28
+
29
+ if "password_correct" not in st.session_state:
30
+ # First run, show input for password.
31
+ st.text_input(
32
+ "Password", type="password", on_change=password_entered, key="password"
33
+ )
34
+ return False
35
+ elif not st.session_state["password_correct"]:
36
+ # Password not correct, show input + error.
37
+ st.text_input(
38
+ "Password", type="password", on_change=password_entered, key="password"
39
+ )
40
+ st.error("😕 Password incorrect")
41
+ return False
42
+ else:
43
+ # Password correct.
44
+ return True
45
+
46
+ def submit_prompt_to_gpt(input_list_of_prompts):
47
+ response = openai.ChatCompletion.create(
48
+ engine=env("DEPLOYMENT_NAME"),
49
+ messages=input_list_of_prompts,
50
+ temperature=1,
51
+ max_tokens=256,
52
+ top_p=1,
53
+ frequency_penalty=0,
54
+ presence_penalty=0,
55
+ )
56
+ response_content = response["choices"][0]["message"]["content"]
57
+ return response_content
58
+
59
+
60
+ def get_hf_embeddings():
61
+ from langchain.embeddings import HuggingFaceHubEmbeddings
62
+
63
+ embeddings = HuggingFaceHubEmbeddings(
64
+ repo_id="sentence-transformers/all-mpnet-base-v2",
65
+ task="feature-extraction",
66
+ huggingfacehub_api_token=env("HUGGINGFACEHUB_API_TOKEN"),
67
+ )
68
+ return embeddings
69
+
70
+ def get_openAI_chat_model():
71
+ import openai
72
+ from langchain.chat_models.azure_openai import AzureChatOpenAI
73
+ chat_model = AzureChatOpenAI(deployment_name=env("DEPLOYMENT_NAME"),
74
+ openai_api_version=env("API_VERSION"),
75
+ openai_api_base=env("API_BASE"),
76
+ openai_api_type=env("API_TYPE"),
77
+ openai_api_key=env("AZURE_OPENAI_KEY"),
78
+ verbose=True)
79
+ return chat_model
80
+
81
+ def get_hf_model(repo_id = "google/flan-t5-xxl"):
82
+
83
+ from langchain import HuggingFaceHub
84
+
85
+ hf_llm = HuggingFaceHub(
86
+ repo_id=repo_id,
87
+ model_kwargs={"temperature": 0.1, "max_length": 1024},
88
+ huggingfacehub_api_token = env("HUGGINGFACEHUB_API_TOKEN"),
89
+ )
90
+ return hf_llm
91
+
92
+ def get_local_gpt4_model(model = "GPT4All-13B-snoozy.ggmlv3.q4_0.bin"):
93
+ from langchain.llms import GPT4All
94
+ gpt4_llm = GPT4All(model=".models/"+model,
95
+ verbose=True)
96
+ return gpt4_llm
97
+
98
+ def set_LangChain_tracking(project="Chat with your PDF"):
99
+ import os
100
+ os.environ['LANGCHAIN_PROJECT'] = project
101
+ print("LangChain tracking is set to : ", project)
102
+
103
+ def unset_LangChain_tracking():
104
+ import os
105
+ os.environ.pop('LANGCHAIN_API_KEY', None)
106
+ os.environ.pop('LANGCHAIN_TRACING_V2', None)
107
+ os.environ.pop('LANGCHAIN_ENDPOINT', None)
108
+ os.environ.pop('LANGCHAIN_PROJECT', None)
109
+ print("LangChain tracking is removed .")