Spaces:
Running
Running
| import datasets | |
| import json | |
| import re | |
| import random | |
| import argparse | |
| from transformers import AutoTokenizer | |
| from vllm import LLM, SamplingParams | |
| def extract_last_boxed(text): | |
| pattern = r'\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}' | |
| matches = list(re.finditer(pattern, text)) | |
| if matches: | |
| return matches[-1].group(1) | |
| return None | |
| def extract_last_final_answer(text): | |
| pattern1 = r'Final Answer:((?:[^<]|<[^<])*?)\n' | |
| pattern2 = r'The answer is:((?:[^<]|<[^<])*?)\n' | |
| matches1 = list(re.finditer(pattern1, text)) | |
| matches2 = list(re.finditer(pattern2, text)) | |
| if matches1: | |
| return matches1[-1].group(1) | |
| elif matches2: | |
| return matches2[-1].group(1) | |
| return None | |
| def extract_solution(solution_str): | |
| if '<|im_start|>user' in solution_str: | |
| model_output = re.sub(r'^.*?<\|im_start\|>assistant', '<|im_start|>assistant', solution_str, flags=re.DOTALL, count=1) | |
| elif 'Assistant:' in solution_str: | |
| model_output = solution_str.split('Assistant:')[-1].strip() | |
| else: | |
| model_output = solution_str | |
| stop_words = ["</s>", "<|im_end|>", "<|endoftext|>"] | |
| for stop_word in stop_words: | |
| if stop_word in model_output: | |
| model_output = model_output.split(stop_word)[0].strip() | |
| extract_boxed_answer = extract_last_boxed(model_output) | |
| if extract_boxed_answer: | |
| return extract_boxed_answer | |
| else: | |
| return extract_last_final_answer(model_output) | |
| def strip_latex(response: str) -> str: | |
| if response.startswith("$") and response.endswith("$"): | |
| response = response[1:-1] | |
| if "boxed{" in response and response.endswith("}"): | |
| response = response[0:-1].split("boxed{")[1] | |
| if "text{" in response and response.endswith("}"): | |
| response = response[0:-1].split("text{")[1] | |
| if "texttt{" in response and response.endswith("}"): | |
| response = response[0:-1].split("texttt{")[1] | |
| return response | |
| def extract_answer(sample: str) -> str: | |
| if sample is None: | |
| sample = "" | |
| """Extracts the final answer from the sample.""" | |
| answer_prefixes = [ | |
| "The answer is:", | |
| "The final answer is ", | |
| "The final answer is: ", | |
| "The answer is " | |
| ] | |
| answer = sample | |
| for answer_prefix in answer_prefixes: | |
| if answer_prefix in answer: | |
| answer = answer.split(answer_prefix)[-1].strip() | |
| if answer.endswith("."): | |
| answer = answer[:-1] | |
| return strip_latex(answer) | |
| def fuzzy_match(prediction: str, reference: str) -> bool: | |
| """Fuzzy match function for BigBench Extra Hard.""" | |
| if prediction == reference: | |
| return True | |
| # (a) vs a | |
| if len(prediction) == 3 and prediction[0] == "(" and prediction[-1] == ")": | |
| return prediction[1] == reference | |
| if len(reference) == 3 and reference[0] == "(" and reference[-1] == ")": | |
| return reference[1] == prediction | |
| # Numbers | |
| try: | |
| if float(prediction) == float(reference): | |
| return True | |
| except ValueError: | |
| pass | |
| # quote issues | |
| if prediction.replace("'", "") == reference.replace("'", ""): | |
| return True | |
| # Bracket issues | |
| if f"[{reference}]" == prediction or f"[{prediction}]" == reference: | |
| return True | |
| # Question mark issues | |
| if prediction.endswith("?") and prediction[:-1] == reference: | |
| return True | |
| return False | |
| def preprocess_sample(sample: str) -> str: | |
| if sample is None: | |
| sample = "" | |
| prediction = extract_answer(sample.strip()).lower() | |
| prediction = prediction.replace(", ", ",").replace("**", "") | |
| prediction = prediction.split("\n")[0] | |
| prediction = prediction[0:-1] if prediction.endswith(".") else prediction | |
| return prediction | |
| def preprocess_reference(reference: str) -> str: | |
| reference = reference.strip().lower() | |
| reference = reference.replace(", ", ",") | |
| return reference | |
| def evaluate_correctness(sample: str, reference: str) -> bool: | |
| prediction = preprocess_sample(sample) | |
| reference = preprocess_reference(reference) | |
| return fuzzy_match(prediction, reference) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_path", type=str, required=True, help="Path to the model directory") | |
| parser.add_argument("--output_file", type=str, default="outputs.json", help="File to save results") | |
| args = parser.parse_args() | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) | |
| llm = LLM(model=args.model_path, tensor_parallel_size=4,gpu_memory_utilization=0.85) | |
| dataset = datasets.load_dataset('MrLight/bbeh-eval') | |
| categories = sorted(list(set(dataset['train']['task']))) | |
| print("Categories:", categories) | |
| per_category_accuracy = {c: [0, 0] for c in categories} | |
| success, fail = 0, 0 | |
| answers = [] | |
| print('----------------- Start Answering -------------------') | |
| for category in categories: | |
| category_entries = [entry for entry in dataset['train'] if entry['task'] == category] | |
| prompts = [] | |
| for entry in category_entries: | |
| query = entry['question'] + '\n' | |
| messages = [{ | |
| "role": "user", | |
| "content": query + '\nPlease reason step by step, and put your final answer option within \\boxed{}.' | |
| }] | |
| if tokenizer.chat_template: | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False) | |
| else: | |
| prompt = "user: " + query + '\nPlease reason step by step, and put your final answer option within \\boxed{}. Only put the letter in the box, e.g. \\boxed{A}. There is only one correct answer.' | |
| prompts.append(prompt) | |
| sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=8192) | |
| outputs = llm.generate(prompts, sampling_params) | |
| for entry, output in zip(category_entries, outputs): | |
| answer = output.outputs[0].text | |
| entry['solution'] = answer | |
| answers.append(entry) | |
| answer = extract_solution(answer) | |
| if evaluate_correctness(answer, entry['answer']): | |
| success += 1 | |
| per_category_accuracy[category][0] += 1 | |
| else: | |
| fail += 1 | |
| per_category_accuracy[category][1] += 1 | |
| print(f"{category}: {per_category_accuracy[category][0] / (per_category_accuracy[category][0] + per_category_accuracy[category][1]):.4f}") | |
| with open(args.output_file, 'w') as f: | |
| json.dump(answers, f, indent=2) | |
| with open('final_results.jsonl', 'a') as f: | |
| json.dump({"dataset": "bbeh", "model": args.model_path, "accuracy": round(success / (success + fail)*100, 2)}, f, indent=2) | |
| print("Overall Accuracy:", success / (success + fail)) |