File size: 11,426 Bytes
b1e25b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import time
import os
import argparse
# import torch # torch might not be directly needed if vLLM handles all device aspects
from datasets import load_dataset
from tqdm import tqdm
from openai import OpenAI # For GPT-4o rephrasing
from vllm import LLM, SamplingParams # For vLLM inference
from transformers import AutoTokenizer # Import AutoTokenizer
from utils.metrics import qa_f1_score, qa_em_score

# This will be respected by vLLM if CUDA_VISIBLE_DEVICES is set before vLLM import
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" # User can set this outside the script

# --- OpenAI Client for Rephrasing ---
openai_client = OpenAI(
    api_key=os.environ.get("OPENAI_API_KEY"), 
    base_url=os.environ.get("OPENAI_BASE_URL")
)

def get_openai_rephrase_response(prompt, model="gpt-4o", retries=3, delay=2):
    """Call OpenAI API for rephrasing."""
    for attempt in range(retries):
        try:
            completion = openai_client.chat.completions.create(
                model=model,
                messages=[{'role': 'user', 'content': prompt}],
                max_tokens=100 # Max tokens for rephrased question
            )
            return completion.choices[0].message.content.strip()
        except Exception as e:
            print(f"OpenAI Rephrase attempt {attempt + 1} failed: {e}")
            if attempt < retries - 1:
                print(f"Retrying OpenAI rephrase in {delay} seconds...")
                time.sleep(delay)
            else:
                print("Max retries for OpenAI rephrase reached.")
                return "Failed to rephrase question"

def rephrase_question_with_gpt4o(question, rephrase_type="opposite"):
    """Rephrase a question using GPT-4o (English prompt)."""
    if rephrase_type == "opposite":
        prompt = f"""Please rephrase the following question to have the exact opposite meaning. 
Question: {question}

Return only the rephrased question with the opposite meaning, without any explanations or other content."""
    elif rephrase_type == "similar":
        prompt = f"""Please rephrase the following question to be synonymous, maintaining the original meaning but using different wording:
Question: {question}

Return only the rephrased question, without any explanations or other content."""
    else:
        raise ValueError(f"Invalid rephrase_type: {rephrase_type}. Must be 'opposite' or 'similar'.")
    
    return get_openai_rephrase_response(prompt)

# --- vLLM Model Functions (for Answering) ---
def get_vllm_response(prompt_text, llm_instance, sampling_params_instance, retries=2, delay=5):
    """Generate a response from a vLLM instance."""
    for attempt in range(retries):
        try:
            # vLLM generate method expects a list of prompts
            outputs = llm_instance.generate([prompt_text], sampling_params_instance)
            # For a single prompt, the result is in the first element of the output list
            # Each output object has a list of `outputs` (for n>1 in SamplingParams)
            response = outputs[0].outputs[0].text.strip()
            return response
        except Exception as e:
            print(f"vLLM generation attempt {attempt + 1} failed: {e}")
            if attempt < retries - 1:
                print(f"Retrying vLLM generation in {delay} seconds...")
                time.sleep(delay)
            else:
                print("Max retries for vLLM generation reached.")
                return "Failed to get vLLM response"

def answer_question_with_context_vllm(question, context, llm_instance, sampling_params_instance, tokenizer):
    """Answer a question with context using a vLLM model and chat template (English prompt)."""
    # Construct prompt using chat template, similar to evaluation.py
    prompt_content = (
        f"Answer the question based on the given passages. "
        "Only give me your answer and do not output any other words.\\n"
        "The following are given passages:\\n"
        f"{context}\\n"
        "Please strictly follow the context. "
        f"Question: {question}\\n"
        "Answer:"
    )
    messages = [{"role": "user", "content": prompt_content}]
    
    # Apply chat template
    # Note: Some tokenizers might not have a chat template configured, or might have different ways to apply it.
    # This is a common way for many models.
    try:
        final_prompt_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
    except Exception as e:
        print(f"Failed to apply chat template: {e}. Falling back to basic prompt string.")
        # Fallback to a simpler prompt if template application fails
        final_prompt_text = f"Context:\\n{context}\\n\\nQuestion: {question}\\n\\nAnswer:"

    return get_vllm_response(final_prompt_text, llm_instance, sampling_params_instance)

def main(args):
    # Load Tokenizer for the vLLM model
    print(f"Loading tokenizer for model: {args.model_name}...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=args.trust_remote_code)
        print("Successfully loaded tokenizer.")
    except Exception as e:
        print(f"Failed to load tokenizer for {args.model_name}: {e}")
        print("Please ensure the model name is correct and the tokenizer can be loaded.")
        return

    # Load vLLM Model (for Answering)
    print(f"Loading vLLM model for Answering: {args.model_name}...")
    print(f"(This may take a while depending on the model size and download speed if not cached).")
    vllm_model = None
    try:
        # You can expose more vLLM LLM parameters as args if needed
        # (e.g., tensor_parallel_size, dtype, gpu_memory_utilization)
        vllm_model = LLM(
            model=args.model_name, 
            trust_remote_code=args.trust_remote_code,
            dtype="bfloat16", # Use dtype from command line arguments
            # Add other vLLM LLM constructor arguments here if needed, e.g.:
            tensor_parallel_size=2
        )
        print(f"Successfully loaded vLLM model {args.model_name} with dtype='{args.dtype}' and tensor_parallel_size={args.tensor_parallel_size}.")
    except Exception as e:
        print(f"Failed to load vLLM model {args.model_name}: {e}")
        print("Please ensure vLLM is installed correctly and the model identifier is valid.")
        return

    # Define Sampling Parameters for vLLM
    # max_tokens is equivalent to max_new_tokens in HF
    # temperature=0.0 for greedy decoding, good for QA tasks for more deterministic output.
    # Adjust temperature (e.g., 0.7) and top_p (e.g., 0.95) for more diverse outputs if needed.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=30) # Set temperature to 0.0 for deterministic QA

    # Load dataset
    print(f"Loading dataset {args.dataset_name}, subset {args.dataset_subset}...")
    try:
        dataset = load_dataset(args.dataset_name, args.dataset_subset)["test"]
        print(f"Successfully loaded dataset with {len(dataset)} samples.")
    except Exception as e:
        print(f"Failed to load dataset: {e}")
        return

    em_match_count = 0  # Counter for EM matches
    em_match_original_count = 0 # Counter for EM matches
    successfully_processed_samples = 0 # Counter for successfully processed samples

    num_samples_to_process = len(dataset) if args.sample_count == -1 else min(args.sample_count, len(dataset))
    
    print(f"Processing {num_samples_to_process} samples. Rephrasing with GPT-4o (opposite meaning). Answering with vLLM model {args.model_name} (max 30 tokens)...")

    for i in tqdm(range(num_samples_to_process), desc="Processing samples with vLLM"):
        example = dataset[i]
        original_question = example['input']
        context = example['context']
        ground_truth_answers = example['answers']
        
        rephrased_question = rephrase_question_with_gpt4o(original_question, args.rephrase_type) # Use new rephrasing
        
        if rephrased_question == "Failed to rephrase question":
            print(f"Skipping sample {i+1} due to rephrasing failure.")
            continue
            
        rephrased_answer = answer_question_with_context_vllm(rephrased_question, context, vllm_model, sampling_params, tokenizer)
        # print(f"Rephrased question: {rephrased_question}") # Optional: for debugging
        # print(f"Answer to rephrased: {rephrased_answer}") # Optional: for debugging

        original_answer = answer_question_with_context_vllm(original_question, context, vllm_model, sampling_params, tokenizer)
        # print(f"Original question: {original_question}") # Optional: for debugging
        # print(f"Answer to original: {original_answer}") # Optional: for debugging
        
        if not ground_truth_answers:
            print(f"Skipping sample {i+1} due to missing ground truth answers.")
            continue
        print(original_answer)
        successfully_processed_samples += 1
        
        sample_had_em_match = False


        

        em_match_count += qa_em_score(rephrased_answer, ground_truth_answers[0])
        
        sample_had_em_match = False


        print(original_answer)
        print(ground_truth_answers[0])

        em_match_original_count += qa_em_score(original_answer, ground_truth_answers[0])

    if successfully_processed_samples > 0:
        print(f"Answering vLLM Model: {args.model_name}")
        print(f"Dataset             : {args.dataset_name} ({args.dataset_subset})")
        print(f"Successfully Processed Samples for Evaluation: {successfully_processed_samples}")
        print(f"Max Answer Tokens   : 30") # Reflects SamplingParams
        print(f"Count of EM with original ground truth (after rephrase): {em_match_count}")
        print(f"Count of EM with original ground truth (before rephrase): {em_match_original_count}")
    else:
        print("\nNo samples were processed adequately to provide an evaluation summary.")
    
    print("vLLM processing complete!")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Rephrase with GPT-4o, Answer with local vLLM-hosted Model, then Evaluate.")
    parser.add_argument("--model_name", type=str, default="facebook/opt-125m", help="Name/path of the Hugging Face model for Answering via vLLM (e.g., 'mistralai/Mistral-7B-Instruct-v0.1').")
    parser.add_argument("--dataset_name", type=str, default="THUDM/LongBench", help="Name of the Hugging Face dataset.")
    parser.add_argument("--dataset_subset", type=str, default="2wikimqa", help="Subset of the dataset.")
    parser.add_argument("--sample_count", type=int, default=3, help="Number of samples to process. -1 for all. Default: 3 for quick testing.")
    parser.add_argument("--trust_remote_code", action="store_true", help="Set to true if the Hugging Face model for vLLM requires remote code.")
    parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Tensor parallel size for vLLM.")
    parser.add_argument("--dtype", type=str, default="auto", help="Data type for the model. Examples: 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'. Default is 'auto'.")
    parser.add_argument("--rephrase_type", type=str, default="opposite", choices=["opposite", "similar"], help="Type of rephrasing: 'opposite' for opposite meaning or 'similar' for similar meaning.")
    
    args = parser.parse_args()
    main(args)