LOpeetu commited on
Commit
19c89c6
1 Parent(s): 86bf261

first commit

Browse files
Files changed (3) hide show
  1. ragchatbot.py +84 -0
  2. ragfunctions.py +130 -0
  3. requirements.txt +3 -0
ragchatbot.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import ragfunctions
4
+
5
+
6
+ st.title("RAG Chatbot")
7
+
8
+ # Setting the LLM
9
+ with st.expander("Setting the LLM"):
10
+ st.markdown("This page is used to have a chat with the uploaded documents")
11
+ with st.form("setting"):
12
+ row_1 = st.columns(3)
13
+ with row_1[0]:
14
+ token = st.text_input("Hugging Face Token", type="password")
15
+
16
+ with row_1[1]:
17
+ llm_model = st.text_input("LLM model", value="tiiuae/falcon-7b-instruct")
18
+
19
+ with row_1[2]:
20
+ instruct_embeddings = st.text_input("Instruct Embeddings", value="hkunlp/instructor-xl")
21
+
22
+ row_2 = st.columns(3)
23
+ with row_2[0]:
24
+ vector_store_list = os.listdir("vector store/")
25
+ default_choice = (
26
+ vector_store_list.index('naruto_snake')
27
+ if 'naruto_snake' in vector_store_list
28
+ else 0
29
+ )
30
+ existing_vector_store = st.selectbox("Vector Store", vector_store_list, default_choice)
31
+
32
+ with row_2[1]:
33
+ temperature = st.number_input("Temperature", value=1.0, step=0.1)
34
+
35
+ with row_2[2]:
36
+ max_length = st.number_input("Maximum character length", value=300, step=1)
37
+
38
+ create_chatbot = st.form_submit_button("Create chatbot")
39
+
40
+
41
+ # Prepare the LLM model
42
+ if "conversation" not in st.session_state:
43
+ st.session_state.conversation = None
44
+
45
+ if token:
46
+ st.session_state.conversation = ragfunctions.prepare_rag_llm(
47
+ token, llm_model, instruct_embeddings, existing_vector_store, temperature, max_length
48
+ )
49
+
50
+ # Chat history
51
+ if "history" not in st.session_state:
52
+ st.session_state.history = []
53
+
54
+ # Source documents
55
+ if "source" not in st.session_state:
56
+ st.session_state.source = []
57
+
58
+ # Display chats
59
+ for message in st.session_state.history:
60
+ with st.chat_message(message["role"]):
61
+ st.markdown(message["content"])
62
+
63
+ # Ask a question
64
+ if question := st.chat_input("Ask a question"):
65
+ # Append user question to history
66
+ st.session_state.history.append({"role": "user", "content": question})
67
+ # Add user question
68
+ with st.chat_message("user"):
69
+ st.markdown(question)
70
+
71
+ # Answer the question
72
+ answer, doc_source = ragfunctions.generate_answer(question, token)
73
+ with st.chat_message("assistant"):
74
+ st.write(answer)
75
+ # Append assistant answer to history
76
+ st.session_state.history.append({"role": "assistant", "content": answer})
77
+
78
+ # Append the document sources
79
+ st.session_state.source.append({"question": question, "answer": answer, "document": doc_source})
80
+
81
+
82
+ # Source documents
83
+ with st.expander("Source documents"):
84
+ st.write(st.session_state.source)
ragfunctions.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.document_loaders import TextLoader
3
+ from pypdf import PdfReader
4
+ from langchain import HuggingFaceHub
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
7
+ from langchain.vectorstores import FAISS
8
+ from langchain.chains import ConversationalRetrievalChain
9
+ from langchain.memory import ConversationBufferWindowMemory
10
+
11
+
12
+ def read_pdf(file):
13
+ document = ""
14
+
15
+ reader = PdfReader(file)
16
+ for page in reader.pages:
17
+ document += page.extract_text()
18
+
19
+ return document
20
+
21
+ class llmClass():
22
+ def __init__(self, token, llm_model, instruct_embeddings, existing_vector_store, temperature, max_length):
23
+ print('Created')
24
+
25
+ def generate_answer(self, question, token):
26
+ return 'LLm not connected'
27
+
28
+
29
+ def read_txt(file):
30
+ document = str(file.getvalue())
31
+ document = document.replace("\\n", " \\n ").replace("\\r", " \\r ")
32
+
33
+ return document
34
+
35
+
36
+ def split_doc(document, chunk_size, chunk_overlap):
37
+
38
+ splitter = RecursiveCharacterTextSplitter(
39
+ chunk_size=chunk_size,
40
+ chunk_overlap=chunk_overlap
41
+ )
42
+ split = splitter.split_text(document)
43
+ split = splitter.create_documents(split)
44
+
45
+ return split
46
+
47
+
48
+ def embedding_storing(model_name, split, create_new_vs, existing_vector_store, new_vs_name):
49
+ if create_new_vs is not None:
50
+ # Load embeddings instructor
51
+ instructor_embeddings = HuggingFaceInstructEmbeddings(
52
+ model_name=model_name, model_kwargs={"device":"cuda"}
53
+ )
54
+
55
+ # Implement embeddings
56
+ db = FAISS.from_documents(split, instructor_embeddings)
57
+
58
+ if create_new_vs == True:
59
+ # Save db
60
+ db.save_local("vector store/" + new_vs_name)
61
+ else:
62
+ # Load existing db
63
+ load_db = FAISS.load_local(
64
+ "vector store/" + existing_vector_store,
65
+ instructor_embeddings,
66
+ allow_dangerous_deserialization=True
67
+ )
68
+ # Merge two DBs and save
69
+ load_db.merge_from(db)
70
+ load_db.save_local("vector store/" + new_vs_name)
71
+
72
+ st.success("The document has been saved.")
73
+
74
+
75
+ def prepare_rag_llm(
76
+ token, llm_model, instruct_embeddings, vector_store_list, temperature, max_length
77
+ ):
78
+ # # Load embeddings instructor
79
+ # instructor_embeddings = HuggingFaceInstructEmbeddings(
80
+ # model_name=instruct_embeddings, model_kwargs={"device":"cuda"}
81
+ # )
82
+
83
+ # # Load db
84
+ # loaded_db = FAISS.load_local(
85
+ # f"vector store/{vector_store_list}", instructor_embeddings, allow_dangerous_deserialization=True
86
+ # )
87
+
88
+ # # Load LLM
89
+ # llm = HuggingFaceHub(
90
+ # repo_id=llm_model,
91
+ # model_kwargs={"temperature": temperature, "max_length": max_length},
92
+ # huggingfacehub_api_token=token
93
+ # )
94
+
95
+ # memory = ConversationBufferWindowMemory(
96
+ # k=2,
97
+ # memory_key="chat_history",
98
+ # output_key="answer",
99
+ # return_messages=True,
100
+ # )
101
+
102
+ # # Create the chatbot
103
+ # qa_conversation = ConversationalRetrievalChain.from_llm(
104
+ # llm=llm,
105
+ # chain_type="stuff",
106
+ # retriever=loaded_db.as_retriever(),
107
+ # return_source_documents=True,
108
+ # memory=memory,
109
+ # )
110
+
111
+ llm_object = llmClass(token, llm_model, instruct_embeddings, vector_store_list, temperature, max_length)
112
+
113
+ return llm_object
114
+
115
+
116
+ def generate_answer(question, token):
117
+ return ["Not connected to An LLm", "No documentation available"]
118
+ answer = "An error has occured"
119
+
120
+ if token == "":
121
+ answer = "Insert the Hugging Face token"
122
+ doc_source = ["no source"]
123
+ else:
124
+ response = st.session_state.conversation({"question": question})
125
+ answer = response.get("answer").split("Helpful Answer:")[-1].strip()
126
+ explanation = response.get("source_documents", [])
127
+ doc_source = [d.page_content for d in explanation]
128
+
129
+ return answer, doc_source
130
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ pypdf
3
+ langchain