mckplus commited on
Commit
3133187
1 Parent(s): 2e14a79

Update DocuChat.py

Browse files
Files changed (1) hide show
  1. DocuChat.py +3 -15
DocuChat.py CHANGED
@@ -41,28 +41,16 @@ class LangchainConversation:
41
  lines = re.split(r'\r\n|\r|\n', text)
42
  return '\n'.join([line.strip() for line in lines if line.strip()])
43
 
44
-
45
- def get_chat_history(self, inputs):
46
- chat_history_str = ""
47
- for human, ai in inputs:
48
- chat_history_str += f"User: {human}\nAI: {ai}\n"
49
- return chat_history_str
50
-
51
  def qa(self, file, query):
52
  loader = PyPDFLoader(file)
53
  documents = loader.load()
54
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0, context_aware=True)
55
  texts = text_splitter.split_documents(documents)
56
  embeddings = OpenAIEmbeddings()
57
  db = Chroma.from_documents(texts, embeddings)
58
  retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
59
- question_generator = LLMChain(llm=LangchainOpenAI(), prompt="Your Prompt Here")
60
- doc_chain = RetrievalQA.from_chain_type(llm=LangchainOpenAI(), chain_type="stuff", retriever=retriever, return_source_documents=True)
61
- qa = ConversationalRetrievalChain(retriever=retriever, combine_docs_chain=doc_chain, question_generator=question_generator)
62
- chat_history = self.chat_history if hasattr(self, 'chat_history') else []
63
- result = qa({"question": query, "chat_history": chat_history})
64
- chat_history.append((query, result["result"]))
65
- self.chat_history = chat_history
66
  return result['result']
67
 
68
  def view(self):
 
41
  lines = re.split(r'\r\n|\r|\n', text)
42
  return '\n'.join([line.strip() for line in lines if line.strip()])
43
 
 
 
 
 
 
 
 
44
  def qa(self, file, query):
45
  loader = PyPDFLoader(file)
46
  documents = loader.load()
47
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
48
  texts = text_splitter.split_documents(documents)
49
  embeddings = OpenAIEmbeddings()
50
  db = Chroma.from_documents(texts, embeddings)
51
  retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
52
+ qa = RetrievalQA.from_chain_type(llm=LangchainOpenAI(), chain_type="stuff", retriever=retriever, return_source_documents=True)
53
+ result = qa({"query": query})
 
 
 
 
 
54
  return result['result']
55
 
56
  def view(self):