moctardiallo commited on
Commit
ef93b68
·
1 Parent(s): e42468d

Use similarity retriever to provide context for '.respond'

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. data.py +0 -21
  3. model.py +14 -4
app.py CHANGED
@@ -18,8 +18,8 @@ with gr.Blocks() as demo:
18
  with gr.Column():
19
  url = gr.Textbox(value="https://www.gradio.app/docs/gradio/chatinterface", label="Docs URL", render=True)
20
  chat = gr.ChatInterface(
21
- # model.respond,
22
- model.predict,
23
  # model.rag,
24
  additional_inputs=[
25
  url,
 
18
  with gr.Column():
19
  url = gr.Textbox(value="https://www.gradio.app/docs/gradio/chatinterface", label="Docs URL", render=True)
20
  chat = gr.ChatInterface(
21
+ model.respond,
22
+ # model.predict,
23
  # model.rag,
24
  additional_inputs=[
25
  url,
data.py CHANGED
@@ -18,27 +18,6 @@ class Data:
18
  encode_kwargs={'normalize_embeddings':True}
19
  )
20
 
21
- def get_context(self):
22
- urls = [
23
- self.url,
24
- ]
25
- loader = UnstructuredURLLoader(urls=urls)
26
- data = loader.load()
27
-
28
- context = data[0].page_content # will come from 'url'
29
-
30
- return context
31
-
32
- def build_prompt(self, question):
33
- prompt = f"""
34
- Use the following piece of context to answer the question asked.
35
- Please try to provide the answer only based on the context
36
- {self.get_context()}
37
- Question:{question}
38
- Helpful Answers:
39
- """
40
- return prompt
41
-
42
  @property
43
  def retriever(self):
44
  loader = UnstructuredURLLoader(urls=self.urls)
 
18
  encode_kwargs={'normalize_embeddings':True}
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @property
22
  def retriever(self):
23
  loader = UnstructuredURLLoader(urls=self.urls)
model.py CHANGED
@@ -21,6 +21,18 @@ class Model:
21
  )
22
  self.chat_model = ChatHuggingFace(llm=self.llm, token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
23
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def _build_prompt_rag(self):
25
  prompt_template="""
26
  Use the following piece of context to answer the question asked.
@@ -64,9 +76,7 @@ class Model:
64
  temperature,
65
  top_p,
66
  ):
67
-
68
- data = Data(url)
69
-
70
  messages = [{"role": "system", "content": url}]
71
 
72
  for val in history:
@@ -75,7 +85,7 @@ class Model:
75
  if val[1]:
76
  messages.append({"role": "assistant", "content": val[1]})
77
 
78
- messages.append({"role": "user", "content": data.build_prompt(message)})
79
 
80
  response = ""
81
 
 
21
  )
22
  self.chat_model = ChatHuggingFace(llm=self.llm, token=os.getenv("HUGGINGFACEHUB_API_TOKEN"))
23
 
24
+ def build_prompt(self, question, context_urls):
25
+ data = Data(context_urls)
26
+ context = data.retriever.invoke(f"{question}")[0].page_content
27
+ prompt = f"""
28
+ Use the following piece of context to answer the question asked.
29
+ Please try to provide the answer only based on the context
30
+ {context}
31
+ Question:{question}
32
+ Helpful Answers:
33
+ """
34
+ return prompt
35
+
36
  def _build_prompt_rag(self):
37
  prompt_template="""
38
  Use the following piece of context to answer the question asked.
 
76
  temperature,
77
  top_p,
78
  ):
79
+
 
 
80
  messages = [{"role": "system", "content": url}]
81
 
82
  for val in history:
 
85
  if val[1]:
86
  messages.append({"role": "assistant", "content": val[1]})
87
 
88
+ messages.append({"role": "user", "content": self.build_prompt(message, [url])})
89
 
90
  response = ""
91