vincentmin commited on
Commit
ccbc5fb
1 Parent(s): a9f61c0

Add model selection

Browse files
Files changed (1) hide show
  1. app.py +23 -16
app.py CHANGED
@@ -4,21 +4,21 @@ import requests
4
 
5
  import gradio as gr
6
 
7
- # MODEL = "HuggingFaceH4/starchat-beta"
8
- MODEL = "tiiuae/falcon-7b-instruct"
9
- API_URL = f"https://api-inference.huggingface.co/models/{MODEL}"
 
10
  HEADERS = {"Authorization": f"Bearer {os.environ['HUB_TOKEN']}"}
11
-
12
  TITLE = """<h2 align="center">🚀 Falcon-Chat demo</h2>"""
13
  USER_NAME = "User"
14
  BOT_NAME = "Falcon"
15
  DEFAULT_INSTRUCTIONS = f"""The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, and a human user, called User. In the following interactions, User and Falcon will converse in natural language, and Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. Falcon was built by the Technology Innovation Institute in Abu Dhabi. Falcon will never decline to answer a question, and always attempts to give an answer that User would be satisfied with. It knows a lot, and always tells the truth. The conversation begins.
16
  """
17
  RETRY_COMMAND = "/retry"
18
- STOP_STR = f"\n{USER_NAME}:"
19
- STOP_SUSPECT_LIST = [":", "\n", "User"]
20
 
21
- def run_model(prompt, temperature, top_p):
 
22
  payload = {
23
  "inputs": prompt,
24
  "parameters": {
@@ -28,7 +28,7 @@ def run_model(prompt, temperature, top_p):
28
  "top_p": top_p
29
  }
30
  }
31
- response = requests.post(API_URL, headers=HEADERS, json=payload)
32
  return response.json()[0]['generated_text']
33
 
34
  def get_stream(string: str):
@@ -36,6 +36,12 @@ def get_stream(string: str):
36
 
37
  def chat_accordion():
38
  with gr.Accordion("Parameters", open=False):
 
 
 
 
 
 
39
  temperature = gr.Slider(
40
  minimum=0.1,
41
  maximum=2.0,
@@ -52,7 +58,7 @@ def chat_accordion():
52
  interactive=True,
53
  label="p (nucleus sampling)",
54
  )
55
- return temperature, top_p
56
 
57
 
58
  def format_chat_prompt(message: str, chat_history, instructions: str) -> str:
@@ -98,7 +104,7 @@ def chat():
98
 
99
  with gr.Row(elem_id="param_container"):
100
  with gr.Column():
101
- temperature, top_p = chat_accordion()
102
  with gr.Column():
103
  with gr.Accordion("Instructions", open=False):
104
  instructions = gr.Textbox(
@@ -111,7 +117,7 @@ def chat():
111
  show_label=False,
112
  )
113
 
114
- def run_chat(message: str, chat_history, instructions: str, temperature: float, top_p: float):
115
  if not message or (message == RETRY_COMMAND and len(chat_history) == 0):
116
  yield chat_history
117
  return
@@ -124,10 +130,11 @@ def chat():
124
  prompt = format_chat_prompt(message, chat_history, instructions)
125
  model_output = run_model(
126
  prompt,
 
127
  temperature=temperature,
128
  top_p=top_p,
129
  )
130
- model_output = model_output[len(prompt):].split(f"\n{USER_NAME}")[0]
131
  chat_history = chat_history + [[message, model_output]]
132
  yield chat_history
133
  return
@@ -137,15 +144,15 @@ def chat():
137
  chat_history.pop(-1)
138
  return {chatbot: gr.update(value=chat_history)}
139
 
140
- def run_retry(message: str, chat_history, instructions: str, temperature: float, top_p: float):
141
- yield from run_chat(RETRY_COMMAND, chat_history, instructions, temperature, top_p)
142
 
143
  def clear_chat():
144
  return []
145
 
146
  inputs.submit(
147
  run_chat,
148
- [inputs, chatbot, instructions, temperature, top_p],
149
  outputs=[chatbot],
150
  show_progress=False,
151
  )
@@ -153,7 +160,7 @@ def chat():
153
  delete_turn_button.click(delete_last_turn, inputs=[chatbot], outputs=[chatbot])
154
  retry_button.click(
155
  run_retry,
156
- [inputs, chatbot, instructions, temperature, top_p],
157
  outputs=[chatbot],
158
  show_progress=False,
159
  )
 
4
 
5
  import gradio as gr
6
 
7
+ MODELS = [
8
+ "tiiuae/falcon-7b-instruct",
9
+ "HuggingFaceH4/starchat-beta"
10
+ ]
11
  HEADERS = {"Authorization": f"Bearer {os.environ['HUB_TOKEN']}"}
 
12
  TITLE = """<h2 align="center">🚀 Falcon-Chat demo</h2>"""
13
  USER_NAME = "User"
14
  BOT_NAME = "Falcon"
15
  DEFAULT_INSTRUCTIONS = f"""The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, and a human user, called User. In the following interactions, User and Falcon will converse in natural language, and Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. Falcon was built by the Technology Innovation Institute in Abu Dhabi. Falcon will never decline to answer a question, and always attempts to give an answer that User would be satisfied with. It knows a lot, and always tells the truth. The conversation begins.
16
  """
17
  RETRY_COMMAND = "/retry"
18
+ STOP_STR = f"\n{USER_NAME}"
 
19
 
20
+ def run_model(prompt, model, temperature, top_p):
21
+ api_url = f"https://api-inference.huggingface.co/models/{model}"
22
  payload = {
23
  "inputs": prompt,
24
  "parameters": {
 
28
  "top_p": top_p
29
  }
30
  }
31
+ response = requests.post(api_url, headers=HEADERS, json=payload)
32
  return response.json()[0]['generated_text']
33
 
34
  def get_stream(string: str):
 
36
 
37
  def chat_accordion():
38
  with gr.Accordion("Parameters", open=False):
39
+ model = gr.Dropdown(
40
+ choices = MODELS,
41
+ value = MODELS[0],
42
+ interactive=True,
43
+ label="Model",
44
+ )
45
  temperature = gr.Slider(
46
  minimum=0.1,
47
  maximum=2.0,
 
58
  interactive=True,
59
  label="p (nucleus sampling)",
60
  )
61
+ return model, temperature, top_p
62
 
63
 
64
  def format_chat_prompt(message: str, chat_history, instructions: str) -> str:
 
104
 
105
  with gr.Row(elem_id="param_container"):
106
  with gr.Column():
107
+ model, temperature, top_p = chat_accordion()
108
  with gr.Column():
109
  with gr.Accordion("Instructions", open=False):
110
  instructions = gr.Textbox(
 
117
  show_label=False,
118
  )
119
 
120
+ def run_chat(message: str, chat_history, instructions: str, model: str, temperature: float, top_p: float):
121
  if not message or (message == RETRY_COMMAND and len(chat_history) == 0):
122
  yield chat_history
123
  return
 
130
  prompt = format_chat_prompt(message, chat_history, instructions)
131
  model_output = run_model(
132
  prompt,
133
+ model=model,
134
  temperature=temperature,
135
  top_p=top_p,
136
  )
137
+ model_output = model_output[len(prompt):].split(STOP_STR)[0]
138
  chat_history = chat_history + [[message, model_output]]
139
  yield chat_history
140
  return
 
144
  chat_history.pop(-1)
145
  return {chatbot: gr.update(value=chat_history)}
146
 
147
+ def run_retry(message: str, chat_history, instructions: str, model: str, temperature: float, top_p: float):
148
+ yield from run_chat(RETRY_COMMAND, chat_history, instructions, model: str, temperature, top_p)
149
 
150
  def clear_chat():
151
  return []
152
 
153
  inputs.submit(
154
  run_chat,
155
+ [inputs, chatbot, instructions, model, temperature, top_p],
156
  outputs=[chatbot],
157
  show_progress=False,
158
  )
 
160
  delete_turn_button.click(delete_last_turn, inputs=[chatbot], outputs=[chatbot])
161
  retry_button.click(
162
  run_retry,
163
+ [inputs, chatbot, instructions, model, temperature, top_p],
164
  outputs=[chatbot],
165
  show_progress=False,
166
  )