Daniel Fried commited on
Commit
0b73ae7
1 Parent(s): 56871fa
Files changed (2) hide show
  1. modules/app.py +31 -11
  2. static/index.html +56 -22
modules/app.py CHANGED
@@ -15,7 +15,7 @@ import json
15
  # from flask import Flask, request, render_template
16
  # from flask_cors import CORS
17
  # app = Flask(__name__, static_folder='static')
18
- # app.config['TEMPLATES_AUTO_RELOAD'] = True
19
  # CORS(app, resources= {
20
  # r"/generate": {"origins": origins},
21
  # r"/infill": {"origins": origins},
@@ -25,9 +25,12 @@ import json
25
  PORT = 7860
26
  VERBOSE = False
27
 
 
 
 
28
  if BIG_MODEL:
29
  CUDA = True
30
- model_name = "facebook/incoder-6B"
31
  else:
32
  CUDA = False
33
  model_name = "facebook/incoder-1B"
@@ -60,21 +63,28 @@ def generate(input, length_limit=None, temperature=None):
60
  input_ids = tokenizer(input, return_tensors="pt").input_ids
61
  if CUDA:
62
  input_ids = input_ids.cuda()
63
- max_length = length_limit + input_ids.flatten().size(0)
64
- if max_length > 256:
65
- max_length = 256
 
 
 
 
 
66
  output = model.generate(input_ids=input_ids, do_sample=True, top_p=0.95, temperature=temperature, max_length=max_length)
67
  detok_hypo_str = tokenizer.decode(output.flatten())
68
  if detok_hypo_str.startswith(BOS):
69
  detok_hypo_str = detok_hypo_str[len(BOS):]
70
- return detok_hypo_str
71
 
72
  def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel=False, max_retries=1):
73
  assert isinstance(parts, list)
74
  retries_attempted = 0
75
  done = False
76
 
 
77
  while (not done) and (retries_attempted < max_retries):
 
78
  retries_attempted += 1
79
  if VERBOSE:
80
  print(f"retry {retries_attempted}")
@@ -98,7 +108,8 @@ def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel
98
  for sentinel_ix, part in enumerate(parts[:-1]):
99
  complete.append(part)
100
  prompt += make_sentinel(sentinel_ix)
101
- completion = generate(prompt, length_limit, temperature)
 
102
  completion = completion[len(prompt):]
103
  if EOM not in completion:
104
  if VERBOSE:
@@ -133,6 +144,7 @@ def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel
133
  'parts': parts,
134
  'infills': infills,
135
  'retries_attempted': retries_attempted,
 
136
  }
137
 
138
 
@@ -151,11 +163,15 @@ async def generate_maybe(info: str):
151
  if VERBOSE:
152
  print(prompt)
153
  try:
154
- generation = generate(prompt, length_limit, temperature)
155
- return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation}
 
 
 
 
156
  except Exception as e:
157
  traceback.print_exception(*sys.exc_info())
158
- return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'text': f'There was an error: {e}. Tell Daniel.'}
159
 
160
  @app.get('/infill')
161
  async def infill_maybe(info: str):
@@ -169,12 +185,16 @@ async def infill_maybe(info: str):
169
  generation = infill(form['parts'], length_limit, temperature, extra_sentinel=extra_sentinel, max_retries=max_retries)
170
  generation['result'] = 'success'
171
  generation['type'] = 'infill'
 
 
 
 
172
  return generation
173
  # return {'result': 'success', 'prefix': prefix, 'suffix': suffix, 'text': generation['text']}
174
  except Exception as e:
175
  traceback.print_exception(*sys.exc_info())
176
  print(e)
177
- return {'result': 'error', 'type': 'infill', 'text': f'There was an error: {e}.'}
178
 
179
 
180
  if __name__ == "__main__":
 
15
  # from flask import Flask, request, render_template
16
  # from flask_cors import CORS
17
  # app = Flask(__name__, static_folder='static')
18
+ # app.config['TEMPLATES_AUTO_RELOAD'] = Tru
19
  # CORS(app, resources= {
20
  # r"/generate": {"origins": origins},
21
  # r"/infill": {"origins": origins},
 
25
  PORT = 7860
26
  VERBOSE = False
27
 
28
+ MAX_LENGTH = 256+64
29
+ TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.'
30
+
31
  if BIG_MODEL:
32
  CUDA = True
33
+ model_name = "./incoder-6B"
34
  else:
35
  CUDA = False
36
  model_name = "facebook/incoder-1B"
 
63
  input_ids = tokenizer(input, return_tensors="pt").input_ids
64
  if CUDA:
65
  input_ids = input_ids.cuda()
66
+ current_length = input_ids.flatten().size(0)
67
+ max_length = length_limit + current_length
68
+ truncated = False
69
+ if max_length > MAX_LENGTH:
70
+ max_length = MAX_LENGTH
71
+ truncated = True
72
+ if max_length == current_length:
73
+ return input, True
74
  output = model.generate(input_ids=input_ids, do_sample=True, top_p=0.95, temperature=temperature, max_length=max_length)
75
  detok_hypo_str = tokenizer.decode(output.flatten())
76
  if detok_hypo_str.startswith(BOS):
77
  detok_hypo_str = detok_hypo_str[len(BOS):]
78
+ return detok_hypo_str, truncated
79
 
80
  def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel=False, max_retries=1):
81
  assert isinstance(parts, list)
82
  retries_attempted = 0
83
  done = False
84
 
85
+
86
  while (not done) and (retries_attempted < max_retries):
87
+ any_truncated = False
88
  retries_attempted += 1
89
  if VERBOSE:
90
  print(f"retry {retries_attempted}")
 
108
  for sentinel_ix, part in enumerate(parts[:-1]):
109
  complete.append(part)
110
  prompt += make_sentinel(sentinel_ix)
111
+ completion, this_truncated = generate(prompt, length_limit, temperature)
112
+ any_truncated |= this_truncated
113
  completion = completion[len(prompt):]
114
  if EOM not in completion:
115
  if VERBOSE:
 
144
  'parts': parts,
145
  'infills': infills,
146
  'retries_attempted': retries_attempted,
147
+ 'truncated': any_truncated,
148
  }
149
 
150
 
 
163
  if VERBOSE:
164
  print(prompt)
165
  try:
166
+ generation, truncated = generate(prompt, length_limit, temperature)
167
+ if truncated:
168
+ message = TRUNCATION_MESSAGE
169
+ else:
170
+ message = ''
171
+ return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation, 'message': message}
172
  except Exception as e:
173
  traceback.print_exception(*sys.exc_info())
174
+ return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}
175
 
176
  @app.get('/infill')
177
  async def infill_maybe(info: str):
 
185
  generation = infill(form['parts'], length_limit, temperature, extra_sentinel=extra_sentinel, max_retries=max_retries)
186
  generation['result'] = 'success'
187
  generation['type'] = 'infill'
188
+ if generation['truncated']:
189
+ generation['message'] = TRUNCATION_MESSAGE
190
+ else:
191
+ generation['message'] = ''
192
  return generation
193
  # return {'result': 'success', 'prefix': prefix, 'suffix': suffix, 'text': generation['text']}
194
  except Exception as e:
195
  traceback.print_exception(*sys.exc_info())
196
  print(e)
197
+ return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'}
198
 
199
 
200
  if __name__ == "__main__":
static/index.html CHANGED
@@ -93,6 +93,10 @@ label {
93
  color: red;
94
  width: 100%;
95
  }
 
 
 
 
96
  #examples span {
97
  margin-right: 1em;
98
  }
@@ -117,17 +121,20 @@ label {
117
  <div id="about">
118
  <p>Refresh "Extend" will insert text at the end. "Infill" will replace all <infill> masks. (click add <infill> mask to add infill mask at the cursors or selections ) </p>
119
  <p id="examples">
120
- <span style="font-weight: bold">Examples:</span>
121
- <span><a href='javascript:select_example("python");'>Python</a></span>
122
- <span><a href='javascript:select_example("python-infill2");'>Python-infill</a></span>
123
- <span><a href='javascript:select_example("type-pred");'>Type-prediction</a></span>
124
- <span><a href='javascript:select_example("docstring");'>Doc-string</a></span>
 
 
 
125
  <span><a href='javascript:select_example("javascript");'>JavaScript</a></span>
126
  <span><a href='javascript:select_example("jupyter");'>Jupyter</a></span>
127
  <span><a href='javascript:select_example("stackoverflow");'>StackOverflow</a></span>
128
  <span><a href='javascript:select_example("metadata-conditioning");'>Metadata Conditioning</a></span>
129
  <span><a href='javascript:select_example("metadata-prediction");'>Metadata Prediction</a></span>
130
- <span><a href='javascript:select_example("humaneval");'>Docstring->Code</a></span>
131
  </div>
132
  </div>
133
  <div class="request">
@@ -135,17 +142,17 @@ label {
135
  <div class="leftside">
136
  <div>
137
  <label>Response Length:</label>
138
- <input type="range" value="64" min="16" max="512" step="16" class="slider"
139
  oninput="this.nextElementSibling.value = this.value" name="length"
140
  id='length_slider'>
141
  <output class='a' id="length_slider_output">64</output>
142
  <div>
143
  <label>Temperature:</label>
144
- <input type="range" value="0.6" min="0.2" max="1.0" step="0.10" class="slider"
145
  oninput="this.nextElementSibling.value = this.value" name="temp"
146
  id='temp_slider'
147
  >
148
- <output>0.6</output>
149
  </div>
150
  <div>
151
  <!-- <input type="submit" value="Extend" id="extend-form-button"/> -->
@@ -173,8 +180,6 @@ label {
173
  </div>
174
  </form>
175
  </div>
176
- <div id="editor-holder">
177
- <div>
178
  Syntax:
179
  <select name="mode" id="mode">
180
  <option value="text">Text</option>
@@ -210,6 +215,7 @@ Syntax:
210
  </div>
211
  </div>
212
  <div id="error"></div>
 
213
 
214
  <h3 id="debug-info">More Info</h3>
215
  <p>
@@ -236,25 +242,21 @@ var Range = require("ace/range").Range;
236
 
237
  // examples for the user
238
  var EXAMPLES = {
239
- "python": {
240
- "prompt": "<| file ext=.py |>\nclass Person:\n" + SPLIT_TOKEN + "\np = Person('Eren', 18, 'Male')",
241
- "length": 64,
242
- "mode": "python"
243
- },
244
  "python-infill2": {
245
  "prompt":
246
  `from collections import Counter
247
- def <infill>(file_name):
248
  """Count the number of occurrences of each word in the file."""
249
  <infill>
250
  `,
251
  "length": 64,
 
252
  "mode": "python"
253
  },
254
 
255
  "type-pred": {
256
  "prompt":
257
- `def count_words(filename: str) -> <infill>:
258
  """Count the number of occurrences of each word in the file."""
259
  with open(filename, 'r') as f:
260
  word_counts = {}
@@ -267,6 +269,7 @@ def <infill>(file_name):
267
  return word_counts
268
  `,
269
  "length": 4,
 
270
  "mode": "python"
271
  },
272
  "docstring": {
@@ -285,35 +288,57 @@ def <infill>(file_name):
285
  return word_counts
286
  `,
287
  "length": 32,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  "mode": "python"
289
  },
290
  "javascript": {
291
  "prompt": "<| file ext=.js |>\n // is something really happening here",
292
  "length": 64,
 
293
  "mode": "javascript"
294
  },
295
  "jupyter": {
296
  "prompt": "<| file ext=.ipynb:python |>\n<text>\nThis notebook demonstrates using scikit-learn to perform PCA.\n</text>\n<cell>",
297
  "length": 64,
 
298
  "mode": "python"
299
  },
300
  "stackoverflow": {
301
  "prompt": "<| q tags=regex,html |>\nParsing HTML with regular expressions\nHow do I do this? Is it a good idea?\n<|/ q dscore=3 |>\n<| a dscore=4 |>",
302
  "length": 64,
 
303
  "mode": "text"
304
  },
305
  "metadata-conditioning": {
306
  "prompt": "<| file ext=.py filename=train_model.py source=github dstars=4 |>\n",
307
  "length": 64,
 
308
  "mode": "python"
309
  },
310
  "metadata-prediction": {
311
  "prompt": "<| file source=github ext=.py |>\nfrom setuptools import setup\nfrom setuptools_rust import Binding, RustExtension\n\nextras = {}\nextras[\"testing\"] = [\"pytest\", \"requests\", \"numpy\", \"datasets\"]\nextras[\"docs\"] = [\"sphinx\", \"sphinx_rtd_theme\", \"setuptools_rust\"]\n\nsetup(\n name=\"tokenizers\",\n version=\"0.11\",\n description=\"Fast and Customizable Tokenizers\",\n long_description=open(\"README.md\", \"r\", encoding=\"utf-8\").read(),\n)\n\n<|/ file filename=",
312
  "length": 1,
 
313
  "mode": "python"
314
  },
315
  "humaneval": {
316
  "prompt": "from typing import List, Optional\n\n\ndef longest(strings: List[str]) -> Optional[str]:\n \"\"\" Out of list of strings, return the longest one. Return the first one in case of multiple\n strings of the same length. Return None in case the input list is empty.\n >>> longest([])\n\n >>> longest(['a', 'b', 'c'])\n 'a'\n >>> longest(['a', 'bb', 'ccc'])\n 'ccc'\n \"\"\"\n",
 
317
  "length": 64,
318
  "mode": "python"
319
  },
@@ -392,6 +417,8 @@ function set_selection(data) {
392
  function select_example(name) {
393
  $("#length_slider").val(EXAMPLES[name]["length"]);
394
  $("#length_slider_output").text(EXAMPLES[name]["length"]);
 
 
395
  set_text(EXAMPLES[name]["prompt"])
396
  var mode = EXAMPLES[name]["mode"];
397
 
@@ -442,11 +469,11 @@ function convert_string_index_to_location(string_index, lines) {
442
  return null;
443
  }
444
 
445
- function get_infill_parts() {
446
  var lines = editor.getSession().doc.$lines;
447
  var lines_flat = join_lines(lines);
448
  parts = lines_flat.split(SPLIT_TOKEN)
449
- if (parts.length == 1) {
450
  window.alert('There are no infill masks, add some <infill> masks before requesting an infill')
451
  }
452
  return parts
@@ -477,7 +504,7 @@ function make_generate_listener(url) {
477
  temperature: $("#temp_slider").val(),
478
  extra_sentinel: $('#extra_sentinel_checkbox').is(":checked"),
479
  max_retries: $('#max_retries_slider').val(),
480
- parts: get_infill_parts(),
481
  prompt: editor.getSession().getValue(),
482
  }
483
  console.log("send_data:");
@@ -503,6 +530,11 @@ function make_generate_listener(url) {
503
  set_text(receive_data["text"]);
504
  set_selection(receive_data);
505
  $("#error").text("");
 
 
 
 
 
506
  } else {
507
  // set_text(data["prompt"])
508
  $("#error").text(receive_data["text"]);
@@ -522,8 +554,10 @@ function make_generate_listener(url) {
522
  const response = await fetch(`${url}?info=${encoded_data}`);
523
  if (response.status >= 400) {
524
  error(response.statusText);
 
 
 
525
  }
526
- response.json().then(success).catch(error).finally(complete);
527
  } catch (e) {
528
  error(e);
529
  } finally {
 
93
  color: red;
94
  width: 100%;
95
  }
96
+ #warning {
97
+ color: darkorange;
98
+ width: 100%;
99
+ }
100
  #examples span {
101
  margin-right: 1em;
102
  }
 
121
  <div id="about">
122
  <p>Refresh "Extend" will insert text at the end. "Infill" will replace all <infill> masks. (click add <infill> mask to add infill mask at the cursors or selections ) </p>
123
  <p id="examples">
124
+ <span style="font-weight: bold">Infill Examples:</span>
125
+ <span><a href='javascript:select_example("class");'>Class generation</a></span>
126
+ <span><a href='javascript:select_example("type-pred");'>Type prediction</a></span>
127
+ <span><a href='javascript:select_example("docstring");'>Function to docstring</a></span>
128
+ <span><a href='javascript:select_example("python-infill2");'>Docstring to function</a></span>
129
+ <br>
130
+ <span style="font-weight: bold">Extend Examples:</span>
131
+ <span><a href='javascript:select_example("javascript");'>JavaScript</a></span>
132
  <span><a href='javascript:select_example("javascript");'>JavaScript</a></span>
133
  <span><a href='javascript:select_example("jupyter");'>Jupyter</a></span>
134
  <span><a href='javascript:select_example("stackoverflow");'>StackOverflow</a></span>
135
  <span><a href='javascript:select_example("metadata-conditioning");'>Metadata Conditioning</a></span>
136
  <span><a href='javascript:select_example("metadata-prediction");'>Metadata Prediction</a></span>
137
+ <!-- <span><a href='javascript:select_example("humaneval");'>Docstring->Code</a></span> -->
138
  </div>
139
  </div>
140
  <div class="request">
 
142
  <div class="leftside">
143
  <div>
144
  <label>Response Length:</label>
145
+ <input type="range" value="64" min="16" max="256" step="16" class="slider"
146
  oninput="this.nextElementSibling.value = this.value" name="length"
147
  id='length_slider'>
148
  <output class='a' id="length_slider_output">64</output>
149
  <div>
150
  <label>Temperature:</label>
151
+ <input type="range" value="0.6" min="0.1" max="1.0" step="0.10" class="slider"
152
  oninput="this.nextElementSibling.value = this.value" name="temp"
153
  id='temp_slider'
154
  >
155
+ <output class='a' id="temp_slider_output">0.6</output>
156
  </div>
157
  <div>
158
  <!-- <input type="submit" value="Extend" id="extend-form-button"/> -->
 
180
  </div>
181
  </form>
182
  </div>
 
 
183
  Syntax:
184
  <select name="mode" id="mode">
185
  <option value="text">Text</option>
 
215
  </div>
216
  </div>
217
  <div id="error"></div>
218
+ <div id="warning"></div>
219
 
220
  <h3 id="debug-info">More Info</h3>
221
  <p>
 
242
 
243
  // examples for the user
244
  var EXAMPLES = {
 
 
 
 
 
245
  "python-infill2": {
246
  "prompt":
247
  `from collections import Counter
248
+ def <infill>
249
  """Count the number of occurrences of each word in the file."""
250
  <infill>
251
  `,
252
  "length": 64,
253
+ "temperature": 0.2,
254
  "mode": "python"
255
  },
256
 
257
  "type-pred": {
258
  "prompt":
259
+ `def count_words(filename: str) -> <infill>
260
  """Count the number of occurrences of each word in the file."""
261
  with open(filename, 'r') as f:
262
  word_counts = {}
 
269
  return word_counts
270
  `,
271
  "length": 4,
272
+ "temperature": 0.2,
273
  "mode": "python"
274
  },
275
  "docstring": {
 
288
  return word_counts
289
  `,
290
  "length": 32,
291
+ "temperature": 0.2,
292
+ "mode": "python"
293
+ },
294
+ "python": {
295
+ "prompt":
296
+ `<| file ext=.py |>
297
+ def count_words(filename):
298
+ """Count the number of occurrences of each word in the file"""`,
299
+ "length": 64,
300
+ "temperature": 0.6,
301
+ "mode": "python"
302
+ },
303
+ "class": {
304
+ "prompt": "<| file ext=.py |>\nclass Person:\n" + SPLIT_TOKEN + "\np = Person('Eren', 18, 'Male')",
305
+ "length": 64,
306
+ "temperature": 0.2,
307
  "mode": "python"
308
  },
309
  "javascript": {
310
  "prompt": "<| file ext=.js |>\n // is something really happening here",
311
  "length": 64,
312
+ "temperature": 0.6,
313
  "mode": "javascript"
314
  },
315
  "jupyter": {
316
  "prompt": "<| file ext=.ipynb:python |>\n<text>\nThis notebook demonstrates using scikit-learn to perform PCA.\n</text>\n<cell>",
317
  "length": 64,
318
+ "temperature": 0.6,
319
  "mode": "python"
320
  },
321
  "stackoverflow": {
322
  "prompt": "<| q tags=regex,html |>\nParsing HTML with regular expressions\nHow do I do this? Is it a good idea?\n<|/ q dscore=3 |>\n<| a dscore=4 |>",
323
  "length": 64,
324
+ "temperature": 0.6,
325
  "mode": "text"
326
  },
327
  "metadata-conditioning": {
328
  "prompt": "<| file ext=.py filename=train_model.py source=github dstars=4 |>\n",
329
  "length": 64,
330
+ "temperature": 0.6,
331
  "mode": "python"
332
  },
333
  "metadata-prediction": {
334
  "prompt": "<| file source=github ext=.py |>\nfrom setuptools import setup\nfrom setuptools_rust import Binding, RustExtension\n\nextras = {}\nextras[\"testing\"] = [\"pytest\", \"requests\", \"numpy\", \"datasets\"]\nextras[\"docs\"] = [\"sphinx\", \"sphinx_rtd_theme\", \"setuptools_rust\"]\n\nsetup(\n name=\"tokenizers\",\n version=\"0.11\",\n description=\"Fast and Customizable Tokenizers\",\n long_description=open(\"README.md\", \"r\", encoding=\"utf-8\").read(),\n)\n\n<|/ file filename=",
335
  "length": 1,
336
+ "temperature": 0.2,
337
  "mode": "python"
338
  },
339
  "humaneval": {
340
  "prompt": "from typing import List, Optional\n\n\ndef longest(strings: List[str]) -> Optional[str]:\n \"\"\" Out of list of strings, return the longest one. Return the first one in case of multiple\n strings of the same length. Return None in case the input list is empty.\n >>> longest([])\n\n >>> longest(['a', 'b', 'c'])\n 'a'\n >>> longest(['a', 'bb', 'ccc'])\n 'ccc'\n \"\"\"\n",
341
+ "temperature": 0.6,
342
  "length": 64,
343
  "mode": "python"
344
  },
 
417
  function select_example(name) {
418
  $("#length_slider").val(EXAMPLES[name]["length"]);
419
  $("#length_slider_output").text(EXAMPLES[name]["length"]);
420
+ $("#temp_slider").val(EXAMPLES[name]["temperature"]);
421
+ $("#temp_slider_output").text(EXAMPLES[name]["temperature"]);
422
  set_text(EXAMPLES[name]["prompt"])
423
  var mode = EXAMPLES[name]["mode"];
424
 
 
469
  return null;
470
  }
471
 
472
+ function get_infill_parts(warn_on_single) {
473
  var lines = editor.getSession().doc.$lines;
474
  var lines_flat = join_lines(lines);
475
  parts = lines_flat.split(SPLIT_TOKEN)
476
+ if (warn_on_single && parts.length == 1) {
477
  window.alert('There are no infill masks, add some <infill> masks before requesting an infill')
478
  }
479
  return parts
 
504
  temperature: $("#temp_slider").val(),
505
  extra_sentinel: $('#extra_sentinel_checkbox').is(":checked"),
506
  max_retries: $('#max_retries_slider').val(),
507
+ parts: get_infill_parts(url == "infill"),
508
  prompt: editor.getSession().getValue(),
509
  }
510
  console.log("send_data:");
 
530
  set_text(receive_data["text"]);
531
  set_selection(receive_data);
532
  $("#error").text("");
533
+ if (receive_data["message"] != "") {
534
+ $("#warning").text(receive_data["message"]);
535
+ } else {
536
+ $("#warning").text("");
537
+ }
538
  } else {
539
  // set_text(data["prompt"])
540
  $("#error").text(receive_data["text"]);
 
554
  const response = await fetch(`${url}?info=${encoded_data}`);
555
  if (response.status >= 400) {
556
  error(response.statusText);
557
+ complete();
558
+ } else {
559
+ response.json().then(success).catch(error).finally(complete);
560
  }
 
561
  } catch (e) {
562
  error(e);
563
  } finally {