vwxyzjn commited on
Commit
2435439
·
1 Parent(s): 7c288c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -39
app.py CHANGED
@@ -1,14 +1,15 @@
1
- import json
2
  import os
3
- import shutil
4
- import requests
5
 
6
  import gradio as gr
7
- from huggingface_hub import Repository
8
  from text_generation import Client
 
 
9
 
10
  from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
11
 
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
  print(HF_TOKEN)
14
 
@@ -80,15 +81,39 @@ client = Client(
80
  API_URL,
81
  headers={"Authorization": f"Bearer {HF_TOKEN}"},
82
  )
83
- client_base = Client(
84
- API_URL_BASE, headers={"Authorization": f"Bearer {HF_TOKEN}"},
85
- )
86
- client_plus = Client(
87
- API_URL_PLUS, headers={"Authorization": f"Bearer {HF_TOKEN}"},
88
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def generate(
91
- prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, version="StarCoder",
92
  ):
93
 
94
  temperature = float(temperature)
@@ -97,6 +122,7 @@ def generate(
97
  top_p = float(top_p)
98
  fim_mode = False
99
 
 
100
  generate_kwargs = dict(
101
  temperature=temperature,
102
  max_new_tokens=max_new_tokens,
@@ -104,39 +130,49 @@ def generate(
104
  repetition_penalty=repetition_penalty,
105
  do_sample=True,
106
  seed=42,
 
107
  )
108
 
109
- if FIM_INDICATOR in prompt:
110
- fim_mode = True
111
- try:
112
- prefix, suffix = prompt.split(FIM_INDICATOR)
113
- except:
114
- raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!")
115
- prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
116
-
117
- if version == "StarCoder":
118
  stream = client.generate_stream(prompt, **generate_kwargs)
119
- elif version == "StarCoderPlus":
120
- stream = client_plus.generate_stream(prompt, **generate_kwargs)
121
- else:
122
- stream = client_base.generate_stream(prompt, **generate_kwargs)
123
 
124
- if fim_mode:
125
- output = prefix
126
- else:
127
- output = prompt
128
 
 
 
129
  previous_token = ""
130
  for response in stream:
131
  if response.token.text == "<|endoftext|>":
132
- if fim_mode:
133
- output += suffix
134
- else:
135
- return output
136
  else:
137
  output += response.token.text
138
  previous_token = response.token.text
139
- yield output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  return output
141
 
142
 
@@ -168,14 +204,14 @@ css += share_btn_css + monospace_css + ".gradio-container {color: black}"
168
 
169
  description = """
170
  <div style="text-align: center;">
171
- <h1> ⭐ StarCoder <span style='color: #e6b800;'>Models</span> Playground</h1>
172
  </div>
173
  <div style="text-align: left;">
174
- <p>This is a demo to generate text and code with the following StarCoder models:</p>
175
  <ul>
176
  <li><a href="https://huggingface.co/bigcode/starcoderplus" style='color: #e6b800;'>StarCoderPlus</a>: A finetuned version of StarCoderBase on English web data, making it strong in both English text and code generation.</li>
177
  <li><a href="https://huggingface.co/bigcode/starcoderbase" style='color: #e6b800;'>StarCoderBase</a>: A code generation model trained on 80+ programming languages, providing broad language coverage for code generation tasks.</li>
178
- <li><a href="https://huggingface.co/bigcode/starcoder" style='color: #e6b800;'>StarCoder</a>: A finetuned version of StarCoderBase specifically focused on Python, while also maintaining strong performance on other programming languages.</li>
179
  </ul>
180
  <p><b>Please note:</b> These models are not designed for instruction purposes. If you're looking for instruction or want to chat with a fine-tuned model, you can visit the <a href="https://huggingface.co/spaces/HuggingFaceH4/starchat-playground">StarChat Playground</a>.</p>
181
  </div>
@@ -188,8 +224,8 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
188
  gr.Markdown(description)
189
  with gr.Row():
190
  version = gr.Dropdown(
191
- ["StarCoderPlus", "StarCoderBase", "StarCoder"],
192
- value="StarCoder",
193
  label="Model",
194
  info="Choose a model from the list",
195
  )
@@ -269,4 +305,20 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
269
  outputs=[output],
270
  )
271
  share_button.click(None, [], [], _js=share_js)
272
- demo.queue(concurrency_count=16).launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import re
3
+ import copy
4
 
5
  import gradio as gr
 
6
  from text_generation import Client
7
+ from transformers import load_tool
8
+
9
 
10
  from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
11
 
12
+
13
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
  print(HF_TOKEN)
15
 
 
81
  API_URL,
82
  headers={"Authorization": f"Bearer {HF_TOKEN}"},
83
  )
84
+ tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
85
+ tool_fn = lambda x: tool(x).split("\n")[1][:600] # limit the amount if tokens
86
+ tools = {"Wiki": tool_fn}
87
+
88
+ def parse_tool_call(text, request_token="<request>", call_token="<call>"):
89
+ """
90
+ Parse request string. Expected format: <request><tool_name>query<call>
91
+ """
92
+ result = re.search(f"(?<={request_token}).*?(?={call_token})", text, re.DOTALL)
93
+
94
+ # if we can't find a <request>/<call> span we return none
95
+ if result is None:
96
+ return None, None
97
+ else:
98
+ extracted_text = result.group()
99
+
100
+ result = re.search(r"<(.*?)>", extracted_text)
101
+
102
+ # if we can't find a tool name we return none
103
+ if result is None:
104
+ return None, None
105
+ else:
106
+ tool = result.group(1)
107
+
108
+ # split off the tool name
109
+ query = ">".join(extracted_text.split(">")[1:])
110
+
111
+ return tool, query
112
+
113
+
114
 
115
  def generate(
116
+ prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, version="StarCoderBase TriviaQA",
117
  ):
118
 
119
  temperature = float(temperature)
 
122
  top_p = float(top_p)
123
  fim_mode = False
124
 
125
+ # TextEnv tool
126
  generate_kwargs = dict(
127
  temperature=temperature,
128
  max_new_tokens=max_new_tokens,
 
130
  repetition_penalty=repetition_penalty,
131
  do_sample=True,
132
  seed=42,
133
+ stop_sequences=["<call>"]
134
  )
135
 
136
+ if version == "StarCoderBase TriviaQA":
 
 
 
 
 
 
 
 
137
  stream = client.generate_stream(prompt, **generate_kwargs)
 
 
 
 
138
 
 
 
 
 
139
 
140
+ # call env phase
141
+ output = prompt
142
  previous_token = ""
143
  for response in stream:
144
  if response.token.text == "<|endoftext|>":
145
+ return output
 
 
 
146
  else:
147
  output += response.token.text
148
  previous_token = response.token.text
149
+ # text env logic:
150
+ tool, query = parse_tool_call(output[len(prompt):])
151
+ if tool is not None and query is not None:
152
+ if tool not in tools:
153
+ response = f"Unknown tool {tool}."
154
+ try:
155
+ response = tools[tool](query)
156
+ output += response + "<response>"
157
+ except Exception as error:
158
+ response = f"Tool error: {str(error)}"
159
+ yield output[len(prompt):]
160
+
161
+ call_output = copy.deepcopy(output)
162
+ # response phase
163
+ generate_kwargs["stop_sequences"] = ["<submit>"]
164
+ stream = client.generate_stream(output, **generate_kwargs)
165
+ previous_token = ""
166
+ for response in stream:
167
+ if response.token.text == "<|endoftext|>":
168
+ return output
169
+ else:
170
+ output += response.token.text
171
+ previous_token = response.token.text
172
+ yield output[len(prompt):]
173
+
174
+
175
+
176
  return output
177
 
178
 
 
204
 
205
  description = """
206
  <div style="text-align: center;">
207
+ <h1> ⭐ StarCoderBase TriviaQA <span style='color: #e6b800;'>Models</span> Playground</h1>
208
  </div>
209
  <div style="text-align: left;">
210
+ <p>This is a demo to generate text and code with the following StarCoderBase TriviaQA models:</p>
211
  <ul>
212
  <li><a href="https://huggingface.co/bigcode/starcoderplus" style='color: #e6b800;'>StarCoderPlus</a>: A finetuned version of StarCoderBase on English web data, making it strong in both English text and code generation.</li>
213
  <li><a href="https://huggingface.co/bigcode/starcoderbase" style='color: #e6b800;'>StarCoderBase</a>: A code generation model trained on 80+ programming languages, providing broad language coverage for code generation tasks.</li>
214
+ <li><a href="https://huggingface.co/bigcode/starcoder" style='color: #e6b800;'>StarCoderBase TriviaQA</a>: A finetuned version of StarCoderBase specifically focused on Python, while also maintaining strong performance on other programming languages.</li>
215
  </ul>
216
  <p><b>Please note:</b> These models are not designed for instruction purposes. If you're looking for instruction or want to chat with a fine-tuned model, you can visit the <a href="https://huggingface.co/spaces/HuggingFaceH4/starchat-playground">StarChat Playground</a>.</p>
217
  </div>
 
224
  gr.Markdown(description)
225
  with gr.Row():
226
  version = gr.Dropdown(
227
+ ["StarCoderBase TriviaQA"],
228
+ value="StarCoderBase TriviaQA",
229
  label="Model",
230
  info="Choose a model from the list",
231
  )
 
305
  outputs=[output],
306
  )
307
  share_button.click(None, [], [], _js=share_js)
308
+ demo.queue(concurrency_count=16).launch(share=True)
309
+
310
+
311
+ """
312
+ Answer the following question:
313
+
314
+ Q: In which branch of the arts is Patricia Neary famous?
315
+ A: Ballets
316
+ A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
317
+ Result=Ballets<submit>
318
+
319
+ Q: Who won Super Bowl XX?
320
+ A: Chicago Bears
321
+ A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
322
+ Result=Chicago Bears<submit>
323
+
324
+ Q: In what state is Philadelphia located?"""