moctardiallo commited on
Commit
bc4c927
2 Parent(s): 50cd04e 53953f7

Merge branch 'rag'

Browse files
Files changed (4) hide show
  1. app.py +3 -1
  2. data.py +30 -22
  3. model.py +66 -4
  4. requirements.txt +4 -1
app.py CHANGED
@@ -23,7 +23,9 @@ with gr.Blocks(fill_height=True) as demo:
23
  with gr.Column():
24
  url = gr.Textbox(value="https://www.gradio.app/docs/gradio/chatinterface", label="Docs URL", render=True)
25
  chat = gr.ChatInterface(
26
- model.respond,
 
 
27
  additional_inputs=[
28
  url,
29
  max_tokens,
 
23
  with gr.Column():
24
  url = gr.Textbox(value="https://www.gradio.app/docs/gradio/chatinterface", label="Docs URL", render=True)
25
  chat = gr.ChatInterface(
26
+ # model.respond,
27
+ model.predict,
28
+ # model.rag,
29
  additional_inputs=[
30
  url,
31
  max_tokens,
data.py CHANGED
@@ -1,26 +1,34 @@
1
  from langchain_community.document_loaders import UnstructuredURLLoader
2
 
 
 
 
 
 
 
 
 
3
  class Data:
4
- def __init__(self, url):
5
- self.url = url
6
-
7
- def get_context(self):
8
- urls = [
9
- self.url,
10
- ]
11
- loader = UnstructuredURLLoader(urls=urls)
 
 
 
 
12
  data = loader.load()
13
-
14
- context = data[0].page_content # will come from 'url'
15
-
16
- return context
17
-
18
- def build_prompt(self, question):
19
- prompt = f"""
20
- Use the following piece of context to answer the question asked.
21
- Please try to provide the answer only based on the context
22
- {self.get_context()}
23
- Question:{question}
24
- Helpful Answers:
25
- """
26
- return prompt
 
1
  from langchain_community.document_loaders import UnstructuredURLLoader
2
 
3
+ from langchain_community.document_loaders import PyPDFLoader
4
+ from langchain_community.document_loaders import PyPDFDirectoryLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
8
+
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
10
+
11
  class Data:
12
+ def __init__(self, urls):
13
+ self.urls = urls
14
+ ## Embedding Using Huggingface
15
+ self.huggingface_embeddings = HuggingFaceBgeEmbeddings(
16
+ model_name="BAAI/bge-small-en-v1.5", #sentence-transformers/all-MiniLM-l6-v2
17
+ model_kwargs={'device':'cpu'},
18
+ encode_kwargs={'normalize_embeddings':True}
19
+ )
20
+
21
+ @property
22
+ def retriever(self):
23
+ loader = UnstructuredURLLoader(urls=self.urls)
24
  data = loader.load()
25
+
26
+ ## VectorStore Creation
27
+ vectorstore = FAISS.from_documents(data, self.huggingface_embeddings)
28
+
29
+ retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":3})
30
+
31
+ return retriever
32
+
33
+
34
+
 
 
 
 
model.py CHANGED
@@ -1,12 +1,76 @@
1
  import os
2
 
3
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
4
 
5
  from data import Data
6
 
7
  class Model:
8
  def __init__(self, model_id="meta-llama/Llama-3.2-1B-Instruct"):
9
  self.client = InferenceClient(model_id, token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def respond(
12
  self,
@@ -17,9 +81,7 @@ class Model:
17
  temperature,
18
  top_p,
19
  ):
20
-
21
- data = Data(url)
22
-
23
  messages = [{"role": "system", "content": url}]
24
 
25
  for val in history:
@@ -28,7 +90,7 @@ class Model:
28
  if val[1]:
29
  messages.append({"role": "assistant", "content": val[1]})
30
 
31
- messages.append({"role": "user", "content": data.build_prompt(message)})
32
 
33
  response = ""
34
 
 
1
  import os
2
 
3
  from huggingface_hub import InferenceClient
4
+ from langchain.schema import SystemMessage, AIMessage, HumanMessage
5
+
6
+ from langchain.chains import RetrievalQA
7
+ from langchain.prompts import PromptTemplate
8
+
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
10
 
11
  from data import Data
12
 
13
  class Model:
14
  def __init__(self, model_id="meta-llama/Llama-3.2-1B-Instruct"):
15
  self.client = InferenceClient(model_id, token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
16
+ self.llm = HuggingFaceEndpoint(
17
+ repo_id="HuggingFaceH4/zephyr-7b-beta",
18
+ task="text-generation",
19
+ max_new_tokens=512,
20
+ do_sample=False,
21
+ repetition_penalty=1.03,
22
+ )
23
+ self.chat_model = ChatHuggingFace(llm=self.llm, token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
24
+
25
+ def build_prompt(self, question, context_urls):
26
+ data = Data(context_urls)
27
+ context = data.retriever.invoke(f"{question}")[0].page_content
28
+ prompt = f"""
29
+ Use the following piece of context to answer the question asked.
30
+ Please try to provide the answer only based on the context
31
+ {context}
32
+ Question:{question}
33
+ Helpful Answers:
34
+ """
35
+ return prompt
36
+
37
+ def _build_prompt_rag(self):
38
+ prompt_template="""
39
+ Use the following piece of context to answer the question asked.
40
+ Please try to provide the answer only based on the context
41
+ {context}
42
+ Question:{question}
43
+ Helpful Answers:
44
+ """
45
+ prompt=PromptTemplate(template=prompt_template,input_variables=["context","question"])
46
+ return prompt
47
+
48
+ def _retrieval_qa(self, url):
49
+ data = Data([url])
50
+ prompt = self._build_prompt_rag()
51
+ return RetrievalQA.from_chain_type(
52
+ llm=self.chat_model,
53
+ chain_type="stuff",
54
+ retriever=data.retriever,
55
+ return_source_documents=True,
56
+ chain_type_kwargs={"prompt":prompt}
57
+ )
58
+
59
+ def predict(self, message, history, url, max_tokens, temperature, top_p):
60
+ history_langchain_format = [SystemMessage(content="You're a helpful python developer assistant")]
61
+ for msg in history:
62
+ if msg['role'] == "user":
63
+ history_langchain_format.append(HumanMessage(content=msg['content']))
64
+ elif msg['role'] == "assistant":
65
+ history_langchain_format.append(AIMessage(content=msg['content']))
66
+ history_langchain_format.append(HumanMessage(content=message))
67
+
68
+ # ai_msg = self.chat_model.invoke(history_langchain_format)
69
+ # return ai_msg.content
70
+
71
+ ret = self._retrieval_qa(url)
72
+ return ret.invoke({"query": message})['result']
73
+
74
 
75
  def respond(
76
  self,
 
81
  temperature,
82
  top_p,
83
  ):
84
+
 
 
85
  messages = [{"role": "system", "content": url}]
86
 
87
  for val in history:
 
90
  if val[1]:
91
  messages.append({"role": "assistant", "content": val[1]})
92
 
93
+ messages.append({"role": "user", "content": self.build_prompt(message, [url])})
94
 
95
  response = ""
96
 
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  huggingface_hub==0.25.2
2
  langchain-community==0.3.3
 
 
3
  unstructured==0.16.0
4
- unstructured-client==0.26.1
 
 
1
  huggingface_hub==0.25.2
2
  langchain-community==0.3.3
3
+ langchain-core==0.3.12
4
+ langchain-huggingface==0.1.0
5
  unstructured==0.16.0
6
+ unstructured-client==0.26.1
7
+ faiss-cpu==1.9.0