moctardiallo commited on
Commit
53953f7
1 Parent(s): ef93b68

refactor '.predict' add SystemMessage

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. model.py +10 -5
app.py CHANGED
@@ -18,8 +18,8 @@ with gr.Blocks() as demo:
18
  with gr.Column():
19
  url = gr.Textbox(value="https://www.gradio.app/docs/gradio/chatinterface", label="Docs URL", render=True)
20
  chat = gr.ChatInterface(
21
- model.respond,
22
- # model.predict,
23
  # model.rag,
24
  additional_inputs=[
25
  url,
 
18
  with gr.Column():
19
  url = gr.Textbox(value="https://www.gradio.app/docs/gradio/chatinterface", label="Docs URL", render=True)
20
  chat = gr.ChatInterface(
21
+ # model.respond,
22
+ model.predict,
23
  # model.rag,
24
  additional_inputs=[
25
  url,
model.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
 
3
  from huggingface_hub import InferenceClient
4
- from langchain.schema import AIMessage, HumanMessage
 
5
  from langchain.chains import RetrievalQA
6
  from langchain.prompts import PromptTemplate
7
 
@@ -56,16 +57,20 @@ class Model:
56
  )
57
 
58
  def predict(self, message, history, url, max_tokens, temperature, top_p):
59
- history_langchain_format = []
60
  for msg in history:
61
  if msg['role'] == "user":
62
  history_langchain_format.append(HumanMessage(content=msg['content']))
63
  elif msg['role'] == "assistant":
64
  history_langchain_format.append(AIMessage(content=msg['content']))
65
  history_langchain_format.append(HumanMessage(content=message))
66
- # gpt_response = self.chat_model(history_langchain_format)
67
- # return gpt_response.content
68
- return self._retrieval_qa(url).invoke({"query": message})['result']
 
 
 
 
69
 
70
  def respond(
71
  self,
 
1
  import os
2
 
3
  from huggingface_hub import InferenceClient
4
+ from langchain.schema import SystemMessage, AIMessage, HumanMessage
5
+
6
  from langchain.chains import RetrievalQA
7
  from langchain.prompts import PromptTemplate
8
 
 
57
  )
58
 
59
  def predict(self, message, history, url, max_tokens, temperature, top_p):
60
+ history_langchain_format = [SystemMessage(content="You're a helpful python developer assistant")]
61
  for msg in history:
62
  if msg['role'] == "user":
63
  history_langchain_format.append(HumanMessage(content=msg['content']))
64
  elif msg['role'] == "assistant":
65
  history_langchain_format.append(AIMessage(content=msg['content']))
66
  history_langchain_format.append(HumanMessage(content=message))
67
+
68
+ # ai_msg = self.chat_model.invoke(history_langchain_format)
69
+ # return ai_msg.content
70
+
71
+ ret = self._retrieval_qa(url)
72
+ return ret.invoke({"query": message})['result']
73
+
74
 
75
  def respond(
76
  self,