Alyafeai commited on
Commit
aa69e53
1 Parent(s): e25ebef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -24
app.py CHANGED
@@ -110,14 +110,22 @@ def chat_accordion():
110
 
111
 
112
  def format_chat_prompt(
113
- message: str, chat_history, instructions: str, user_name: str, bot_name: str
 
 
 
 
 
114
  ) -> str:
115
  instructions = instructions.strip()
116
  prompt = instructions
117
- for turn in chat_history:
118
- user_message, bot_message = turn
119
- prompt = f"{prompt}\n{user_name}: {user_message}\n{bot_name}: {bot_message}"
120
- prompt = f"{prompt}\n{user_name}: {message}\n{bot_name}:"
 
 
 
121
  return prompt
122
 
123
 
@@ -156,29 +164,26 @@ def chat_tab():
156
  session_id: str,
157
  ):
158
  prompt = format_chat_prompt(message, history, instructions, user_name, bot_name)
159
- generated_response = ""
160
 
161
- payload = json.dumps(
162
- {
163
- "endpoint": MODEL_NAME,
164
- "data": {
165
- "inputs": prompt,
166
- "parameters": {
167
- "max_new_tokens": 1024,
168
- "do_sample": True,
169
- "top_p": top_p,
170
- "stop": ["User:"],
171
- },
172
- "stream": True,
173
- "session_id": session_id,
174
  },
175
- }
176
- )
 
 
177
 
178
  sess = requests.Session()
179
  full_output = ""
180
  with sess.post(
181
- ENDPOINT_URL, headers=HEADERS, data=payload, stream=True
182
  ) as response:
183
  if response.status_code == 200:
184
  for chunk in response.iter_content(chunk_size=4):
@@ -191,7 +196,22 @@ def chat_tab():
191
  else:
192
  yield full_output
193
  if full_output == "":
194
- yield "I am sorry, I did not understand your query. Could you please rephrase it?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  return ""
196
 
197
  with gr.Column():
@@ -268,4 +288,4 @@ def start_demo():
268
 
269
 
270
  if __name__ == "__main__":
271
- start_demo()
 
110
 
111
 
112
  def format_chat_prompt(
113
+ message: str,
114
+ chat_history,
115
+ instructions: str,
116
+ user_name: str,
117
+ bot_name: str,
118
+ include_chat_history: bool = True,
119
  ) -> str:
120
  instructions = instructions.strip()
121
  prompt = instructions
122
+ if include_chat_history:
123
+ for turn in chat_history:
124
+ user_message, bot_message = turn
125
+ prompt = f"{prompt}\n{user_name}: {user_message}\n{bot_name}: {bot_message}"
126
+ prompt = f"{prompt}\n{user_name}: {message}\n{bot_name}:"
127
+ else:
128
+ prompt = f"{prompt}\n{user_name}: {message}\n{bot_name}:"
129
  return prompt
130
 
131
 
 
164
  session_id: str,
165
  ):
166
  prompt = format_chat_prompt(message, history, instructions, user_name, bot_name)
 
167
 
168
+ payload = {
169
+ "endpoint": MODEL_NAME,
170
+ "data": {
171
+ "inputs": prompt,
172
+ "parameters": {
173
+ "max_new_tokens": 1024,
174
+ "do_sample": True,
175
+ "top_p": top_p,
176
+ "stop": ["User:"],
 
 
 
 
177
  },
178
+ "stream": True,
179
+ "session_id": session_id,
180
+ },
181
+ }
182
 
183
  sess = requests.Session()
184
  full_output = ""
185
  with sess.post(
186
+ ENDPOINT_URL, headers=HEADERS, json=payload, stream=True
187
  ) as response:
188
  if response.status_code == 200:
189
  for chunk in response.iter_content(chunk_size=4):
 
196
  else:
197
  yield full_output
198
  if full_output == "":
199
+ payload["data"]["inputs"] = format_chat_prompt(
200
+ message, history, instructions, user_name, bot_name, False
201
+ )
202
+ with sess.post(
203
+ ENDPOINT_URL, headers=HEADERS, json=payload, stream=True
204
+ ) as response:
205
+ if response.status_code == 200:
206
+ for chunk in response.iter_content(chunk_size=4):
207
+ if chunk:
208
+ decoded = chunk.decode("utf-8")
209
+ full_output += decoded
210
+ if full_output.endswith("User:"):
211
+ yield full_output[:-5]
212
+ break
213
+ else:
214
+ yield full_output
215
  return ""
216
 
217
  with gr.Column():
 
288
 
289
 
290
  if __name__ == "__main__":
291
+ start_demo()