Dyve_plus_RL_copy / llm_as_judge.py
zeju-0727's picture
Upload llm_as_judge.py with huggingface_hub
9c19867 verified
import heapq
import math
import random
import re
import json
from typing import List, Tuple, Dict, Any, Optional
import itertools
from transformers import AutoTokenizer
import asyncio # New import added for async handling
from openai import AsyncOpenAI # Using AsyncOpenAI as client
import numpy as np
from openai import OpenAI
import openai
import json
from tqdm import tqdm
import re
def read_jsonl(file_path):
data = []
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
data.append(json.loads(line.strip()))
return data
def extract_answer_judge(solution_text: str):
boxed_pattern = r'\\boxed\{([^}]*)\}'
matches = re.findall(boxed_pattern, solution_text)
if matches:
return matches[-1].strip()
return None
def separate_steps(steps: List[str], mode: str = 'join') -> Any:
delimiter = "\n\n"
if mode == 'join':
if not isinstance(steps, list):
raise TypeError("For 'join' mode, 'steps' must be a list of strings.")
return delimiter.join(steps)
elif mode == 'split':
if not isinstance(steps, str):
raise TypeError("For 'split' mode, 'steps' must be a string.")
return steps.split(delimiter)
else:
raise ValueError("Mode should be either 'join' or 'split'.")
def evaluate_llm_as_judge(problem: str, steps: list, output_type: str = 'bool') -> bool:
global client
client = OpenAI(
base_url="http://localhost:8014/v1",
api_key="token-abc123"
)
# client = OpenAI(
# base_url="https://ark.cn-beijing.volces.com/api/v3",
# api_key="d61217e7-8ff3-4937-83ed-3dd2bebf72ad"
# )
model_name = "DeepSeek-R1-Distill-Qwen-14B"
# model_name = 'deepseek-v3-241226'
messages = []
feedback = None
judge_prompt = f"""
The following is a math problem and a solution (split into paragraphs, enclosed with tags and indexed from 0):
[Math Problem]
{problem}
[Solution]
{steps}
Your task is to review and critique the solution paragraph by paragraph. Once you identify an error in a paragraph, return the index of the paragraph where the earliest error occurs. Otherwise, return the index of -1 (which typically denotes "not found").
Please put your final answer (i.e., the index) in \\boxed{{}}.
"""
messages.append({
'role': 'user',
'content': judge_prompt
})
completion = client.chat.completions.create(
model=model_name,
messages=messages,
n=1,
temperature=0.6,
max_tokens=8192,
)
response = completion.choices[0].message.content
# print('*****step*****',steps)
# print("*****Verification*****:", response)
content = response.strip().lower()
last_words = ' '.join(content.split()[-10:]) # Last 3 words
# print('last_words:', last_words)
error_step_index = extract_answer_judge(last_words)
if isinstance(error_step_index, int) and (int(error_step_index) != -1) and (error_step_index != None):
print('error_step_index',error_step_index)
merged_data = {
'question': problem,
'reasining_steps': steps,
'error_step_index': error_step_index,
'response': response
}
return merged_data
elif int(error_step_index) == -1:
print('error_step_index',error_step_index)
merged_data = {
'question': problem,
'reasining_steps': steps,
'error_step_index': -1,
'response': response
}
return merged_data
else:
return None
new_file_path = '0312_training_new_processed_16w.jsonl'
data_all = read_jsonl(new_file_path)
print(len(data_all))
output = []
zero = 0
for data in tqdm(data_all[0:10]):
print(data_all.index(data))
problem = data['question']
steps_ori = data['process']
labels = data['label']
steps = steps_ori.split('\n\n')
steps[0] = problem + ' ' + steps[0]
# print('steps:',steps)
steps_updated = steps[0:len(steps)-1]
if zero in data['label']:
merged_data = evaluate_llm_as_judge(problem=problem, steps=steps_updated, output_type='bool')
if merged_data != None:
output.append(merged_data)
else:
merged_data = {
'question': problem,
'reasining_steps': steps_updated,
'error_step_index': -1,
'response': '<think>\n\n</think>-1'
}
output.append(merged_data)
if len(output) % 100 == 0:
output_file = '0312_training_fast_slow_thinking.jsonl'
with open(output_file, 'w', encoding='utf-8') as output_file:
for entry in output:
output_file.write(json.dumps(entry, ensure_ascii=False) + '\n')
print(f"数据已成功写入 {output_file}")