lara1510 commited on
Commit
4f570b0
1 Parent(s): 2bb21b0

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +99 -102
chatbot.py CHANGED
@@ -1,102 +1,99 @@
1
- import os
2
- import dotenv
3
- from langchain.text_splitter import RecursiveCharacterTextSplitter
4
- from langchain_community.llms import HuggingFaceEndpoint
5
- from langchain_community.embeddings import HuggingFaceEmbeddings
6
- from langchain_community.vectorstores import Chroma
7
- from langchain_core.prompts import ChatPromptTemplate
8
- from langchain.chains.combine_documents import create_stuff_documents_chain
9
- from langchain.chains import create_retrieval_chain, create_history_aware_retriever
10
- from langchain_community.document_loaders import PyMuPDFLoader
11
- from langchain_community.llms import Ollama
12
- from langchain_core.messages import HumanMessage, AIMessage
13
- from langchain_core.prompts import ChatPromptTemplate
14
- from langchain_core.prompts import MessagesPlaceholder
15
-
16
- dotenv.load_dotenv()
17
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv('HUGGINGFACEHUB_API_TOKEN')
18
-
19
-
20
- class AdjustedHuggingFaceEmbeddings(HuggingFaceEmbeddings):
21
- def __call__(self, input):
22
- return super().__call__(input)
23
-
24
-
25
- def create_chain(chains, pdf_doc, use_local_model=True):
26
- if pdf_doc is None:
27
- return 'You must convert or upload a pdf first'
28
- db = create_vector_db(pdf_doc)
29
- llm = create_model(use_local_model)
30
- prompt_search_query = ChatPromptTemplate.from_messages([
31
- MessagesPlaceholder(
32
- variable_name="chat_history"),
33
- ("user", "{input}"),
34
- ("user",
35
- "Given the above conversation, generate a search query to look up to get information relevant to the conversation")
36
- ])
37
- retriever_chain = create_history_aware_retriever(llm, db.as_retriever(), prompt_search_query)
38
- prompt_get_answer = ChatPromptTemplate.from_messages([
39
- ("system", "Answer the user's questions based on the below context:\\n\\n{context}"),
40
- MessagesPlaceholder(variable_name="chat_history"),
41
- ("user", "{input}"),
42
- ])
43
- combine_docs_chain = create_stuff_documents_chain(llm=llm, prompt=prompt_get_answer)
44
- chains[0] = create_retrieval_chain(retriever_chain, combine_docs_chain)
45
- return 'Document has successfully been loaded'
46
-
47
-
48
- def create_model(local: bool):
49
- if local:
50
- llm = Ollama(model='phi')
51
- else:
52
- llm = HuggingFaceEndpoint(
53
- repo_id="OpenAssistant/oasst-sft-1-pythia-12b",
54
- model_kwargs={"max_length": 256},
55
- temperature=1.0
56
- )
57
- return llm
58
-
59
-
60
- def create_vector_db(doc):
61
- document = load_document(doc)
62
- text = split_document(document)
63
- embedding = AdjustedHuggingFaceEmbeddings()
64
- db = Chroma.from_documents(text, embedding)
65
- return db
66
-
67
-
68
- def load_document(doc):
69
- loader = PyMuPDFLoader(doc.name)
70
- document = loader.load()
71
- return document
72
-
73
-
74
- def split_document(doc):
75
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
76
- text = text_splitter.split_documents(doc)
77
- return text
78
-
79
-
80
- def save_history(history):
81
- with open('history.txt', 'w') as file:
82
- for s in history:
83
- file.write(f'- {s.content}\n')
84
-
85
-
86
- def answer_query(chain, query: str, chat_history=None) -> str:
87
- if chain:
88
- # run the given chain with the given query and history
89
- chat_history.append(HumanMessage(content=query))
90
- response = chain.invoke({
91
- 'chat_history': chat_history,
92
- 'input': query
93
- })
94
- answer = response['answer']
95
- print('RESPONSE: ', answer, '\n\n')
96
- # add the current question and answer to history
97
- chat_history.append(AIMessage(content=answer))
98
- # save chat history to text file
99
- save_history(chat_history)
100
- return answer
101
- else:
102
- return "Please load a document first."
 
1
+
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain_community.llms import HuggingFaceEndpoint
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from langchain_community.vectorstores import Chroma
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from langchain.chains.combine_documents import create_stuff_documents_chain
8
+ from langchain.chains import create_retrieval_chain, create_history_aware_retriever
9
+ from langchain_community.document_loaders import PyMuPDFLoader
10
+ from langchain_community.llms import Ollama
11
+ from langchain_core.messages import HumanMessage, AIMessage
12
+ from langchain_core.prompts import ChatPromptTemplate
13
+ from langchain_core.prompts import MessagesPlaceholder
14
+
15
+
16
+
17
+ class AdjustedHuggingFaceEmbeddings(HuggingFaceEmbeddings):
18
+ def __call__(self, input):
19
+ return super().__call__(input)
20
+
21
+
22
+ def create_chain(chains, pdf_doc, use_local_model=True):
23
+ if pdf_doc is None:
24
+ return 'You must convert or upload a pdf first'
25
+ db = create_vector_db(pdf_doc)
26
+ llm = create_model(use_local_model)
27
+ prompt_search_query = ChatPromptTemplate.from_messages([
28
+ MessagesPlaceholder(
29
+ variable_name="chat_history"),
30
+ ("user", "{input}"),
31
+ ("user",
32
+ "Given the above conversation, generate a search query to look up to get information relevant to the conversation")
33
+ ])
34
+ retriever_chain = create_history_aware_retriever(llm, db.as_retriever(), prompt_search_query)
35
+ prompt_get_answer = ChatPromptTemplate.from_messages([
36
+ ("system", "Answer the user's questions based on the below context:\\n\\n{context}"),
37
+ MessagesPlaceholder(variable_name="chat_history"),
38
+ ("user", "{input}"),
39
+ ])
40
+ combine_docs_chain = create_stuff_documents_chain(llm=llm, prompt=prompt_get_answer)
41
+ chains[0] = create_retrieval_chain(retriever_chain, combine_docs_chain)
42
+ return 'Document has successfully been loaded'
43
+
44
+
45
+ def create_model(local: bool):
46
+ if local:
47
+ llm = Ollama(model='phi')
48
+ else:
49
+ llm = HuggingFaceEndpoint(
50
+ repo_id="OpenAssistant/oasst-sft-1-pythia-12b",
51
+ model_kwargs={"max_length": 256},
52
+ temperature=1.0
53
+ )
54
+ return llm
55
+
56
+
57
+ def create_vector_db(doc):
58
+ document = load_document(doc)
59
+ text = split_document(document)
60
+ embedding = AdjustedHuggingFaceEmbeddings()
61
+ db = Chroma.from_documents(text, embedding)
62
+ return db
63
+
64
+
65
+ def load_document(doc):
66
+ loader = PyMuPDFLoader(doc.name)
67
+ document = loader.load()
68
+ return document
69
+
70
+
71
+ def split_document(doc):
72
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
73
+ text = text_splitter.split_documents(doc)
74
+ return text
75
+
76
+
77
+ def save_history(history):
78
+ with open('history.txt', 'w') as file:
79
+ for s in history:
80
+ file.write(f'- {s.content}\n')
81
+
82
+
83
+ def answer_query(chain, query: str, chat_history=None) -> str:
84
+ if chain:
85
+ # run the given chain with the given query and history
86
+ chat_history.append(HumanMessage(content=query))
87
+ response = chain.invoke({
88
+ 'chat_history': chat_history,
89
+ 'input': query
90
+ })
91
+ answer = response['answer']
92
+ print('RESPONSE: ', answer, '\n\n')
93
+ # add the current question and answer to history
94
+ chat_history.append(AIMessage(content=answer))
95
+ # save chat history to text file
96
+ save_history(chat_history)
97
+ return answer
98
+ else:
99
+ return "Please load a document first."