nickmuchi commited on
Commit
a9fd3e5
1 Parent(s): 7b09818

Update functions.py

Browse files
Files changed (1) hide show
  1. functions.py +16 -9
functions.py CHANGED
@@ -31,8 +31,10 @@ from langchain.vectorstores import FAISS
31
  from langchain.text_splitter import RecursiveCharacterTextSplitter
32
  from langchain.chat_models import ChatOpenAI
33
  from langchain.callbacks import StdOutCallbackHandler
34
- from langchain.chains import ConversationalRetrievalChain, QAGenerationChain, RetrievalQA
35
  from langchain.memory import ConversationBufferMemory
 
 
36
 
37
  from langchain.prompts.chat import (
38
  ChatPromptTemplate,
@@ -87,7 +89,11 @@ def load_prompt():
87
  ----------------
88
  {context}"""
89
 
90
- prompt = system_template
 
 
 
 
91
 
92
  return prompt
93
 
@@ -566,13 +572,14 @@ def embed_text(query,_docsearch):
566
  # retriever=_docsearch.as_retriever(),
567
  # return_source_documents=True)
568
 
569
- chain = ConversationalRetrievalChain.from_llm(chat_llm,
570
- retriever= _docsearch.as_retriever(search_kwargs={"k": 3}),
571
- get_chat_history=lambda h : h,
572
- memory = memory,
573
- return_source_documents=True)
574
-
575
- chain.combine_docs_chain.llm_chain.prompt.messages[0] = SystemMessagePromptTemplate.from_template(load_prompt())
 
576
 
577
  answer = chain({"question": query})
578
 
 
31
  from langchain.text_splitter import RecursiveCharacterTextSplitter
32
  from langchain.chat_models import ChatOpenAI
33
  from langchain.callbacks import StdOutCallbackHandler
34
+ from langchain.chains import ConversationalRetrievalChain, QAGenerationChain, LLMChain
35
  from langchain.memory import ConversationBufferMemory
36
+ from langchain.chains.question_answering import load_qa_chain
37
+ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
38
 
39
  from langchain.prompts.chat import (
40
  ChatPromptTemplate,
 
89
  ----------------
90
  {context}"""
91
 
92
+ messages = [
93
+ SystemMessagePromptTemplate.from_template(system_template),
94
+ HumanMessagePromptTemplate.from_template("{question}")
95
+ ]
96
+ prompt = ChatPromptTemplate.from_messages(messages)
97
 
98
  return prompt
99
 
 
572
  # retriever=_docsearch.as_retriever(),
573
  # return_source_documents=True)
574
 
575
+ question_generator = LLMChain(llm=chat_llm, prompt=CONDENSE_QUESTION_PROMPT)
576
+ doc_chain = load_qa_chain(llm=chat_llm,chain_type="stuff",prompt=load_prompt())
577
+ chain = ConversationalRetrievalChain(retriever=cfa_db.as_retriever(search_kwags={"k": 3}),
578
+ question_generator=question_generator,
579
+ combine_docs_chain=doc_chain,
580
+ memory=memory,
581
+ return_source_documents=True,
582
+ get_chat_history=lambda h :h)
583
 
584
  answer = chain({"question": query})
585