vilsonrodrigues commited on
Commit
fe42288
1 Parent(s): 2aac694

add conversational chat mode

Browse files
Files changed (2) hide show
  1. qa/chains.py +15 -1
  2. qa/manager.py +15 -4
qa/chains.py CHANGED
@@ -3,4 +3,18 @@ from typing import Callable
3
  def retrieval_qa(llm: Callable, retriever: Callable) -> Callable:
4
  from langchain.chains import RetrievalQA
5
  qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
6
- return qa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  def retrieval_qa(llm: Callable, retriever: Callable) -> Callable:
4
  from langchain.chains import RetrievalQA
5
  qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
6
+ return qa
7
+
8
+ def conversational_retrieval_qa(llm: Callable, retriever: Callable) -> Callable:
9
+ from langchain.memory import ConversationBufferMemory
10
+ from langchain.chains import ConversationalRetrievalChain
11
+ memory = ConversationBufferMemory(
12
+ memory_key="chat_history",
13
+ return_messages=True
14
+ )
15
+ qa = ConversationalRetrievalChain.from_llm(
16
+ llm,
17
+ retriever=retriever,
18
+ memory=memory
19
+ )
20
+ return qa
qa/manager.py CHANGED
@@ -1,4 +1,4 @@
1
- from qa.chains import retrieval_qa
2
  from qa.loader import youtube_doc_loader
3
  from qa.model import load_llm
4
  from qa.split import split_document
@@ -7,7 +7,10 @@ from qa.vector_store import create_vector_store
7
  class YoutubeQA:
8
 
9
  def __init__(self):
10
- pass
 
 
 
11
 
12
  def load_model(self) -> None:
13
  self.llm = load_llm()
@@ -18,7 +21,15 @@ class YoutubeQA:
18
  self.retriver = create_vector_store(docs=docs)
19
 
20
  def load_retriever(self) -> None:
21
- self.retrieval_qa = retrieval_qa(llm=self.llm, retriever=self.retriver)
 
 
 
 
 
22
 
23
  def run(self, question: str) -> str:
24
- return self.retrieval_qa.run(question)
 
 
 
 
1
+ from qa.chains import conversational_retrieval_qa, retrieval_qa
2
  from qa.loader import youtube_doc_loader
3
  from qa.model import load_llm
4
  from qa.split import split_document
 
7
  class YoutubeQA:
8
 
9
  def __init__(self):
10
+ self.CHAT_MODE = 'normal'
11
+
12
+ def change_chat_mode(self, mode: str) -> None:
13
+ self.CHAT_MODE = mode
14
 
15
  def load_model(self) -> None:
16
  self.llm = load_llm()
 
21
  self.retriver = create_vector_store(docs=docs)
22
 
23
  def load_retriever(self) -> None:
24
+ if self.CHAT_MODE == 'normal':
25
+ self.retrieval_qa = retrieval_qa(llm=self.llm, retriever=self.retriver)
26
+ elif self.CHAT_MODE == 'conversational':
27
+ self.retrieval_qa = conversational_retrieval_qa(llm=self.llm, retriever=self.retriver)
28
+ else:
29
+ raise ValueError('Chat Mode not implemented')
30
 
31
  def run(self, question: str) -> str:
32
+ if self.CHAT_MODE == 'normal':
33
+ return self.retrieval_qa.run(question)
34
+ elif self.CHAT_MODE == 'conversational':
35
+ return self.retrieval_qa({'question': question})['answer']