wb-droid commited on
Commit
8fe2d96
1 Parent(s): d06d8c3

First commit.

Browse files
Files changed (2) hide show
  1. app.py +143 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1. Using langchain Vector store
2
+ # https://python.langchain.com/v0.1/docs/modules/data_connection/vectorstores/
3
+ # VectorStore - FAISS
4
+ # 2. Embedding - HuggingFaceInferenceAPIEmbeddings with "BAAI/bge-base-en-v1.5"
5
+ # 3. llm use mistral and llama.
6
+ # "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
7
+ # "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
8
+
9
+ import gradio as gr
10
+ import os
11
+ from langchain.prompts import ChatPromptTemplate
12
+ from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
13
+ from langchain.schema.runnable import RunnablePassthrough
14
+ from langchain_community.document_loaders import TextLoader
15
+ from langchain_text_splitters import CharacterTextSplitter
16
+ from langchain_community.vectorstores import FAISS
17
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
18
+
19
+ API_TOKEN = os.environ.get('HUGGINGFACE_API_KEY')
20
+ HF_API_KEY = API_TOKEN
21
+
22
+ llm_urls = {
23
+ "Mistral 7B": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
24
+ "Llama 8B": "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
25
+ }
26
+
27
+ def initialize_vector_store_retriever(file):
28
+ # Load the document, split it into chunks, embed each chunk and load it into the vector store.
29
+ #raw_documents = TextLoader('./llm.txt').load()
30
+ raw_documents = TextLoader(file).load()
31
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
32
+ documents = text_splitter.split_documents(raw_documents)
33
+
34
+ API_URL = "https://api-inference.huggingface.co/models/BAAI/bge-base-en-v1.5"
35
+ embeddings = HuggingFaceInferenceAPIEmbeddings(
36
+ endpoint_url=API_URL,
37
+ api_key=HF_API_KEY,
38
+ )
39
+ db = FAISS.from_documents(documents, embeddings)
40
+ retriever = db.as_retriever()
41
+ return retriever
42
+
43
+ def generate_llm_rag_prompt() -> ChatPromptTemplate:
44
+ #template = "<s>[INST] {context} {prompt} [/INST]"
45
+ template = "<s>[INST] <<SYS>>{system}<</SYS>>{context} {prompt} [/INST]"
46
+
47
+ prompt_template = ChatPromptTemplate.from_template(template)
48
+ return prompt_template
49
+
50
+
51
+
52
+ def create_chain(retriever, llm):
53
+
54
+ url = llm_urls[llm]
55
+ model_endpoint = HuggingFaceEndpoint(
56
+ endpoint_url=url,
57
+ huggingfacehub_api_token=HF_API_KEY,
58
+ task="text2text-generation",
59
+ max_new_tokens=200
60
+ )
61
+
62
+ if retriever != None:
63
+ def get_system(input):
64
+ return "You are a helpful and honest assistant. Please, respond concisely and truthfully."
65
+
66
+ retrieval = {"context": retriever, "prompt": RunnablePassthrough(), "system": get_system}
67
+ chain = retrieval | generate_llm_rag_prompt() | model_endpoint
68
+ return chain, model_endpoint
69
+ else:
70
+ return None, model_endpoint
71
+
72
+
73
+ def query(question_text, llm, session_data):
74
+ if question_text == "":
75
+ without_rag_text = "Query result without RAG is not available. Enter a question first."
76
+ rag_text = "Query result with RAG is not available. Enter a question first."
77
+ return without_rag_text, rag_text
78
+
79
+ if len(session_data)>0:
80
+ retriever = session_data[0]
81
+ else:
82
+ retriever = None
83
+ chain, model_endpoint = create_chain(retriever, llm)
84
+ without_rag_text = "Query result without RAG:\n\n" + model_endpoint(question_text).strip()
85
+ if (retriever == None):
86
+ rag_text = "Query result With RAG is not available. Load Vector Store first."
87
+ else:
88
+ ans = chain.invoke(question_text).strip()
89
+ s = ans
90
+ s = [s.split("[INST] <<SYS>>")[1] for s in s.split("[/SYS]>[/INST]") if s.find("[INST] <<SYS>>") >=0]
91
+ if len(s) >= 2:
92
+ s = s[1:-1]
93
+ else:
94
+ s = ans
95
+ rag_text = "Query result With RAG:\n\n" + "".join(s)
96
+ return without_rag_text, rag_text
97
+
98
+ def upload_file(file, session_data):
99
+ #file_paths = [file.name for file in files]
100
+ #file = files[0]
101
+ session_data = [initialize_vector_store_retriever(file)]
102
+ return gr.File(value=file, visible=True), session_data
103
+
104
+ def initialize_vector_store(session_data):
105
+ session_data = [initialize_vector_store_retriever()]
106
+ return session_data
107
+
108
+ with gr.Blocks() as demo:
109
+ gr.HTML("""<h1 align="center">Retrieval Augmented Generation</h1>""")
110
+ session_data = gr.State([])
111
+
112
+ file_output = gr.File(visible=False)
113
+ upload_button = gr.UploadButton("Click to Upload a text File to Vector Store", file_types=["text"], file_count="single")
114
+ upload_button.upload(upload_file, [upload_button, session_data], [file_output, session_data])
115
+
116
+ #initialize_VS_button = gr.Button("Load text file to Vector Store")
117
+ with gr.Row():
118
+ with gr.Column(scale=4):
119
+ question_text = gr.Textbox(show_label=False, placeholder="Ask a question", lines=2)
120
+ with gr.Column(scale=1):
121
+ llm_Choice = gr.Radio(["Llama 8B", "Mistral 7B"], value="Mistral 7B", label="Select lanaguage model:", info="")
122
+ query_Button = gr.Button("Query")
123
+
124
+ with gr.Row():
125
+ with gr.Column(scale=1):
126
+ without_rag_text = gr.Textbox(show_label=False, placeholder="Query result without using RAG", lines=15)
127
+ with gr.Column(scale=1):
128
+ rag_text = gr.Textbox(show_label=False, placeholder="Query result with RAG", lines=15)
129
+
130
+ #initialize_VS_button.click(
131
+ # initialize_vector_store,
132
+ # [session_data],
133
+ # [session_data],
134
+ # #show_progress=True,
135
+ #)
136
+ query_Button.click(
137
+ query,
138
+ [question_text, llm_Choice, session_data],
139
+ [without_rag_text, rag_text],
140
+ #show_progress=True,
141
+ )
142
+
143
+ demo.queue().launch(share=False, inbrowser=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ faiss-cpu
3
+ langchain
4
+