|
import time |
|
import os |
|
import argparse |
|
|
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
from openai import OpenAI |
|
from vllm import LLM, SamplingParams |
|
from transformers import AutoTokenizer |
|
from utils.metrics import qa_f1_score, qa_em_score |
|
|
|
|
|
|
|
|
|
|
|
openai_client = OpenAI( |
|
api_key=os.environ.get("OPENAI_API_KEY"), |
|
base_url=os.environ.get("OPENAI_BASE_URL") |
|
) |
|
|
|
def get_openai_rephrase_response(prompt, model="gpt-4o", retries=3, delay=2): |
|
"""Call OpenAI API for rephrasing.""" |
|
for attempt in range(retries): |
|
try: |
|
completion = openai_client.chat.completions.create( |
|
model=model, |
|
messages=[{'role': 'user', 'content': prompt}], |
|
max_tokens=100 |
|
) |
|
return completion.choices[0].message.content.strip() |
|
except Exception as e: |
|
print(f"OpenAI Rephrase attempt {attempt + 1} failed: {e}") |
|
if attempt < retries - 1: |
|
print(f"Retrying OpenAI rephrase in {delay} seconds...") |
|
time.sleep(delay) |
|
else: |
|
print("Max retries for OpenAI rephrase reached.") |
|
return "Failed to rephrase question" |
|
|
|
def rephrase_question_with_gpt4o(question, rephrase_type="opposite"): |
|
"""Rephrase a question using GPT-4o (English prompt).""" |
|
if rephrase_type == "opposite": |
|
prompt = f"""Please rephrase the following question to have the exact opposite meaning. |
|
Question: {question} |
|
|
|
Return only the rephrased question with the opposite meaning, without any explanations or other content.""" |
|
elif rephrase_type == "similar": |
|
prompt = f"""Please rephrase the following question to be synonymous, maintaining the original meaning but using different wording: |
|
Question: {question} |
|
|
|
Return only the rephrased question, without any explanations or other content.""" |
|
else: |
|
raise ValueError(f"Invalid rephrase_type: {rephrase_type}. Must be 'opposite' or 'similar'.") |
|
|
|
return get_openai_rephrase_response(prompt) |
|
|
|
|
|
def get_vllm_response(prompt_text, llm_instance, sampling_params_instance, retries=2, delay=5): |
|
"""Generate a response from a vLLM instance.""" |
|
for attempt in range(retries): |
|
try: |
|
|
|
outputs = llm_instance.generate([prompt_text], sampling_params_instance) |
|
|
|
|
|
response = outputs[0].outputs[0].text.strip() |
|
return response |
|
except Exception as e: |
|
print(f"vLLM generation attempt {attempt + 1} failed: {e}") |
|
if attempt < retries - 1: |
|
print(f"Retrying vLLM generation in {delay} seconds...") |
|
time.sleep(delay) |
|
else: |
|
print("Max retries for vLLM generation reached.") |
|
return "Failed to get vLLM response" |
|
|
|
def answer_question_with_context_vllm(question, context, llm_instance, sampling_params_instance, tokenizer): |
|
"""Answer a question with context using a vLLM model and chat template (English prompt).""" |
|
|
|
prompt_content = ( |
|
f"Answer the question based on the given passages. " |
|
"Only give me your answer and do not output any other words.\\n" |
|
"The following are given passages:\\n" |
|
f"{context}\\n" |
|
"Please strictly follow the context. " |
|
f"Question: {question}\\n" |
|
"Answer:" |
|
) |
|
messages = [{"role": "user", "content": prompt_content}] |
|
|
|
|
|
|
|
|
|
try: |
|
final_prompt_text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
except Exception as e: |
|
print(f"Failed to apply chat template: {e}. Falling back to basic prompt string.") |
|
|
|
final_prompt_text = f"Context:\\n{context}\\n\\nQuestion: {question}\\n\\nAnswer:" |
|
|
|
return get_vllm_response(final_prompt_text, llm_instance, sampling_params_instance) |
|
|
|
def main(args): |
|
|
|
print(f"Loading tokenizer for model: {args.model_name}...") |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=args.trust_remote_code) |
|
print("Successfully loaded tokenizer.") |
|
except Exception as e: |
|
print(f"Failed to load tokenizer for {args.model_name}: {e}") |
|
print("Please ensure the model name is correct and the tokenizer can be loaded.") |
|
return |
|
|
|
|
|
print(f"Loading vLLM model for Answering: {args.model_name}...") |
|
print(f"(This may take a while depending on the model size and download speed if not cached).") |
|
vllm_model = None |
|
try: |
|
|
|
|
|
vllm_model = LLM( |
|
model=args.model_name, |
|
trust_remote_code=args.trust_remote_code, |
|
dtype="bfloat16", |
|
|
|
tensor_parallel_size=2 |
|
) |
|
print(f"Successfully loaded vLLM model {args.model_name} with dtype='{args.dtype}' and tensor_parallel_size={args.tensor_parallel_size}.") |
|
except Exception as e: |
|
print(f"Failed to load vLLM model {args.model_name}: {e}") |
|
print("Please ensure vLLM is installed correctly and the model identifier is valid.") |
|
return |
|
|
|
|
|
|
|
|
|
|
|
sampling_params = SamplingParams(temperature=0.0, max_tokens=30) |
|
|
|
|
|
print(f"Loading dataset {args.dataset_name}, subset {args.dataset_subset}...") |
|
try: |
|
dataset = load_dataset(args.dataset_name, args.dataset_subset)["test"] |
|
print(f"Successfully loaded dataset with {len(dataset)} samples.") |
|
except Exception as e: |
|
print(f"Failed to load dataset: {e}") |
|
return |
|
|
|
em_match_count = 0 |
|
em_match_original_count = 0 |
|
successfully_processed_samples = 0 |
|
|
|
num_samples_to_process = len(dataset) if args.sample_count == -1 else min(args.sample_count, len(dataset)) |
|
|
|
print(f"Processing {num_samples_to_process} samples. Rephrasing with GPT-4o (opposite meaning). Answering with vLLM model {args.model_name} (max 30 tokens)...") |
|
|
|
for i in tqdm(range(num_samples_to_process), desc="Processing samples with vLLM"): |
|
example = dataset[i] |
|
original_question = example['input'] |
|
context = example['context'] |
|
ground_truth_answers = example['answers'] |
|
|
|
rephrased_question = rephrase_question_with_gpt4o(original_question, args.rephrase_type) |
|
|
|
if rephrased_question == "Failed to rephrase question": |
|
print(f"Skipping sample {i+1} due to rephrasing failure.") |
|
continue |
|
|
|
rephrased_answer = answer_question_with_context_vllm(rephrased_question, context, vllm_model, sampling_params, tokenizer) |
|
|
|
|
|
|
|
original_answer = answer_question_with_context_vllm(original_question, context, vllm_model, sampling_params, tokenizer) |
|
|
|
|
|
|
|
if not ground_truth_answers: |
|
print(f"Skipping sample {i+1} due to missing ground truth answers.") |
|
continue |
|
print(original_answer) |
|
successfully_processed_samples += 1 |
|
|
|
sample_had_em_match = False |
|
|
|
|
|
|
|
|
|
em_match_count += qa_em_score(rephrased_answer, ground_truth_answers[0]) |
|
|
|
sample_had_em_match = False |
|
|
|
|
|
print(original_answer) |
|
print(ground_truth_answers[0]) |
|
|
|
em_match_original_count += qa_em_score(original_answer, ground_truth_answers[0]) |
|
|
|
if successfully_processed_samples > 0: |
|
print(f"Answering vLLM Model: {args.model_name}") |
|
print(f"Dataset : {args.dataset_name} ({args.dataset_subset})") |
|
print(f"Successfully Processed Samples for Evaluation: {successfully_processed_samples}") |
|
print(f"Max Answer Tokens : 30") |
|
print(f"Count of EM with original ground truth (after rephrase): {em_match_count}") |
|
print(f"Count of EM with original ground truth (before rephrase): {em_match_original_count}") |
|
else: |
|
print("\nNo samples were processed adequately to provide an evaluation summary.") |
|
|
|
print("vLLM processing complete!") |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Rephrase with GPT-4o, Answer with local vLLM-hosted Model, then Evaluate.") |
|
parser.add_argument("--model_name", type=str, default="facebook/opt-125m", help="Name/path of the Hugging Face model for Answering via vLLM (e.g., 'mistralai/Mistral-7B-Instruct-v0.1').") |
|
parser.add_argument("--dataset_name", type=str, default="THUDM/LongBench", help="Name of the Hugging Face dataset.") |
|
parser.add_argument("--dataset_subset", type=str, default="2wikimqa", help="Subset of the dataset.") |
|
parser.add_argument("--sample_count", type=int, default=3, help="Number of samples to process. -1 for all. Default: 3 for quick testing.") |
|
parser.add_argument("--trust_remote_code", action="store_true", help="Set to true if the Hugging Face model for vLLM requires remote code.") |
|
parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Tensor parallel size for vLLM.") |
|
parser.add_argument("--dtype", type=str, default="auto", help="Data type for the model. Examples: 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'. Default is 'auto'.") |
|
parser.add_argument("--rephrase_type", type=str, default="opposite", choices=["opposite", "similar"], help="Type of rephrasing: 'opposite' for opposite meaning or 'similar' for similar meaning.") |
|
|
|
args = parser.parse_args() |
|
main(args) |