|
import json |
|
import argparse |
|
import re |
|
from concurrent.futures import ProcessPoolExecutor, as_completed |
|
from vllm import LLM, SamplingParams |
|
from tqdm import tqdm |
|
from util.task_handlers import * |
|
from util.model_utils import * |
|
from openai import OpenAI |
|
import concurrent.futures |
|
from functools import partial |
|
|
|
class NumpyEncoder(json.JSONEncoder): |
|
def default(self, obj): |
|
if isinstance(obj, np.ndarray): |
|
return obj.tolist() |
|
return super().default(obj) |
|
|
|
def fetch_response_openai(llm, model_name, max_tokens, temp, prompt): |
|
model_name = model_name.replace("openai/", "") |
|
if "o1" in model_name: |
|
|
|
|
|
for p in prompt: |
|
p["role"] = "user" |
|
|
|
response = llm.chat.completions.create( |
|
model=model_name, |
|
messages=prompt, |
|
n=1, |
|
temperature=1, |
|
max_completion_tokens=max_tokens |
|
) |
|
else: |
|
response = llm.chat.completions.create( |
|
model=model_name, |
|
messages=prompt, |
|
n=1, |
|
temperature=temp, |
|
max_tokens=max_tokens |
|
) |
|
return response |
|
|
|
def perform_inference_and_check(handler: TaskHandler, temperatures, max_tokens, result_file, llm, system_prompt, args): |
|
results = handler.load_existing_results(result_file) |
|
print(f"Loaded {len(results)} existing results.") |
|
train_data = handler.load_and_filter_dataset(args.start, args.end, split=args.split, source=args.source, filter_difficulty=args.filter_difficulty) |
|
remaining_data = handler.process_remaining_data(train_data, results) |
|
conversations = handler.make_conversations(remaining_data, system_prompt) |
|
|
|
for temp in temperatures: |
|
|
|
if args.model.startswith("openai"): |
|
fetch_partial = partial(fetch_response_openai, llm, args.model, max_tokens, temp) |
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as e: |
|
responses = list(e.map(fetch_partial, conversations)) |
|
|
|
else: |
|
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temp) |
|
responses = llm.chat(messages=conversations, sampling_params=sampling_params, use_tqdm=True) |
|
|
|
total_correct = 0 |
|
total_finish = 0 |
|
with ProcessPoolExecutor(max_workers=32) as executor: |
|
|
|
|
|
|
|
|
|
future_to_task = {} |
|
token_usages = {} |
|
for idx, response in enumerate(responses): |
|
if args.model.startswith("openai"): |
|
response_str = response.choices[0].message.content.strip() |
|
else: |
|
response_str = response.outputs[0].text.strip() |
|
future_to_task[executor.submit(handler.update_results, remaining_data[idx], response_str)] = idx |
|
|
|
|
|
if args.model.startswith("openai"): |
|
token_usages[idx] = response.usage |
|
else: |
|
token_usages[idx] = { |
|
"completion_tokens": len(response.outputs[0].token_ids), |
|
"prompt_tokens": len(response.prompt_token_ids) |
|
} |
|
|
|
for future in tqdm(as_completed(future_to_task), total=len(future_to_task), desc="Processing Generations"): |
|
idx = future_to_task[future] |
|
response_entry = future.result() |
|
total_correct += response_entry["correctness"] |
|
total_finish += 1 |
|
|
|
problem_key = remaining_data[idx][handler.get_question_key()] |
|
if problem_key not in results: |
|
results[problem_key] = remaining_data[idx] |
|
if isinstance(handler, NUMINATaskHandler): |
|
results[problem_key]["messages"] = "" |
|
results[problem_key]["responses"] = {} |
|
results[problem_key]["token_usages"] = {} |
|
prompt = conversations[idx][1]["content"] |
|
results[problem_key]["prompt"] = prompt |
|
|
|
results[problem_key]["responses"][str(temp)] = response_entry |
|
|
|
if args.model.startswith("openai"): |
|
results[problem_key]["token_usages"][str(temp)] = { |
|
"completion_tokens": token_usages[idx].completion_tokens, |
|
"prompt_tokens": token_usages[idx].prompt_tokens, |
|
} |
|
else: |
|
|
|
results[problem_key]["token_usages"][str(temp)] = token_usages[idx] |
|
|
|
print(f"Final acc: {total_correct}/{total_finish}") |
|
acc = round(total_correct / total_finish, 4) if total_finish > 0 else 0 |
|
print(json.dumps({"acc": acc})) |
|
|
|
completion_tokens = [ |
|
results[key].get("token_usages", {}).get(str(temp), {}).get("completion_tokens", 0) |
|
for key in results for temp in temperatures |
|
] |
|
prompt_tokens = [ |
|
results[key].get("token_usages", {}).get(str(temp), {}).get("prompt_tokens", 0) |
|
for key in results for temp in temperatures |
|
] |
|
|
|
|
|
result_dir, result_name = os.path.split(result_file) |
|
token_usage_dir = os.path.join(result_dir, "token_usage") |
|
os.makedirs(token_usage_dir, exist_ok=True) |
|
|
|
|
|
token_usage_result_file = os.path.join(token_usage_dir, result_name) |
|
|
|
|
|
token_dict = { |
|
"completion_tokens": sum(completion_tokens), |
|
"prompt_tokens": sum(prompt_tokens), |
|
"avg_completion_tokens": round(sum(completion_tokens) / len(completion_tokens), 3) if completion_tokens else 0, |
|
"avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) if prompt_tokens else 0, |
|
} |
|
|
|
|
|
with open(token_usage_result_file, "w") as f: |
|
json.dump(token_dict, f, indent=4) |
|
|
|
print(f"Token usage saved to {token_usage_result_file}") |
|
|
|
with open(result_file, 'w', encoding='utf-8') as file: |
|
json.dump(results, file, ensure_ascii=False, indent=4, cls=NumpyEncoder) |
|
|
|
def perform_check(handler: TaskHandler, temperatures, result_file, args): |
|
results = handler.load_existing_results(result_file) |
|
print(f"Loaded {len(results)} existing results.") |
|
|
|
train_data = handler.load_and_filter_dataset(args.start, args.end, split=args.split, source=args.source, filter_difficulty=args.filter_difficulty) |
|
remaining_data = handler.process_remaining_data(train_data, {}) |
|
|
|
tasks = [] |
|
for item in remaining_data: |
|
problem_key = item[handler.get_question_key()] |
|
|
|
if problem_key in results and "responses" in results[problem_key]: |
|
for temp in temperatures: |
|
if str(temp) in results[problem_key]["responses"]: |
|
response_entry = results[problem_key]["responses"][str(temp)] |
|
if response_entry["correctness"] is None: |
|
processed = "processed_content" in response_entry |
|
tasks.append((item, temp, response_entry["processed_content"] if processed else response_entry["content"])) |
|
|
|
print(f"Found {len(tasks)} responses requiring reject sampling...") |
|
|
|
total_correct = 0 |
|
total_finish = 0 |
|
|
|
with ProcessPoolExecutor(max_workers=32) as executor: |
|
future_to_task = { |
|
executor.submit(handler.update_results, item, content): (item, temp) |
|
for (item, temp, content) in tasks |
|
} |
|
|
|
|
|
for future in tqdm(as_completed(future_to_task), total=len(future_to_task), desc="Processing Reject Sampling"): |
|
item, temp = future_to_task[future] |
|
new_response_entry = future.result() |
|
total_correct += new_response_entry["correctness"] |
|
total_finish += 1 |
|
|
|
|
|
problem_key = item[handler.get_question_key()] |
|
assert problem_key in results and "responses" in results[problem_key] and str(temp) in results[problem_key]["responses"] |
|
response_entry = results[problem_key]["responses"][str(temp)] |
|
response_entry["correctness"] = new_response_entry["correctness"] |
|
response_entry["reason"] = new_response_entry["reason"] |
|
results[problem_key]["responses"][str(temp)] = response_entry |
|
|
|
print(f"Final reject-sampling accuracy: {total_correct}/{total_finish}") |
|
|
|
with open(result_file, 'w', encoding='utf-8') as file: |
|
json.dump(results, file, ensure_ascii=False, indent=4, cls=NumpyEncoder) |
|
|
|
|
|
def perform_inference_and_save(handler: TaskHandler, temperatures, max_tokens, result_file, llm, system_prompt, args): |
|
results = handler.load_existing_results(result_file) |
|
print(f"Loaded {len(results)} existing results.") |
|
train_data = handler.load_and_filter_dataset(args.start, args.end, split=args.split, source=args.source, filter_difficulty=args.filter_difficulty) |
|
remaining_data = handler.process_remaining_data(train_data, results) |
|
conversations = handler.make_conversations(remaining_data, system_prompt) |
|
|
|
for temp in temperatures: |
|
if args.model.startswith("openai"): |
|
fetch_partial = partial(fetch_response_openai, llm, args.model, max_tokens, temp) |
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as e: |
|
responses = list(e.map(fetch_partial, conversations)) |
|
|
|
else: |
|
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temp) |
|
responses = llm.chat(messages=conversations, sampling_params=sampling_params, use_tqdm=True) |
|
|
|
for idx, response in enumerate(responses): |
|
response_entry = { |
|
"content": response.choices[0].message.content.strip() if args.model.startswith("openai") else response.outputs[0].text.strip(), |
|
"correctness": None, |
|
"reason": None, |
|
} |
|
problem_key = remaining_data[idx][handler.get_question_key()] |
|
if problem_key not in results: |
|
results[problem_key] = remaining_data[idx] |
|
if isinstance(handler, NUMINATaskHandler): |
|
results[problem_key]["messages"] = "" |
|
results[problem_key]["responses"] = {} |
|
results[problem_key]["token_usages"] = {} |
|
prompt = conversations[idx][1]["content"] |
|
results[problem_key]["prompt"] = prompt |
|
|
|
results[problem_key]["responses"][str(temp)] = response_entry |
|
|
|
if args.model.startswith("openai"): |
|
results[problem_key]["token_usages"][str(temp)] = { |
|
"completion_tokens": response.usage.completion_tokens, |
|
"prompt_tokens": response.usage.prompt_tokens, |
|
} |
|
else: |
|
results[problem_key]["token_usages"][str(temp)] = { |
|
"completion_tokens": len(response.outputs[0].token_ids), |
|
"prompt_tokens": len(response.prompt_token_ids) |
|
} |
|
|
|
completion_tokens = [ |
|
results[key].get("token_usages", {}).get(str(temp), {}).get("completion_tokens", 0) |
|
for key in results for temp in temperatures |
|
] |
|
prompt_tokens = [ |
|
results[key].get("token_usages", {}).get(str(temp), {}).get("prompt_tokens", 0) |
|
for key in results for temp in temperatures |
|
] |
|
|
|
|
|
result_dir, result_name = os.path.split(result_file) |
|
token_usage_dir = os.path.join(result_dir, "token_usage") |
|
os.makedirs(token_usage_dir, exist_ok=True) |
|
|
|
|
|
token_usage_result_file = os.path.join(token_usage_dir, result_name) |
|
|
|
|
|
token_dict = { |
|
"completion_tokens": sum(completion_tokens), |
|
"prompt_tokens": sum(prompt_tokens), |
|
"avg_completion_tokens": round(sum(completion_tokens) / len(completion_tokens), 3) if completion_tokens else 0, |
|
"avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) if prompt_tokens else 0, |
|
} |
|
|
|
|
|
with open(token_usage_result_file, "w") as f: |
|
json.dump(token_dict, f, indent=4) |
|
|
|
print(f"Token usage saved to {token_usage_result_file}") |
|
|
|
with open(result_file, 'w', encoding='utf-8') as file: |
|
json.dump(results, file, ensure_ascii=False, indent=4, cls=NumpyEncoder) |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Unified inference and checking for different datasets/tasks.") |
|
parser.add_argument("--dataset", type=str, required=True, choices=["NUMINA", "APPS", "TACO", "MATH500", "AIME", "GPQADiamond", "MMLU", "LiveCodeBench"], help="Dataset to process.") |
|
parser.add_argument("--model", type=str, required=True, default="Qwen/QwQ-32B-Preview", help="The model to run.") |
|
parser.add_argument("--tp", type=int, default=8, help="Tensor Parallelism Degree") |
|
parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for the model.") |
|
parser.add_argument("--split", type=str, default="train", help="Split to use for apps (e.g., train, test).") |
|
parser.add_argument("--source", type=str, help="Source for the dataset.") |
|
parser.add_argument("--start", type=int, default=0, help="Start index.") |
|
parser.add_argument("--end", type=int, default=-1, help="End index.") |
|
parser.add_argument("--filter-difficulty", action="store_true", help="Filter difficulty.") |
|
parser.add_argument("--result-dir", type=str, default="./", help="Result dir to save files.") |
|
parser.add_argument("--check", action="store_true", help="Perform evaluation checks on generated samples.") |
|
parser.add_argument("--inference", action="store_true", help="Perform inference.") |
|
parser.add_argument("--temperatures", type=float, nargs="+", default=[0], help="Temperature for sampling.") |
|
args = parser.parse_args() |
|
|
|
handler: TaskHandler = TASK_HANDLERS[args.dataset]() |
|
temperatures = [1] if args.model.startswith("openai/o1") else args.temperatures |
|
max_tokens = args.max_tokens |
|
|
|
|
|
if args.result_dir and not os.path.exists(args.result_dir): |
|
os.makedirs(args.result_dir) |
|
model_name = os.path.basename(args.model) |
|
result_file = os.path.join(args.result_dir, f"{model_name}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}.json") |
|
|
|
if args.check: |
|
|
|
converted_file = f"{args.result_dir}/converted_{model_name}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}.json" |
|
if os.path.exists(converted_file): |
|
result_file = converted_file |
|
perform_check(handler, temperatures, result_file, args) |
|
return |
|
elif args.inference: |
|
llm = OpenAI() if args.model.startswith("openai") else LLM(model=args.model, tensor_parallel_size=args.tp) |
|
system_prompt = SYSTEM_PROMPT["NovaSky-AI/Sky-T1-32B-Preview"] |
|
perform_inference_and_save(handler, temperatures, max_tokens, result_file, llm, system_prompt, args) |
|
return |
|
|
|
llm = OpenAI() if args.model.startswith("openai") else LLM(model=args.model, tensor_parallel_size=args.tp) |
|
system_prompt = SYSTEM_PROMPT["NovaSky-AI/Sky-T1-32B-Preview"] |
|
perform_inference_and_check(handler, temperatures, max_tokens, result_file, llm, system_prompt, args) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|