nickmuchi commited on
Commit
3e9b436
1 Parent(s): 023f553

Update functions.py

Browse files
Files changed (1) hide show
  1. functions.py +50 -17
functions.py CHANGED
@@ -25,10 +25,27 @@ from langchain.docstore.document import Document
25
  from langchain.embeddings import HuggingFaceEmbeddings,HuggingFaceInstructEmbeddings
26
  from langchain.vectorstores import FAISS
27
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
28
- from langchain.text_splitter import CharacterTextSplitter
29
  from langchain.llms import OpenAI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  from langchain import VectorDBQA
31
  from langchain.chains.question_answering import load_qa_chain
 
32
  from langchain.prompts import PromptTemplate
33
  from langchain.prompts.base import RegexParser
34
 
@@ -48,7 +65,7 @@ output_parser = RegexParser(
48
  output_keys=["answer", "score"],
49
  )
50
 
51
- template = """Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES").
52
  If you don't know the answer, just say that you don't know. Don't try to make up an answer.
53
  ALWAYS return a "SOURCES" part in your answer.
54
 
@@ -64,8 +81,13 @@ Context:
64
  ---------
65
  {summaries}
66
  ---------
67
- Question: {question}
68
- Helpful Answer:"""
 
 
 
 
 
69
 
70
  #Refine Chain Type Prompt Template
71
  refine_prompt_template = (
@@ -85,7 +107,6 @@ refine_prompt = PromptTemplate(
85
  template=refine_prompt_template,
86
  )
87
 
88
-
89
  initial_qa_template = (
90
  "Context information is below. \n"
91
  "---------------------\n"
@@ -123,11 +144,11 @@ def load_asr_model(asr_model_name):
123
  return asr_model
124
 
125
  @st.experimental_singleton(suppress_st_warning=True)
126
- def process_corpus(corpus, _tokenizer, title, embedding_model, chunk_size=200, overlap=50):
127
 
128
  '''Process text for Semantic Search'''
129
 
130
- text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(_tokenizer,chunk_size=chunk_size,chunk_overlap=overlap,separator='.')
131
 
132
  texts = text_splitter.split_text(corpus)
133
 
@@ -181,10 +202,13 @@ def gen_embeddings(embedding_model):
181
  return embeddings
182
 
183
  @st.experimental_memo(suppress_st_warning=True)
184
- def embed_text(query,title,embedding_model,_emb_tok,_docsearch,chain_type):
185
 
186
  '''Embed text and generate semantic search scores'''
187
 
 
 
 
188
  title = title.split()[0].lower()
189
 
190
  docs = _docsearch.similarity_search_with_score(query, k=3)
@@ -193,16 +217,25 @@ def embed_text(query,title,embedding_model,_emb_tok,_docsearch,chain_type):
193
 
194
  docs = [d[0] for d in docs]
195
 
196
- PROMPT = PromptTemplate(template=template,
197
- input_variables=["summaries", "question"],
198
- output_parser=output_parser)
199
-
200
- chain = load_qa_with_sources_chain(OpenAI(temperature=0),
201
- chain_type="stuff",
202
- prompt=PROMPT,
203
- )
204
 
205
- answer = chain({"input_documents": docs, "question": query}, return_only_outputs=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
 
208
  elif chain_type == 'Refined':
 
25
  from langchain.embeddings import HuggingFaceEmbeddings,HuggingFaceInstructEmbeddings
26
  from langchain.vectorstores import FAISS
27
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
28
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
29
  from langchain.llms import OpenAI
30
+ from langchain.callbacks.base import CallbackManager
31
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
32
+ from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
33
+
34
+ from langchain.chat_models import ChatOpenAI
35
+ from langchain.prompts.chat import (
36
+ ChatPromptTemplate,
37
+ SystemMessagePromptTemplate,
38
+ AIMessagePromptTemplate,
39
+ HumanMessagePromptTemplate,
40
+ )
41
+ from langchain.schema import (
42
+ AIMessage,
43
+ HumanMessage,
44
+ SystemMessage
45
+ )
46
  from langchain import VectorDBQA
47
  from langchain.chains.question_answering import load_qa_chain
48
+
49
  from langchain.prompts import PromptTemplate
50
  from langchain.prompts.base import RegexParser
51
 
 
65
  output_keys=["answer", "score"],
66
  )
67
 
68
+ system_template = """Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES").
69
  If you don't know the answer, just say that you don't know. Don't try to make up an answer.
70
  ALWAYS return a "SOURCES" part in your answer.
71
 
 
81
  ---------
82
  {summaries}
83
  ---------
84
+ """
85
+
86
+ messages = [
87
+ SystemMessagePromptTemplate.from_template(system_template),
88
+ HumanMessagePromptTemplate.from_template("{question}")
89
+ ]
90
+ prompt = ChatPromptTemplate.from_messages(messages)
91
 
92
  #Refine Chain Type Prompt Template
93
  refine_prompt_template = (
 
107
  template=refine_prompt_template,
108
  )
109
 
 
110
  initial_qa_template = (
111
  "Context information is below. \n"
112
  "---------------------\n"
 
144
  return asr_model
145
 
146
  @st.experimental_singleton(suppress_st_warning=True)
147
+ def process_corpus(corpus, title, embedding_model, chunk_size=1000, overlap=50):
148
 
149
  '''Process text for Semantic Search'''
150
 
151
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=overlap)
152
 
153
  texts = text_splitter.split_text(corpus)
154
 
 
202
  return embeddings
203
 
204
  @st.experimental_memo(suppress_st_warning=True)
205
+ def embed_text(query,title,embedding_model,_docsearch,chain_type):
206
 
207
  '''Embed text and generate semantic search scores'''
208
 
209
+ llm = OpenAI(temperature=0)
210
+ chat_llm = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)
211
+
212
  title = title.split()[0].lower()
213
 
214
  docs = _docsearch.similarity_search_with_score(query, k=3)
 
217
 
218
  docs = [d[0] for d in docs]
219
 
220
+ # PROMPT = PromptTemplate(template=template,
221
+ # input_variables=["summaries", "question"],
222
+ # output_parser=output_parser)
 
 
 
 
 
223
 
224
+ chain_type_kwargs = {"prompt": prompt}
225
+ chain = VectorDBQAWithSourcesChain.from_chain_type(
226
+ streaming_llm,
227
+ chain_type="stuff",
228
+ vectorstore=_docsearch,
229
+ chain_type_kwargs=chain_type_kwargs
230
+ )
231
+ answer = chain({"question": query}, return_only_outputs=True)
232
+ # chain = load_qa_with_sources_chain(OpenAI(temperature=0),
233
+ # chain_type="stuff",
234
+ # prompt=PROMPT,
235
+ # )
236
+
237
+
238
+ # answer = chain({"input_documents": docs, "question": query}, return_only_outputs=False)
239
 
240
 
241
  elif chain_type == 'Refined':