ranWang commited on
Commit
3b33e85
1 Parent(s): ef99e2b

Use the API to complete the code

Browse files
Files changed (2) hide show
  1. README.md +6 -6
  2. app.py +725 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: Math Olympiad Solver By Api
3
- emoji: 📊
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.37.2
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Math Olympiad Solver
3
+ emoji: ♾️
4
+ colorFrom: yellow
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
app.py ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import login
3
+
4
+ import re
5
+
6
+ # from vllm import LLM, SamplingParams
7
+ import pandas as pd
8
+ from collections import Counter
9
+ from datasets import load_dataset, Dataset, concatenate_datasets
10
+ from dataclasses import dataclass
11
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError
12
+ import os
13
+ from typing import Dict, Any, List
14
+
15
+ # code execution
16
+ import os
17
+ import re
18
+ import signal
19
+ import subprocess
20
+ import tempfile
21
+ from contextlib import contextmanager
22
+ from typing import Tuple
23
+ from transformers import PreTrainedTokenizer, set_seed
24
+ import torch
25
+ from tqdm import tqdm
26
+ import time
27
+ from sympy import N, simplify
28
+ from sympy.parsing.latex import parse_latex
29
+ import random
30
+ from pathlib import Path
31
+ from openai import OpenAI
32
+
33
+ client = OpenAI(
34
+ base_url="https://ji0rhe7rvh6wrfmq.us-east-1.aws.endpoints.huggingface.cloud/v1/",
35
+ api_key=os.environ.get("HF_TOKEN"),
36
+ )
37
+
38
+
39
+ @dataclass
40
+ class Config:
41
+ model_id: str # SELECT MODEL
42
+ revision: str # SELECT REVISION
43
+
44
+ # Append an optional system prompt to each problem
45
+ system_prompt: str
46
+
47
+ # Number of samples to generate per problem
48
+ num_samples: int
49
+ num_generations: int
50
+ # Generation parameters
51
+ do_sample: bool
52
+ temperature: float
53
+ top_p: float
54
+ top_k: int
55
+ max_new_tokens: int
56
+ restart_on_fail: bool
57
+
58
+ # Enable 4-bit quantization
59
+ is_quantized: bool
60
+
61
+ # Run on train or test data?
62
+ is_submission: bool = True if os.getenv("KAGGLE_IS_COMPETITION_RERUN") else False
63
+ validation_set: str = "kaggle-validation-set-medium"
64
+
65
+ notebook_time_limit: int = 9 * 60 * 60 - 15 * 60 # 9 hours - 15 minute buffer
66
+
67
+ # Debug by solving only the first problem
68
+ debug: bool = False
69
+
70
+ # Push solutions to the Hub
71
+ push_to_hub: bool = False
72
+
73
+
74
+ class PythonREPL:
75
+ def __init__(self, timeout=5):
76
+ self.timeout = timeout
77
+
78
+ def execute(self, query: str) -> Tuple[bool, str]:
79
+ query = "import math\nimport numpy as np\nimport sympy as sp\n" + query
80
+ query = query.strip().split("\n")
81
+ if "print(" not in query[-1]:
82
+ if "#" in query[-1]:
83
+ query[-1] = query[-1].split("#")[0]
84
+ query[-1] = "print(" + query[-1] + ")"
85
+ query = "\n".join(query)
86
+
87
+ with tempfile.TemporaryDirectory() as temp_dir:
88
+ temp_file_path = os.path.join(temp_dir, "tmp.py")
89
+
90
+ with open(temp_file_path, "w") as f:
91
+ f.write(query)
92
+
93
+ result = subprocess.run(
94
+ ["python3", temp_file_path],
95
+ capture_output=True,
96
+ check=False,
97
+ text=True,
98
+ timeout=self.timeout,
99
+ )
100
+
101
+ if result.returncode == 0:
102
+ output = result.stdout
103
+ return True, output.strip()
104
+ else:
105
+ error_msg = result.stderr.strip()
106
+ msgs = error_msg.split("\n")
107
+ new_msgs = []
108
+ want_next = False
109
+ for m in msgs:
110
+ if "Traceback" in m:
111
+ new_msgs.append(m)
112
+ elif m == msgs[-1]:
113
+ new_msgs.append(m)
114
+ elif temp_file_path in m:
115
+ st = m.index('"/') + 1 if '"/' in m else 0
116
+ ed = m.index(temp_file_path) + 1 if temp_file_path in m else None
117
+ clr = m[st:ed] if not ed else m[st:]
118
+ m = m.replace(clr, "")
119
+ new_msgs.append(m)
120
+ want_next = True
121
+ elif want_next:
122
+ new_msgs.append(m)
123
+ want_next = False
124
+ error_msg = "\n".join(new_msgs)
125
+ return False, error_msg.strip()
126
+
127
+ def __call__(self, query: str) -> Tuple[bool, str]:
128
+ with ThreadPoolExecutor() as executor:
129
+ future = executor.submit(self.execute, query)
130
+ try:
131
+ return future.result(timeout=self.timeout)
132
+ except TimeoutError:
133
+ return False, f"Timed out after {self.timeout} seconds."
134
+
135
+
136
+ def execute_completion(
137
+ executor: PythonREPL,
138
+ completion: str,
139
+ return_status: bool = False,
140
+ last_code_block: bool = False,
141
+ ) -> str | Tuple[str, bool]:
142
+ # executions = ["!" + code for code in re.findall(r"```bash(.*?)```", completion, re.DOTALL) if "!" not in code]
143
+ executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
144
+
145
+ if len(executions) == 0: # directly return cot result
146
+ return completion, False if return_status else completion
147
+ else:
148
+ if last_code_block:
149
+ executions = [executions[-1]]
150
+
151
+ # Python
152
+ execution_outputs = []
153
+ successes = []
154
+ for code in executions:
155
+ success = False
156
+
157
+ if "subprocess" in code:
158
+ output = "subprocess is not allowed"
159
+ execution_outputs.append(output)
160
+ successes.append(success)
161
+ continue
162
+
163
+ if "venv" in code:
164
+ output = "venv is not allowed"
165
+ execution_outputs.append(output)
166
+ successes.append(success)
167
+ continue
168
+
169
+ try:
170
+ success, output = executor(code)
171
+ except TimeoutError as e:
172
+ print("time out")
173
+ output = e
174
+
175
+ if not success and not return_status:
176
+ output = ""
177
+
178
+ execution_outputs.append(output)
179
+ successes.append(success)
180
+
181
+ output = str(execution_outputs[-1]).strip()
182
+ success = successes[-1]
183
+
184
+ if return_status:
185
+ return output, success
186
+ else:
187
+ return output
188
+
189
+
190
+ def postprocess_completion(
191
+ text: str, return_status: bool = False, last_code_block=False, timeout=5
192
+ ) -> str | Tuple[str, bool]:
193
+ executor = PythonREPL(timeout=timeout)
194
+
195
+ result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block)
196
+ del executor
197
+
198
+ return result
199
+
200
+
201
+ def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]:
202
+ return prompt.format(example["prompt"], "{}")
203
+
204
+
205
+ def last_boxed_only_string(string):
206
+ """
207
+ Extracts the last LaTeX boxed or framed expression from a string.
208
+ Args:
209
+ string (str): The input string containing LaTeX expressions.
210
+ Returns:
211
+ str or None: The last boxed or framed expression, if found;
212
+ otherwise, None.
213
+ """
214
+
215
+ idx = string.rfind("\\boxed")
216
+ if idx < 0:
217
+ idx = string.rfind("\\fbox")
218
+ if idx < 0:
219
+ return None
220
+
221
+ i = idx
222
+ right_brace_idx = None
223
+ num_left_braces_open = 0
224
+ while i < len(string):
225
+ if string[i] == "{":
226
+ num_left_braces_open += 1
227
+ if string[i] == "}":
228
+ num_left_braces_open -= 1
229
+ if num_left_braces_open == 0:
230
+ right_brace_idx = i
231
+ break
232
+ i += 1
233
+
234
+ if right_brace_idx is None:
235
+ retval = None
236
+ else:
237
+ retval = string[idx : right_brace_idx + 1]
238
+
239
+ return retval
240
+
241
+
242
+ def remove_boxed(s):
243
+ """
244
+ Removes the LaTeX boxed command, returning the content inside the braces.
245
+ Args:
246
+ s (str): The string containing a LaTeX boxed expression.
247
+ Returns:
248
+ str or None: The content inside the boxed command, if valid;
249
+ otherwise, None.
250
+ """
251
+
252
+ left = "\\boxed{"
253
+ try:
254
+ assert s[: len(left)] == left
255
+ assert s[-1] == "}"
256
+ length = len(left)
257
+ return s[length:-1]
258
+ except Exception:
259
+ return None
260
+
261
+
262
+ def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
263
+ """
264
+ Extracts the answer from a LaTeX boxed expression within
265
+ a prediction string.
266
+ Args:
267
+ pred_str (str): The string containing one or more LaTeX
268
+ boxed expressions.
269
+ strip_double_curly_brace (bool): If True, removes an additional
270
+ layer of braces.
271
+ Returns:
272
+ str or None: The extracted answer, if any; otherwise, None.
273
+ """
274
+
275
+ boxed_str = last_boxed_only_string(pred_str)
276
+ if boxed_str is None:
277
+ return None
278
+ answer = remove_boxed(boxed_str)
279
+ if answer is None:
280
+ return None
281
+ if strip_double_curly_brace:
282
+ match = re.match("^\{(.*)\}$", answer) # noqa: W605
283
+ if match:
284
+ answer = match.group(1)
285
+ return answer
286
+
287
+
288
+ def normalize_final_answer(final_answer: str) -> str:
289
+ """
290
+ Normalizes a final answer string by removing or replacing various LaTeX
291
+ and text elements.
292
+ Args:
293
+ final_answer (str): The answer string to normalize.
294
+ Returns:
295
+ str: The normalized answer string.
296
+ """
297
+
298
+ match = re.search(r"(.*?)Problem:", final_answer, flags=re.S)
299
+ if match:
300
+ final_answer = match.group(1) # 返回匹配的第一部分,即"Problem"之前的所有文本
301
+ """Normalize a final answer to a quantitative reasoning question."""
302
+ # final_answer = final_answer.split('=')[-1]
303
+ SUBSTITUTIONS = [
304
+ ("an ", ""),
305
+ ("a ", ""),
306
+ (".$", "$"),
307
+ ("\\$", ""),
308
+ (r"\ ", ""),
309
+ (" ", ""),
310
+ ("mbox", "text"),
311
+ (",\\text{and}", ","),
312
+ ("\\text{and}", ","),
313
+ ("\\text{m}", "\\text{}"),
314
+ ("\\le", "<"),
315
+ ]
316
+ REMOVED_EXPRESSIONS = [
317
+ "square",
318
+ "ways",
319
+ "integers",
320
+ "dollars",
321
+ "mph",
322
+ "inches",
323
+ "ft",
324
+ "hours",
325
+ "km",
326
+ "units",
327
+ "\\ldots",
328
+ "sue",
329
+ "points",
330
+ "feet",
331
+ "minutes",
332
+ "digits",
333
+ "cents",
334
+ "degrees",
335
+ "cm",
336
+ "gm",
337
+ "pounds",
338
+ "meters",
339
+ "meals",
340
+ "edges",
341
+ "students",
342
+ "childrentickets",
343
+ "multiples",
344
+ "\\text{s}",
345
+ "\\text{.}",
346
+ "\\text{\ns}",
347
+ "\\text{}^2",
348
+ "\\text{}^3",
349
+ "\\text{\n}",
350
+ "\\text{}",
351
+ r"\mathrm{th}",
352
+ r"^\circ",
353
+ r"^{\circ}",
354
+ r"\;",
355
+ r",\!",
356
+ "{,}",
357
+ '"',
358
+ "\\dots",
359
+ "\n",
360
+ "\r",
361
+ "\f",
362
+ "\%",
363
+ ]
364
+ for before, after in SUBSTITUTIONS:
365
+ final_answer = final_answer.replace(before, after)
366
+ for expr in REMOVED_EXPRESSIONS:
367
+ final_answer = final_answer.replace(expr, "")
368
+
369
+ # Extract answer that is in LaTeX math, is bold,
370
+ # is surrounded by a box, etc.
371
+ final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
372
+ final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
373
+ final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
374
+ final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
375
+ assert "\n" not in final_answer
376
+ assert "\r" not in final_answer
377
+ assert "\f" not in final_answer
378
+ if len(re.findall(r"finalansweris(.*)", final_answer)) > 0:
379
+ final_answer = re.findall(r"finalansweris(.*)", final_answer)[-1]
380
+
381
+ if len(re.findall(r"answer?is:?(.*)", final_answer)) > 0:
382
+ final_answer = re.findall(r"answer?is:?(.*)", final_answer)[-1]
383
+
384
+ if len(re.findall(r"oxed\{(.*?)\}", final_answer)) > 0:
385
+ final_answer = re.findall(r"oxed\{(.*?)\}", final_answer)[-1]
386
+
387
+ if len(re.findall(r"\$(.*?)\$", final_answer)) > 0:
388
+ final_answer = re.findall(r"\$(.*?)\$", final_answer)[-1]
389
+ final_answer = final_answer.strip()
390
+ if "rac" in final_answer and "\\frac" not in final_answer:
391
+ final_answer = final_answer.replace("rac", "\\frac")
392
+
393
+ final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
394
+ final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
395
+ final_answer = final_answer.replace("$", "")
396
+
397
+ if final_answer.replace(",", "").isdigit():
398
+ final_answer = final_answer.replace(",", "")
399
+
400
+ return final_answer
401
+
402
+
403
+ def naive_parse(answer: str) -> str:
404
+ """
405
+ Extracts and returns the numeric digits from the input string, processing them in reverse order
406
+ until a non-numeric character is encountered after encountering the first numeric character.
407
+
408
+ Args:
409
+ answer (str): The input string to parse.
410
+
411
+ Returns:
412
+ str: A string consisting of the numeric digits extracted from the input, in their original order.
413
+
414
+ Example:
415
+ >>> naive_parse("abc123def")
416
+ '123'
417
+ >>> naive_parse("def456ghi")
418
+ '456'
419
+ >>> naive_parse("no numbers here")
420
+ ''
421
+ """
422
+ out = []
423
+ start = False
424
+ end = False
425
+ for l in reversed(list(answer)):
426
+ if l in "0123456789" and not end:
427
+ start = True
428
+ out.append(l)
429
+ else:
430
+ if start:
431
+ end = True
432
+
433
+ out = reversed(out)
434
+ return "".join(out)
435
+
436
+
437
+ def validate_answer_is_numeric(x: str | int | float) -> int:
438
+ FLOAT_TOLERANCE = 0.2
439
+ try:
440
+ x = round(float(x))
441
+ f = float(x)
442
+ if abs(x - f) > FLOAT_TOLERANCE:
443
+ x = -1
444
+ except Exception:
445
+ x = -1
446
+ return x
447
+
448
+
449
+ def get_majority_vote(responses: List[int]) -> int:
450
+ if len(responses) < 1:
451
+ return 0
452
+ else:
453
+ c = Counter(responses)
454
+ value, count = c.most_common()[0]
455
+ return value
456
+
457
+
458
+ def filter_answers(answers: List[str]) -> List[int]:
459
+ formatted_answers = [validate_answer_is_numeric(a) for a in answers]
460
+
461
+ # Filter for non-negative answers
462
+ formatted_answers = [a for a in formatted_answers if a >= 0]
463
+ # Compute modulo
464
+ formatted_answers = [a % 1_000 for a in formatted_answers]
465
+ # less than 2.1 billion or cannot convert to C int (32-bit)
466
+ formatted_answers = [a for a in formatted_answers if a <= 999]
467
+ return formatted_answers
468
+
469
+
470
+ def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool:
471
+ def do_answers_match(ref_answer: str, model_answer: str) -> bool:
472
+ ref_sympy = parse_latex(ref_answer)
473
+ model_sympy = parse_latex(model_answer)
474
+ diff = simplify(ref_sympy - model_sympy)
475
+ return True if -1e-12 < N(diff) < 1e-12 or diff.is_zero else False
476
+
477
+ try:
478
+ result = do_answers_match(ref_answer, model_answer)
479
+ return result
480
+ except Exception as e:
481
+ print(e)
482
+ return False
483
+
484
+
485
+ def check_string_match(ref_answer: str, model_answer: str) -> bool:
486
+ try:
487
+ return ref_answer == model_answer
488
+ except Exception as e:
489
+ print(e)
490
+ return False
491
+
492
+
493
+ def check_answer(ref_answer: str, model_answer: str) -> bool:
494
+ # check if strings are the same
495
+ correct = check_string_match(ref_answer, model_answer)
496
+ if correct:
497
+ return True
498
+
499
+ # use the sympy library to check if the expressions are the same
500
+ correct = check_sympy_equivalence(ref_answer, model_answer)
501
+ if correct:
502
+ return True
503
+
504
+ return False
505
+
506
+
507
+ debug = False
508
+ model_id = "Numina-Math-7B"
509
+ revision = "main"
510
+ system_prompt = "{}"
511
+ validation_set = "kaggle-validation-set-medium"
512
+ is_submission = True
513
+ num_samples = 4
514
+ num_generations = 4
515
+ temperature = 0.8
516
+ is_quantized = False
517
+ restart_on_fail = False
518
+ top_p = 1.0
519
+ top_k = 0
520
+ max_new_tokens = 2048
521
+ # Papermill related variables
522
+ push_to_hub = False
523
+ notebook_name = ""
524
+
525
+ config = Config(
526
+ debug=debug,
527
+ push_to_hub=push_to_hub,
528
+ model_id=model_id,
529
+ revision=revision,
530
+ system_prompt=system_prompt,
531
+ validation_set=validation_set,
532
+ is_quantized=is_quantized,
533
+ restart_on_fail=restart_on_fail,
534
+ is_submission=is_submission,
535
+ num_samples=num_samples,
536
+ num_generations=num_generations,
537
+ do_sample=True,
538
+ temperature=temperature,
539
+ top_p=top_p,
540
+ top_k=top_k,
541
+ max_new_tokens=max_new_tokens,
542
+ )
543
+ print(f"=== Running submission with config ===\n\n{config}")
544
+
545
+
546
+ def generate(message):
547
+ chat_completion = client.chat.completions.create(
548
+ model="tgi",
549
+ messages=message,
550
+ stream=True,
551
+ max_tokens=1024,
552
+ stop=["```output\n"],
553
+ temperature=temperature,
554
+ )
555
+
556
+ for message in chat_completion:
557
+ yield message.choices[0].delta.content
558
+
559
+
560
+ def get_majority_text(data):
561
+ from collections import Counter
562
+
563
+ # Count the frequency of each answer in model_answers
564
+ answer_counts = Counter(data["model_answers"])
565
+
566
+ # Find the majority response
567
+ majority_response = answer_counts.most_common(1)[0][0]
568
+
569
+ # Find the index of the first occurrence of the majority response
570
+ majority_index = data["model_answers"].index(majority_response)
571
+
572
+ # Return the corresponding text in gen_texts
573
+ return data["gen_texts"][majority_index]
574
+
575
+
576
+ def extract_solution(text):
577
+ # Split the text at "### Solution:"
578
+ parts = text.split("### Solution:", 1)
579
+ if len(parts) > 1:
580
+ # Return everything after "### Solution:"
581
+ return parts[1].strip()
582
+ else:
583
+ # Return an empty string if "### Solution:" is not found
584
+ return ""
585
+
586
+
587
+ def process_code(
588
+ example: Dict[str, Any],
589
+ config: Config,
590
+ restart_on_fail: bool = False,
591
+ last_step: bool = False,
592
+ ) -> Dict[str, Any]:
593
+ gen_text = example["gen_texts"]
594
+ num_python_blocks = len(re.findall(r"```python(.*?)```", gen_text, re.DOTALL))
595
+
596
+ if num_python_blocks == 0:
597
+ if restart_on_fail:
598
+ print("no code has ever been generated, RESTARTING")
599
+ # reset the text to the original
600
+ example["gen_texts"] = example["text"]
601
+ else:
602
+ print("no code has ever been generated, STOP")
603
+ example["should_prune"] = True
604
+ example["has_code"] = False
605
+ return example
606
+
607
+ if gen_text[-10:] != "```output\n" and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]):
608
+ num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL))
609
+ if num_output_blocks == 0:
610
+ print("the model hallucinated the code answer")
611
+ example["should_prune"] = True
612
+ return example
613
+
614
+ if "boxed" in gen_text[-100:]:
615
+ try:
616
+ answer = normalize_final_answer(extract_boxed_answer(gen_text[-100:]))
617
+ except Exception:
618
+ answer = "-1"
619
+ else:
620
+ answer = normalize_final_answer(gen_text[-100:])
621
+
622
+ example["model_answers"] = answer
623
+ if not config.is_submission:
624
+ example["corrects"] = check_answer(example["ground_truth"], answer)
625
+ example["should_prune"] = True
626
+ print("Answer is: ", answer, example["ground_truth"], example["corrects"])
627
+ return example
628
+
629
+ if last_step:
630
+ # no point in continuing if we are at the last step
631
+ return example
632
+
633
+ if gen_text[-10:] != "```output\n":
634
+ # something else has gone wrong with the generation
635
+ print("warning: output block not found: ", gen_text[-40:])
636
+ if restart_on_fail:
637
+ example["gen_texts"] = example["text"]
638
+ else:
639
+ example["should_prune"] = True
640
+ return example
641
+
642
+ code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True)
643
+ # add the code result for the next round of generation
644
+ TRUNCATION_LIMIT = 200
645
+ if len(code_result) > TRUNCATION_LIMIT:
646
+ code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)"
647
+ example["gen_texts"] = gen_text + f"{code_result}\n```"
648
+
649
+ return example
650
+
651
+
652
+ # load the vllm instance and set sampling parameters
653
+ # vllm = build_vllm(config)
654
+
655
+
656
+ def solve_problem(problem, temperature, progress=gr.Progress()):
657
+ problem = apply_template({"prompt": problem}, prompt=config.system_prompt)
658
+ print(f"Problem: {problem}")
659
+
660
+ sample = {
661
+ "problem": problem, # not used for the submission TODO Remove
662
+ "ground_truth": "unknown", # not used for the submission TODO Remove
663
+ "text": "### Solution:\n",
664
+ "gen_texts": "### Solution:\n", # used to store all the generated text
665
+ "should_prune": False,
666
+ "problem_index": -1, # not used for the submission TODO Remove
667
+ "model_answers": "-1",
668
+ "has_code": True,
669
+ "corrects": False, # not used for the submission TODO Remove
670
+ }
671
+
672
+ for step in progress.tqdm(
673
+ range(config.num_generations), desc="Generating candidates"
674
+ ): # Depth of the tree (e.g. 6 steps = 5 code blocks)
675
+
676
+ step_reponse = sample["gen_texts"]
677
+
678
+ messages = [
679
+ {"role": "user", "content": sample["problem"]},
680
+ {"role": "assistant", "content": sample["gen_texts"]},
681
+ ]
682
+
683
+ for reponse_message in generate(messages, temperature):
684
+ if reponse_message is not None:
685
+ step_reponse += reponse_message
686
+ yield step_reponse
687
+
688
+ sample["gen_texts"] = step_reponse
689
+
690
+ # TODO: Maybe it should just return the result of running the code
691
+ sample = process_code(
692
+ sample,
693
+ config=config,
694
+ restart_on_fail=config.restart_on_fail,
695
+ last_step=(step == (config.num_generations - 1)),
696
+ )
697
+ sample["gen_texts"] = sample["gen_texts"] + "\n"
698
+
699
+ run_code_reponse = sample["gen_texts"].replace(step_reponse, "")
700
+
701
+ for output_mseeage in run_code_reponse:
702
+ if output_mseeage is not None:
703
+ step_reponse += output_mseeage
704
+ yield step_reponse
705
+
706
+ if sample["should_prune"]:
707
+ break
708
+
709
+ yield sample["gen_texts"]
710
+
711
+
712
+ with gr.Blocks() as demo:
713
+ with gr.Row():
714
+ inp = gr.Textbox(placeholder="Problem", label="Problem", lines=5)
715
+ with gr.Accordion("Advanced Options", open=False):
716
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.1, label="Temperature")
717
+ with gr.Row():
718
+ out = gr.Markdown()
719
+
720
+ btn = gr.Button("Run")
721
+ btn.click(fn=solve_problem, inputs=[inp, temperature], outputs=out)
722
+
723
+
724
+ if __name__ == "__main__":
725
+ demo.queue(default_concurrency_limit=5).launch()