vincentmin commited on
Commit
aedbcbd
·
1 Parent(s): cb83f9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -29
app.py CHANGED
@@ -1,8 +1,12 @@
1
  import argparse
2
  import os
 
3
 
4
  import gradio as gr
5
- from text_generation import Client
 
 
 
6
 
7
  TITLE = """<h2 align="center">🚀 Falcon-Chat demo</h2>"""
8
  USER_NAME = "User"
@@ -13,9 +17,20 @@ RETRY_COMMAND = "/retry"
13
  STOP_STR = f"\n{USER_NAME}:"
14
  STOP_SUSPECT_LIST = [":", "\n", "User"]
15
 
16
- INFERENCE_ENDPOINT = os.environ.get("INFERENCE_ENDPOINT")
17
- INFERENCE_AUTH = os.environ.get("INFERENCE_AUTH")
18
-
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def chat_accordion():
21
  with gr.Accordion("Parameters", open=False):
@@ -48,7 +63,7 @@ def format_chat_prompt(message: str, chat_history, instructions: str) -> str:
48
  return prompt
49
 
50
 
51
- def chat(client: Client):
52
  with gr.Column(elem_id="chat_container"):
53
  with gr.Row():
54
  chatbot = gr.Chatbot(elem_id="chatbot")
@@ -106,34 +121,32 @@ def chat(client: Client):
106
 
107
  prompt = format_chat_prompt(message, chat_history, instructions)
108
  chat_history = chat_history + [[message, ""]]
109
- stream = client.generate_stream(
110
  prompt,
111
- do_sample=True,
112
- max_new_tokens=1024,
113
- stop_sequences=[STOP_STR, "<|endoftext|>"],
114
  temperature=temperature,
115
  top_p=top_p,
116
  )
117
- acc_text = ""
118
- for idx, response in enumerate(stream):
119
- text_token = response.token.text
 
120
 
121
- if response.details:
122
- return
123
 
124
- if text_token in STOP_SUSPECT_LIST:
125
- acc_text += text_token
126
- continue
127
 
128
- if idx == 0 and text_token.startswith(" "):
129
- text_token = text_token[1:]
130
 
131
- acc_text += text_token
132
- last_turn = list(chat_history.pop(-1))
133
- last_turn[-1] += acc_text
134
- chat_history = chat_history + [last_turn]
135
- yield chat_history
136
- acc_text = ""
137
 
138
  def delete_last_turn(chat_history):
139
  if chat_history:
@@ -163,7 +176,7 @@ def chat(client: Client):
163
  clear_chat_button.click(clear_chat, [], chatbot)
164
 
165
 
166
- def get_demo(client: Client):
167
  with gr.Blocks(
168
  # css=None
169
  # css="""#chat_container {width: 700px; margin-left: auto; margin-right: auto;}
@@ -195,7 +208,7 @@ def get_demo(client: Client):
195
  """
196
  )
197
 
198
- chat(client)
199
 
200
  return demo
201
 
@@ -209,7 +222,6 @@ if __name__ == "__main__":
209
  default=INFERENCE_ENDPOINT,
210
  )
211
  args = parser.parse_args()
212
- client = Client(args.addr, headers={"Authorization": f"Basic {INFERENCE_AUTH}"})
213
- demo = get_demo(client)
214
  demo.queue(max_size=128, concurrency_count=16)
215
  demo.launch()
 
1
  import argparse
2
  import os
3
+ import requests
4
 
5
  import gradio as gr
6
+
7
+ MODEL = "HuggingFaceH4/starchat-beta"
8
+ API_URL = f"https://api-inference.huggingface.co/models/{MODEL}"
9
+ HEADERS = {"Authorization": f"Bearer {os.environ['HUB_TOKEN']}"}
10
 
11
  TITLE = """<h2 align="center">🚀 Falcon-Chat demo</h2>"""
12
  USER_NAME = "User"
 
17
  STOP_STR = f"\n{USER_NAME}:"
18
  STOP_SUSPECT_LIST = [":", "\n", "User"]
19
 
20
+ def run_model(prompt, temperature, top_p):
21
+ payload = {
22
+ "inputs": prompt,
23
+ "parameters": {
24
+ "max_new_tokens": 128,
25
+ "do_sample": True,
26
+ "temperature": temperature,
27
+ "top_p": top_p
28
+ }
29
+ response = requests.post(API_URL, headers=HEADERS, json=payload)
30
+ return response.json()[0]['generated_text']
31
+
32
+ def get_stream(string: str):
33
+ return enumerate(iter(string.split(" ")))
34
 
35
  def chat_accordion():
36
  with gr.Accordion("Parameters", open=False):
 
63
  return prompt
64
 
65
 
66
+ def chat():
67
  with gr.Column(elem_id="chat_container"):
68
  with gr.Row():
69
  chatbot = gr.Chatbot(elem_id="chatbot")
 
121
 
122
  prompt = format_chat_prompt(message, chat_history, instructions)
123
  chat_history = chat_history + [[message, ""]]
124
+ model_output = run_model(
125
  prompt,
 
 
 
126
  temperature=temperature,
127
  top_p=top_p,
128
  )
129
+ yield model_output
130
+ # acc_text = ""
131
+ # for idx, response in enumerate(stream):
132
+ # text_token = response.token.text
133
 
134
+ # if response.details:
135
+ # return
136
 
137
+ # if text_token in STOP_SUSPECT_LIST:
138
+ # acc_text += text_token
139
+ # continue
140
 
141
+ # if idx == 0 and text_token.startswith(" "):
142
+ # text_token = text_token[1:]
143
 
144
+ # acc_text += text_token
145
+ # last_turn = list(chat_history.pop(-1))
146
+ # last_turn[-1] += acc_text
147
+ # chat_history = chat_history + [last_turn]
148
+ # yield chat_history
149
+ # acc_text = ""
150
 
151
  def delete_last_turn(chat_history):
152
  if chat_history:
 
176
  clear_chat_button.click(clear_chat, [], chatbot)
177
 
178
 
179
+ def get_demo():
180
  with gr.Blocks(
181
  # css=None
182
  # css="""#chat_container {width: 700px; margin-left: auto; margin-right: auto;}
 
208
  """
209
  )
210
 
211
+ chat()
212
 
213
  return demo
214
 
 
222
  default=INFERENCE_ENDPOINT,
223
  )
224
  args = parser.parse_args()
225
+ demo = get_demo()
 
226
  demo.queue(max_size=128, concurrency_count=16)
227
  demo.launch()