import gradio as gr from huggingface_hub import login import re # from vllm import LLM, SamplingParams import pandas as pd from collections import Counter from datasets import load_dataset, Dataset, concatenate_datasets from dataclasses import dataclass from concurrent.futures import ThreadPoolExecutor, TimeoutError import os from typing import Dict, Any, List # code execution import os import re import signal import subprocess import tempfile from contextlib import contextmanager from typing import Tuple from transformers import PreTrainedTokenizer, set_seed import torch from tqdm import tqdm import time from sympy import N, simplify from sympy.parsing.latex import parse_latex import random from pathlib import Path from openai import OpenAI client = OpenAI( base_url="https://ji0rhe7rvh6wrfmq.us-east-1.aws.endpoints.huggingface.cloud/v1/", api_key=os.environ.get("HF_TOKEN"), ) @dataclass class Config: model_id: str # SELECT MODEL revision: str # SELECT REVISION # Append an optional system prompt to each problem system_prompt: str # Number of samples to generate per problem num_samples: int num_generations: int # Generation parameters do_sample: bool temperature: float top_p: float top_k: int max_new_tokens: int restart_on_fail: bool # Enable 4-bit quantization is_quantized: bool # Run on train or test data? is_submission: bool = True if os.getenv("KAGGLE_IS_COMPETITION_RERUN") else False validation_set: str = "kaggle-validation-set-medium" notebook_time_limit: int = 9 * 60 * 60 - 15 * 60 # 9 hours - 15 minute buffer # Debug by solving only the first problem debug: bool = False # Push solutions to the Hub push_to_hub: bool = False class PythonREPL: def __init__(self, timeout=5): self.timeout = timeout def execute(self, query: str) -> Tuple[bool, str]: query = "import math\nimport numpy as np\nimport sympy as sp\n" + query query = query.strip().split("\n") if "print(" not in query[-1]: if "#" in query[-1]: query[-1] = query[-1].split("#")[0] query[-1] = "print(" + query[-1] + ")" query = "\n".join(query) with tempfile.TemporaryDirectory() as temp_dir: temp_file_path = os.path.join(temp_dir, "tmp.py") with open(temp_file_path, "w") as f: f.write(query) result = subprocess.run( ["python3", temp_file_path], capture_output=True, check=False, text=True, timeout=self.timeout, ) if result.returncode == 0: output = result.stdout return True, output.strip() else: error_msg = result.stderr.strip() msgs = error_msg.split("\n") new_msgs = [] want_next = False for m in msgs: if "Traceback" in m: new_msgs.append(m) elif m == msgs[-1]: new_msgs.append(m) elif temp_file_path in m: st = m.index('"/') + 1 if '"/' in m else 0 ed = m.index(temp_file_path) + 1 if temp_file_path in m else None clr = m[st:ed] if not ed else m[st:] m = m.replace(clr, "") new_msgs.append(m) want_next = True elif want_next: new_msgs.append(m) want_next = False error_msg = "\n".join(new_msgs) return False, error_msg.strip() def __call__(self, query: str) -> Tuple[bool, str]: with ThreadPoolExecutor() as executor: future = executor.submit(self.execute, query) try: return future.result(timeout=self.timeout) except TimeoutError: return False, f"Timed out after {self.timeout} seconds." def execute_completion( executor: PythonREPL, completion: str, return_status: bool = False, last_code_block: bool = False, ) -> str | Tuple[str, bool]: # executions = ["!" + code for code in re.findall(r"```bash(.*?)```", completion, re.DOTALL) if "!" not in code] executions = re.findall(r"```python(.*?)```", completion, re.DOTALL) if len(executions) == 0: # directly return cot result return completion, False if return_status else completion else: if last_code_block: executions = [executions[-1]] # Python execution_outputs = [] successes = [] for code in executions: success = False if "subprocess" in code: output = "subprocess is not allowed" execution_outputs.append(output) successes.append(success) continue if "venv" in code: output = "venv is not allowed" execution_outputs.append(output) successes.append(success) continue try: success, output = executor(code) except TimeoutError as e: print("time out") output = e if not success and not return_status: output = "" execution_outputs.append(output) successes.append(success) output = str(execution_outputs[-1]).strip() success = successes[-1] if return_status: return output, success else: return output def postprocess_completion( text: str, return_status: bool = False, last_code_block=False, timeout=5 ) -> str | Tuple[str, bool]: executor = PythonREPL(timeout=timeout) result = execute_completion(executor, text, return_status=return_status, last_code_block=last_code_block) del executor return result def apply_template(example: Dict[str, Any], prompt: str) -> Dict[str, Any]: return prompt.format(example["prompt"], "{}") def last_boxed_only_string(string): """ Extracts the last LaTeX boxed or framed expression from a string. Args: string (str): The input string containing LaTeX expressions. Returns: str or None: The last boxed or framed expression, if found; otherwise, None. """ idx = string.rfind("\\boxed") if idx < 0: idx = string.rfind("\\fbox") if idx < 0: return None i = idx right_brace_idx = None num_left_braces_open = 0 while i < len(string): if string[i] == "{": num_left_braces_open += 1 if string[i] == "}": num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i break i += 1 if right_brace_idx is None: retval = None else: retval = string[idx : right_brace_idx + 1] return retval def remove_boxed(s): """ Removes the LaTeX boxed command, returning the content inside the braces. Args: s (str): The string containing a LaTeX boxed expression. Returns: str or None: The content inside the boxed command, if valid; otherwise, None. """ left = "\\boxed{" try: assert s[: len(left)] == left assert s[-1] == "}" length = len(left) return s[length:-1] except Exception: return None def extract_boxed_answer(pred_str, strip_double_curly_brace=False): """ Extracts the answer from a LaTeX boxed expression within a prediction string. Args: pred_str (str): The string containing one or more LaTeX boxed expressions. strip_double_curly_brace (bool): If True, removes an additional layer of braces. Returns: str or None: The extracted answer, if any; otherwise, None. """ boxed_str = last_boxed_only_string(pred_str) if boxed_str is None: return None answer = remove_boxed(boxed_str) if answer is None: return None if strip_double_curly_brace: match = re.match("^\{(.*)\}$", answer) # noqa: W605 if match: answer = match.group(1) return answer def normalize_final_answer(final_answer: str) -> str: """ Normalizes a final answer string by removing or replacing various LaTeX and text elements. Args: final_answer (str): The answer string to normalize. Returns: str: The normalized answer string. """ match = re.search(r"(.*?)Problem:", final_answer, flags=re.S) if match: final_answer = match.group(1) # 返回匹配的第一部分,即"Problem"之前的所有文本 """Normalize a final answer to a quantitative reasoning question.""" # final_answer = final_answer.split('=')[-1] SUBSTITUTIONS = [ ("an ", ""), ("a ", ""), (".$", "$"), ("\\$", ""), (r"\ ", ""), (" ", ""), ("mbox", "text"), (",\\text{and}", ","), ("\\text{and}", ","), ("\\text{m}", "\\text{}"), ("\\le", "<"), ] REMOVED_EXPRESSIONS = [ "square", "ways", "integers", "dollars", "mph", "inches", "ft", "hours", "km", "units", "\\ldots", "sue", "points", "feet", "minutes", "digits", "cents", "degrees", "cm", "gm", "pounds", "meters", "meals", "edges", "students", "childrentickets", "multiples", "\\text{s}", "\\text{.}", "\\text{\ns}", "\\text{}^2", "\\text{}^3", "\\text{\n}", "\\text{}", r"\mathrm{th}", r"^\circ", r"^{\circ}", r"\;", r",\!", "{,}", '"', "\\dots", "\n", "\r", "\f", "\%", ] for before, after in SUBSTITUTIONS: final_answer = final_answer.replace(before, after) for expr in REMOVED_EXPRESSIONS: final_answer = final_answer.replace(expr, "") # Extract answer that is in LaTeX math, is bold, # is surrounded by a box, etc. final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) assert "\n" not in final_answer assert "\r" not in final_answer assert "\f" not in final_answer if len(re.findall(r"finalansweris(.*)", final_answer)) > 0: final_answer = re.findall(r"finalansweris(.*)", final_answer)[-1] if len(re.findall(r"answer?is:?(.*)", final_answer)) > 0: final_answer = re.findall(r"answer?is:?(.*)", final_answer)[-1] if len(re.findall(r"oxed\{(.*?)\}", final_answer)) > 0: final_answer = re.findall(r"oxed\{(.*?)\}", final_answer)[-1] if len(re.findall(r"\$(.*?)\$", final_answer)) > 0: final_answer = re.findall(r"\$(.*?)\$", final_answer)[-1] final_answer = final_answer.strip() if "rac" in final_answer and "\\frac" not in final_answer: final_answer = final_answer.replace("rac", "\\frac") final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) final_answer = final_answer.replace("$", "") if final_answer.replace(",", "").isdigit(): final_answer = final_answer.replace(",", "") return final_answer def naive_parse(answer: str) -> str: """ Extracts and returns the numeric digits from the input string, processing them in reverse order until a non-numeric character is encountered after encountering the first numeric character. Args: answer (str): The input string to parse. Returns: str: A string consisting of the numeric digits extracted from the input, in their original order. Example: >>> naive_parse("abc123def") '123' >>> naive_parse("def456ghi") '456' >>> naive_parse("no numbers here") '' """ out = [] start = False end = False for l in reversed(list(answer)): if l in "0123456789" and not end: start = True out.append(l) else: if start: end = True out = reversed(out) return "".join(out) def validate_answer_is_numeric(x: str | int | float) -> int: FLOAT_TOLERANCE = 0.2 try: x = round(float(x)) f = float(x) if abs(x - f) > FLOAT_TOLERANCE: x = -1 except Exception: x = -1 return x def get_majority_vote(responses: List[int]) -> int: if len(responses) < 1: return 0 else: c = Counter(responses) value, count = c.most_common()[0] return value def filter_answers(answers: List[str]) -> List[int]: formatted_answers = [validate_answer_is_numeric(a) for a in answers] # Filter for non-negative answers formatted_answers = [a for a in formatted_answers if a >= 0] # Compute modulo formatted_answers = [a % 1_000 for a in formatted_answers] # less than 2.1 billion or cannot convert to C int (32-bit) formatted_answers = [a for a in formatted_answers if a <= 999] return formatted_answers def check_sympy_equivalence(ref_answer: str, model_answer: str) -> bool: def do_answers_match(ref_answer: str, model_answer: str) -> bool: ref_sympy = parse_latex(ref_answer) model_sympy = parse_latex(model_answer) diff = simplify(ref_sympy - model_sympy) return True if -1e-12 < N(diff) < 1e-12 or diff.is_zero else False try: result = do_answers_match(ref_answer, model_answer) return result except Exception as e: print(e) return False def check_string_match(ref_answer: str, model_answer: str) -> bool: try: return ref_answer == model_answer except Exception as e: print(e) return False def check_answer(ref_answer: str, model_answer: str) -> bool: # check if strings are the same correct = check_string_match(ref_answer, model_answer) if correct: return True # use the sympy library to check if the expressions are the same correct = check_sympy_equivalence(ref_answer, model_answer) if correct: return True return False debug = False model_id = "Numina-Math-7B" revision = "main" system_prompt = "{}" validation_set = "kaggle-validation-set-medium" is_submission = True num_samples = 4 num_generations = 4 temperature = 0.8 is_quantized = False restart_on_fail = False top_p = 1.0 top_k = 0 max_new_tokens = 2048 # Papermill related variables push_to_hub = False notebook_name = "" config = Config( debug=debug, push_to_hub=push_to_hub, model_id=model_id, revision=revision, system_prompt=system_prompt, validation_set=validation_set, is_quantized=is_quantized, restart_on_fail=restart_on_fail, is_submission=is_submission, num_samples=num_samples, num_generations=num_generations, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens, ) print(f"=== Running submission with config ===\n\n{config}") def generate(message): chat_completion = client.chat.completions.create( model="tgi", messages=message, stream=True, max_tokens=1024, stop=["```output\n"], temperature=temperature, ) for message in chat_completion: yield message.choices[0].delta.content def get_majority_text(data): from collections import Counter # Count the frequency of each answer in model_answers answer_counts = Counter(data["model_answers"]) # Find the majority response majority_response = answer_counts.most_common(1)[0][0] # Find the index of the first occurrence of the majority response majority_index = data["model_answers"].index(majority_response) # Return the corresponding text in gen_texts return data["gen_texts"][majority_index] def extract_solution(text): # Split the text at "### Solution:" parts = text.split("### Solution:", 1) if len(parts) > 1: # Return everything after "### Solution:" return parts[1].strip() else: # Return an empty string if "### Solution:" is not found return "" def process_code( example: Dict[str, Any], config: Config, restart_on_fail: bool = False, last_step: bool = False, ) -> Dict[str, Any]: gen_text = example["gen_texts"] num_python_blocks = len(re.findall(r"```python(.*?)```", gen_text, re.DOTALL)) if num_python_blocks == 0: if restart_on_fail: print("no code has ever been generated, RESTARTING") # reset the text to the original example["gen_texts"] = example["text"] else: print("no code has ever been generated, STOP") example["should_prune"] = True example["has_code"] = False return example if gen_text[-10:] != "```output\n" and ("answer is" in gen_text[-100:] or "\\boxed" in gen_text[-100:]): num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL)) if num_output_blocks == 0: print("the model hallucinated the code answer") example["should_prune"] = True return example if "boxed" in gen_text[-100:]: try: answer = normalize_final_answer(extract_boxed_answer(gen_text[-100:])) except Exception: answer = "-1" else: answer = normalize_final_answer(gen_text[-100:]) example["model_answers"] = answer if not config.is_submission: example["corrects"] = check_answer(example["ground_truth"], answer) example["should_prune"] = True print("Answer is: ", answer, example["ground_truth"], example["corrects"]) return example if last_step: # no point in continuing if we are at the last step return example if gen_text[-10:] != "```output\n": # something else has gone wrong with the generation print("warning: output block not found: ", gen_text[-40:]) if restart_on_fail: example["gen_texts"] = example["text"] else: example["should_prune"] = True return example code_result, status = postprocess_completion(gen_text, return_status=True, last_code_block=True) # add the code result for the next round of generation TRUNCATION_LIMIT = 200 if len(code_result) > TRUNCATION_LIMIT: code_result = code_result[:TRUNCATION_LIMIT] + " ... (output truncated)" example["gen_texts"] = gen_text + f"{code_result}\n```" return example # load the vllm instance and set sampling parameters # vllm = build_vllm(config) def solve_problem(problem, temperature, progress=gr.Progress()): problem = apply_template({"prompt": problem}, prompt=config.system_prompt) print(f"Problem: {problem}") sample = { "problem": problem, # not used for the submission TODO Remove "ground_truth": "unknown", # not used for the submission TODO Remove "text": "### Solution:\n", "gen_texts": "### Solution:\n", # used to store all the generated text "should_prune": False, "problem_index": -1, # not used for the submission TODO Remove "model_answers": "-1", "has_code": True, "corrects": False, # not used for the submission TODO Remove } for step in progress.tqdm( range(config.num_generations), desc="Generating candidates" ): # Depth of the tree (e.g. 6 steps = 5 code blocks) step_reponse = sample["gen_texts"] messages = [ {"role": "user", "content": sample["problem"]}, {"role": "assistant", "content": sample["gen_texts"]}, ] for reponse_message in generate(messages, temperature): if reponse_message is not None: step_reponse += reponse_message yield step_reponse sample["gen_texts"] = step_reponse # TODO: Maybe it should just return the result of running the code sample = process_code( sample, config=config, restart_on_fail=config.restart_on_fail, last_step=(step == (config.num_generations - 1)), ) sample["gen_texts"] = sample["gen_texts"] + "\n" run_code_reponse = sample["gen_texts"].replace(step_reponse, "") for output_mseeage in run_code_reponse: if output_mseeage is not None: step_reponse += output_mseeage yield step_reponse if sample["should_prune"]: break yield sample["gen_texts"] with gr.Blocks() as demo: with gr.Row(): inp = gr.Textbox(placeholder="Problem", label="Problem", lines=5) with gr.Accordion("Advanced Options", open=False): temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.1, label="Temperature") with gr.Row(): out = gr.Markdown() btn = gr.Button("Run") btn.click(fn=solve_problem, inputs=[inp, temperature], outputs=out) if __name__ == "__main__": demo.queue(default_concurrency_limit=5).launch()