Spaces:
Sleeping
Sleeping
better error message
Browse files
app.py
CHANGED
@@ -35,7 +35,7 @@ MAX_NUM_ROWS_TO_REWRITE = int(os.environ.get("MAX_NUM_ROWS_TO_REWRITE") or 1000)
|
|
35 |
assert MAX_NUM_ROWS_TO_REWRITE in PARTIAL_SUFFIX, "allowed max num rows are 100, 1000, 10000, 100000 and 1000000"
|
36 |
|
37 |
NUM_PARALLEL_CALLS = 10
|
38 |
-
NUM_ROWS_PER_CALL =
|
39 |
MAX_PROGRESS_UPDATES_PER_SECOND = 4
|
40 |
REWRITE_DATASET_PREVIEW = (
|
41 |
"A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. "
|
@@ -171,19 +171,23 @@ with gr.Blocks(css=css) as demo:
|
|
171 |
while batch := list(islice(it, n)):
|
172 |
yield batch
|
173 |
|
|
|
|
|
174 |
|
175 |
-
def stream_reponse(messages: list[dict[str: str]], response_format=None) -> Iterator[str]:
|
176 |
for _ in range(3):
|
177 |
message = None
|
178 |
try:
|
179 |
for message in client.chat_completion(
|
180 |
messages=messages,
|
181 |
-
max_tokens=
|
182 |
stream=True,
|
183 |
top_p=0.8,
|
184 |
seed=42,
|
185 |
response_format=response_format
|
186 |
):
|
|
|
|
|
187 |
yield message.choices[0].delta.content
|
188 |
except requests.exceptions.ConnectionError as e:
|
189 |
if message:
|
@@ -217,7 +221,7 @@ with gr.Blocks(css=css) as demo:
|
|
217 |
response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
|
218 |
try:
|
219 |
yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4, use_float=True)
|
220 |
-
except ijson.IncompleteJSONError as e:
|
221 |
print(f"{type(e).__name__}: {e}")
|
222 |
print("Warning: Some rows were missing during ReWriting.")
|
223 |
|
@@ -389,14 +393,17 @@ with gr.Blocks(css=css) as demo:
|
|
389 |
|
390 |
current = 0
|
391 |
_last_time = time.time()
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
_last_time
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
|
|
|
|
|
|
400 |
yield {
|
401 |
full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": current / total}),
|
402 |
pretty_full_dataset_generation_output: gr.DataFrame(pd.DataFrame([row for rows in parallel_output_rows for row in rows]))
|
|
|
35 |
assert MAX_NUM_ROWS_TO_REWRITE in PARTIAL_SUFFIX, "allowed max num rows are 100, 1000, 10000, 100000 and 1000000"
|
36 |
|
37 |
NUM_PARALLEL_CALLS = 10
|
38 |
+
NUM_ROWS_PER_CALL = 3
|
39 |
MAX_PROGRESS_UPDATES_PER_SECOND = 4
|
40 |
REWRITE_DATASET_PREVIEW = (
|
41 |
"A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. "
|
|
|
171 |
while batch := list(islice(it, n)):
|
172 |
yield batch
|
173 |
|
174 |
+
class ContextTooLongError(ValueError):
|
175 |
+
pass
|
176 |
|
177 |
+
def stream_reponse(messages: list[dict[str: str]], response_format=None, max_tokens=5000) -> Iterator[str]:
|
178 |
for _ in range(3):
|
179 |
message = None
|
180 |
try:
|
181 |
for message in client.chat_completion(
|
182 |
messages=messages,
|
183 |
+
max_tokens=max_tokens,
|
184 |
stream=True,
|
185 |
top_p=0.8,
|
186 |
seed=42,
|
187 |
response_format=response_format
|
188 |
):
|
189 |
+
if message is None or not message.choices or message.choices[0] is None or message.choices[0].delta is None or message.choices[0].delta.content is None:
|
190 |
+
raise ContextTooLongError(f"messages: {sum(len(message['content']) for message in messages)} chars, max_tokens: {max_tokens}")
|
191 |
yield message.choices[0].delta.content
|
192 |
except requests.exceptions.ConnectionError as e:
|
193 |
if message:
|
|
|
221 |
response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "items": format, "minItems": len(rows), "maxItems": len(rows)}}, "required": ["data"]}}
|
222 |
try:
|
223 |
yield from ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4, use_float=True)
|
224 |
+
except (ijson.IncompleteJSONError) as e:
|
225 |
print(f"{type(e).__name__}: {e}")
|
226 |
print("Warning: Some rows were missing during ReWriting.")
|
227 |
|
|
|
393 |
|
394 |
current = 0
|
395 |
_last_time = time.time()
|
396 |
+
try:
|
397 |
+
for step in iflatmap_unordered(run, kwargs_iterable=[{"i": i} for i in range(num_parallel_calls)]):
|
398 |
+
current += step
|
399 |
+
if _last_time + 1 / MAX_PROGRESS_UPDATES_PER_SECOND < time.time():
|
400 |
+
_last_time = time.time()
|
401 |
+
yield {
|
402 |
+
full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": current / total}),
|
403 |
+
pretty_full_dataset_generation_output: gr.DataFrame(pd.DataFrame([row for rows in parallel_output_rows for row in rows]))
|
404 |
+
}
|
405 |
+
except ContextTooLongError:
|
406 |
+
raise gr.Error("Input dataset has too long context for the model")
|
407 |
yield {
|
408 |
full_dataset_generation_label: gr.Label({f"⚙️ ReWriting {dataset}": current / total}),
|
409 |
pretty_full_dataset_generation_output: gr.DataFrame(pd.DataFrame([row for rows in parallel_output_rows for row in rows]))
|