Update functions.py
Browse files- functions.py +50 -17
functions.py
CHANGED
@@ -25,10 +25,27 @@ from langchain.docstore.document import Document
|
|
25 |
from langchain.embeddings import HuggingFaceEmbeddings,HuggingFaceInstructEmbeddings
|
26 |
from langchain.vectorstores import FAISS
|
27 |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
28 |
-
from langchain.text_splitter import
|
29 |
from langchain.llms import OpenAI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
from langchain import VectorDBQA
|
31 |
from langchain.chains.question_answering import load_qa_chain
|
|
|
32 |
from langchain.prompts import PromptTemplate
|
33 |
from langchain.prompts.base import RegexParser
|
34 |
|
@@ -48,7 +65,7 @@ output_parser = RegexParser(
|
|
48 |
output_keys=["answer", "score"],
|
49 |
)
|
50 |
|
51 |
-
|
52 |
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
|
53 |
ALWAYS return a "SOURCES" part in your answer.
|
54 |
|
@@ -64,8 +81,13 @@ Context:
|
|
64 |
---------
|
65 |
{summaries}
|
66 |
---------
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
#Refine Chain Type Prompt Template
|
71 |
refine_prompt_template = (
|
@@ -85,7 +107,6 @@ refine_prompt = PromptTemplate(
|
|
85 |
template=refine_prompt_template,
|
86 |
)
|
87 |
|
88 |
-
|
89 |
initial_qa_template = (
|
90 |
"Context information is below. \n"
|
91 |
"---------------------\n"
|
@@ -123,11 +144,11 @@ def load_asr_model(asr_model_name):
|
|
123 |
return asr_model
|
124 |
|
125 |
@st.experimental_singleton(suppress_st_warning=True)
|
126 |
-
def process_corpus(corpus,
|
127 |
|
128 |
'''Process text for Semantic Search'''
|
129 |
|
130 |
-
text_splitter =
|
131 |
|
132 |
texts = text_splitter.split_text(corpus)
|
133 |
|
@@ -181,10 +202,13 @@ def gen_embeddings(embedding_model):
|
|
181 |
return embeddings
|
182 |
|
183 |
@st.experimental_memo(suppress_st_warning=True)
|
184 |
-
def embed_text(query,title,embedding_model,
|
185 |
|
186 |
'''Embed text and generate semantic search scores'''
|
187 |
|
|
|
|
|
|
|
188 |
title = title.split()[0].lower()
|
189 |
|
190 |
docs = _docsearch.similarity_search_with_score(query, k=3)
|
@@ -193,16 +217,25 @@ def embed_text(query,title,embedding_model,_emb_tok,_docsearch,chain_type):
|
|
193 |
|
194 |
docs = [d[0] for d in docs]
|
195 |
|
196 |
-
PROMPT = PromptTemplate(template=template,
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
chain = load_qa_with_sources_chain(OpenAI(temperature=0),
|
201 |
-
chain_type="stuff",
|
202 |
-
prompt=PROMPT,
|
203 |
-
)
|
204 |
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
|
208 |
elif chain_type == 'Refined':
|
|
|
25 |
from langchain.embeddings import HuggingFaceEmbeddings,HuggingFaceInstructEmbeddings
|
26 |
from langchain.vectorstores import FAISS
|
27 |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
28 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
29 |
from langchain.llms import OpenAI
|
30 |
+
from langchain.callbacks.base import CallbackManager
|
31 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
32 |
+
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
|
33 |
+
|
34 |
+
from langchain.chat_models import ChatOpenAI
|
35 |
+
from langchain.prompts.chat import (
|
36 |
+
ChatPromptTemplate,
|
37 |
+
SystemMessagePromptTemplate,
|
38 |
+
AIMessagePromptTemplate,
|
39 |
+
HumanMessagePromptTemplate,
|
40 |
+
)
|
41 |
+
from langchain.schema import (
|
42 |
+
AIMessage,
|
43 |
+
HumanMessage,
|
44 |
+
SystemMessage
|
45 |
+
)
|
46 |
from langchain import VectorDBQA
|
47 |
from langchain.chains.question_answering import load_qa_chain
|
48 |
+
|
49 |
from langchain.prompts import PromptTemplate
|
50 |
from langchain.prompts.base import RegexParser
|
51 |
|
|
|
65 |
output_keys=["answer", "score"],
|
66 |
)
|
67 |
|
68 |
+
system_template = """Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES").
|
69 |
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
|
70 |
ALWAYS return a "SOURCES" part in your answer.
|
71 |
|
|
|
81 |
---------
|
82 |
{summaries}
|
83 |
---------
|
84 |
+
"""
|
85 |
+
|
86 |
+
messages = [
|
87 |
+
SystemMessagePromptTemplate.from_template(system_template),
|
88 |
+
HumanMessagePromptTemplate.from_template("{question}")
|
89 |
+
]
|
90 |
+
prompt = ChatPromptTemplate.from_messages(messages)
|
91 |
|
92 |
#Refine Chain Type Prompt Template
|
93 |
refine_prompt_template = (
|
|
|
107 |
template=refine_prompt_template,
|
108 |
)
|
109 |
|
|
|
110 |
initial_qa_template = (
|
111 |
"Context information is below. \n"
|
112 |
"---------------------\n"
|
|
|
144 |
return asr_model
|
145 |
|
146 |
@st.experimental_singleton(suppress_st_warning=True)
|
147 |
+
def process_corpus(corpus, title, embedding_model, chunk_size=1000, overlap=50):
|
148 |
|
149 |
'''Process text for Semantic Search'''
|
150 |
|
151 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,chunk_overlap=overlap)
|
152 |
|
153 |
texts = text_splitter.split_text(corpus)
|
154 |
|
|
|
202 |
return embeddings
|
203 |
|
204 |
@st.experimental_memo(suppress_st_warning=True)
|
205 |
+
def embed_text(query,title,embedding_model,_docsearch,chain_type):
|
206 |
|
207 |
'''Embed text and generate semantic search scores'''
|
208 |
|
209 |
+
llm = OpenAI(temperature=0)
|
210 |
+
chat_llm = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)
|
211 |
+
|
212 |
title = title.split()[0].lower()
|
213 |
|
214 |
docs = _docsearch.similarity_search_with_score(query, k=3)
|
|
|
217 |
|
218 |
docs = [d[0] for d in docs]
|
219 |
|
220 |
+
# PROMPT = PromptTemplate(template=template,
|
221 |
+
# input_variables=["summaries", "question"],
|
222 |
+
# output_parser=output_parser)
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
+
chain_type_kwargs = {"prompt": prompt}
|
225 |
+
chain = VectorDBQAWithSourcesChain.from_chain_type(
|
226 |
+
streaming_llm,
|
227 |
+
chain_type="stuff",
|
228 |
+
vectorstore=_docsearch,
|
229 |
+
chain_type_kwargs=chain_type_kwargs
|
230 |
+
)
|
231 |
+
answer = chain({"question": query}, return_only_outputs=True)
|
232 |
+
# chain = load_qa_with_sources_chain(OpenAI(temperature=0),
|
233 |
+
# chain_type="stuff",
|
234 |
+
# prompt=PROMPT,
|
235 |
+
# )
|
236 |
+
|
237 |
+
|
238 |
+
# answer = chain({"input_documents": docs, "question": query}, return_only_outputs=False)
|
239 |
|
240 |
|
241 |
elif chain_type == 'Refined':
|