lhoestq HF staff commited on
Commit
063480a
·
1 Parent(s): 1fa40b1

better error message

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -35,7 +35,7 @@ MAX_NUM_ROWS_TO_REWRITE = int(os.environ.get("MAX_NUM_ROWS_TO_REWRITE") or 1000)
35
  assert MAX_NUM_ROWS_TO_REWRITE in PARTIAL_SUFFIX, "allowed max num rows are 100, 1000, 10000, 100000 and 1000000"
36
 
37
  NUM_PARALLEL_CALLS = 10
38
- NUM_ROWS_PER_CALL = 5
39
  MAX_PROGRESS_UPDATES_PER_SECOND = 4
40
  REWRITE_DATASET_PREVIEW = (
41
  "A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. "
@@ -171,19 +171,23 @@ with gr.Blocks(css=css) as demo:
171
  while batch := list(islice(it, n)):
172
  yield batch
173
 
 
 
174
 
175
- def stream_reponse(messages: list[dict[str: str]], response_format=None) -> Iterator[str]:
176
  for _ in range(3):
177
  message = None
178
  try:
179
  for message in client.chat_completion(
180
  messages=messages,
181
- max_tokens=5000,
182
  stream=True,
183
  top_p=0.8,
184
  seed=42,
185
  response_format=response_format
186
  ):
 
 
187
  yield message.choices[0].delta.content
188
  except requests.exceptions.ConnectionError as e:
189
  if message:
@@ -217,7 +221,7 @@ with gr.Blocks(css=css) as demo:
217
  response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
218
  try:
219
  yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4, use_float=True)
220
- except ijson.IncompleteJSONError as e:
221
  print(f"{type(e).__name__}: {e}")
222
  print("Warning: Some rows were missing during ReWriting.")
223
 
@@ -389,14 +393,17 @@ with gr.Blocks(css=css) as demo:
389
 
390
  current = 0
391
  _last_time = time.time()
392
- for step in iflatmap_unordered(run, kwargs_iterable=[{"i": i} for i in range(num_parallel_calls)]):
393
- current += step
394
- if _last_time + 1 / MAX_PROGRESS_UPDATES_PER_SECOND < time.time():
395
- _last_time = time.time()
396
- yield {
397
- full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": current / total}),
398
- pretty_full_dataset_generation_output: gr.DataFrame(pd.DataFrame([row for rows in parallel_output_rows for row in rows]))
399
- }
 
 
 
400
  yield {
401
  full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": current / total}),
402
  pretty_full_dataset_generation_output: gr.DataFrame(pd.DataFrame([row for rows in parallel_output_rows for row in rows]))
 
35
  assert MAX_NUM_ROWS_TO_REWRITE in PARTIAL_SUFFIX, "allowed max num rows are 100, 1000, 10000, 100000 and 1000000"
36
 
37
  NUM_PARALLEL_CALLS = 10
38
+ NUM_ROWS_PER_CALL = 3
39
  MAX_PROGRESS_UPDATES_PER_SECOND = 4
40
  REWRITE_DATASET_PREVIEW = (
41
  "A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. "
 
171
  while batch := list(islice(it, n)):
172
  yield batch
173
 
174
+ class ContextTooLongError(ValueError):
175
+ pass
176
 
177
+ def stream_reponse(messages: list[dict[str: str]], response_format=None, max_tokens=5000) -> Iterator[str]:
178
  for _ in range(3):
179
  message = None
180
  try:
181
  for message in client.chat_completion(
182
  messages=messages,
183
+ max_tokens=max_tokens,
184
  stream=True,
185
  top_p=0.8,
186
  seed=42,
187
  response_format=response_format
188
  ):
189
+ if message is None or not message.choices or message.choices[0] is None or message.choices[0].delta is None or message.choices[0].delta.content is None:
190
+ raise ContextTooLongError(f"messages: {sum(len(message['content']) for message in messages)} chars, max_tokens: {max_tokens}")
191
  yield message.choices[0].delta.content
192
  except requests.exceptions.ConnectionError as e:
193
  if message:
 
221
  response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
222
  try:
223
  yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4, use_float=True)
224
+ except (ijson.IncompleteJSONError) as e:
225
  print(f"{type(e).__name__}: {e}")
226
  print("Warning: Some rows were missing during ReWriting.")
227
 
 
393
 
394
  current = 0
395
  _last_time = time.time()
396
+ try:
397
+ for step in iflatmap_unordered(run, kwargs_iterable=[{"i": i} for i in range(num_parallel_calls)]):
398
+ current += step
399
+ if _last_time + 1 / MAX_PROGRESS_UPDATES_PER_SECOND < time.time():
400
+ _last_time = time.time()
401
+ yield {
402
+ full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": current / total}),
403
+ pretty_full_dataset_generation_output: gr.DataFrame(pd.DataFrame([row for rows in parallel_output_rows for row in rows]))
404
+ }
405
+ except ContextTooLongError:
406
+ raise gr.Error("Input dataset has too long context for the model")
407
  yield {
408
  full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": current / total}),
409
  pretty_full_dataset_generation_output: gr.DataFrame(pd.DataFrame([row for rows in parallel_output_rows for row in rows]))