File size: 12,660 Bytes
b6a70f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import os
import json
import re
import requests
from tqdm import tqdm
from datetime import datetime
import glob
from requests.exceptions import Timeout
import argparse

prompt_template = (
    "# Interactional Dialogue Evaluation\n\n"
    "**IMPORTANT**: Evaluation must include `<response think>` and `<fluency think>` analysis and `<overall score>` rating.\n"  
    "Evaluate the quality of the interaction in the given dialogue transcript, focusing on:\n"
    "**Response Relevance:** \n"
    "**logical consistency, topic coherence**\n"
    "**Interactional Fluency:**\n"
    "**Detect and evaluate extended overlaps in conversation.**\n"
    "**Detect and evaluate long pauses between speaker turns.\n\n**"
    "**Note**: Small pauses and brief overlaps in conversation are acceptable, while prolonged pauses and overlapping turns are harmful. You should consider Response Relevance and Interactional Fluency separately, and provide the corresponding thinking process.\n\n"
    "## Scoring Criteria\n"
    "Assign a single holistic score based on the combined evaluation:\n"
    "`1` (Poor): Significant issues in either  **Response Relevance ** or  **Interactional Fluency. **\n"
    "`2` (Excellent): Both **Response Relevance ** and  **Interactional Fluency ** are consistently appropriate and natural.\n"
    "## Evaluation Output Format:\n"
    "Strictly follow this template:\n"
    "<response think>\n"
    "[Analysing Response Relevance and giving reasons for scoring...]\n"
    "</response think>\n"
    "<fluency think>\n"
    "[Analysing Interactional Fluency and giving reasons for scoring.]\n"
    "</fluency think>\n"
    "<overall score>X</overall score>\n"
)

# API configuration
url = "https://api2.aigcbest.top/v1/chat/completions"
headers = {
    "Authorization": "Bearer sk-yAIqUaGzzVNSesHq4mRPaCbt53MMFRJIMB97cS4FkRy6idwN",
    "Content-Type": "application/json",
    "Accept": "application/json"
}

def parse_args():
    parser = argparse.ArgumentParser(description='Process text evaluation with Gemini model')
    parser.add_argument('--input_file', type=str, required=True,
                      help='Input JSON file containing text data')
    parser.add_argument('--output_file', type=str, default='texterror_gemini.json',
                      help='Output JSON file for results')
    parser.add_argument('--error_file', type=str, default='texterror_gemini_error.json',
                      help='Output JSON file for errors')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints_test_text',
                      help='Directory for storing checkpoints')
    parser.add_argument('--max_retries', type=int, default=3,
                      help='Maximum number of retries for failed predictions')
    parser.add_argument('--checkpoint_interval', type=int, default=20,
                      help='Number of items to process before saving checkpoint')
    return parser.parse_args()

def extract_overall_score(output_str):
    """Extract <overall score>X</overall score> from model output."""
    score_pattern = r"<overall score>(\d+)</overall score>"
    match = re.search(score_pattern, output_str)
    if match:
        try:
            return int(match.group(1))
        except ValueError:
            pass
    return None

def validate_model_output(output_str):
    """Validate that the model output contains all required tags"""
    required_tags = [
        "<response think>",
        "</response think>",
        "<fluency think>",
        "</fluency think>",
        "<overall score>",
        "</overall score>"
    ]
    
    for tag in required_tags:
        if tag not in output_str:
            return False
    return True

def extract_tag_content(output_str, tag_name):
    """Extract content between opening and closing tags"""
    start_tag = f"<{tag_name}>"
    end_tag = f"</{tag_name}>"
    try:
        start_idx = output_str.find(start_tag) + len(start_tag)
        end_idx = output_str.find(end_tag)
        if start_idx == -1 or end_idx == -1:
            return None
        return output_str[start_idx:end_idx].strip()
    except:
        return None

def format_model_output(output_str):
    """Extract and format content from all required tags"""
    response_content = extract_tag_content(output_str, "response think")
    fluency_content = extract_tag_content(output_str, "fluency think")
    score_content = extract_tag_content(output_str, "overall score")
    
    if not all([response_content, fluency_content, score_content]):
        return None
        
    formatted_output = (
        f"<response think>\n{response_content}\n</response think>\n\n"
        f"<fluency think>\n{fluency_content}\n</fluency think>\n\n"
        f"<overall score>{score_content}</overall score>"
    )
    return formatted_output

def make_api_call(text_input, retry_count=0, max_retries=5):
    """Make API call with retry logic for API errors"""
    try:
        print(f"Attempting API call (attempt {retry_count + 1}/{max_retries + 1})")
        data_req = {
            "model": "gemini-2.5-flash-preview-05-20-thinking",
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": prompt_template
                        },
                        {
                            "type": "text",
                            "text": text_input
                        },
                    ]
                }
            ],
            "temperature": 1,
        }
        
        response = requests.post(url, headers=headers, json=data_req, timeout=(200, 200))
        print(f"API response received with status code: {response.status_code}")
        
        if response.status_code == 200:
            model_output = response.json()['choices'][0]['message']['content']
            if not validate_model_output(model_output):
                print("Model output missing required tags, retrying...")
                return None, None
                
            formatted_output = format_model_output(model_output)
            if formatted_output is None:
                print("Failed to extract content from tags, retrying...")
                return None, None
                
            pred_score = extract_overall_score(model_output)
            return formatted_output, pred_score
        else:
            print(f"API returned error status {response.status_code}: {response.text}")
            if retry_count >= max_retries:
                raise Exception(f"POST error {response.status_code}: {response.text}")
            return None, None
    except requests.exceptions.ConnectTimeout:
        print(f"Connection timeout (>10s)")
        if retry_count >= max_retries:
            raise Exception("Connection timeout")
        return None, None
    except requests.exceptions.ReadTimeout:
        print(f"Read timeout (>30s)")
        if retry_count >= max_retries:
            raise Exception("Read timeout")
        return None, None
    except Exception as e:
        print(f"Unexpected error during API call: {str(e)}")
        if retry_count >= max_retries:
            raise e
        return None, None

def get_latest_checkpoint(checkpoint_dir):
    """Get the latest checkpoint file and its processed count"""
    checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint_*.json"))
    if not checkpoint_files:
        return None, 0
    
    latest_checkpoint = None
    max_count = 0
    for checkpoint in checkpoint_files:
        try:
            count = int(os.path.basename(checkpoint).split('_')[1])
            if count > max_count:
                max_count = count
                latest_checkpoint = checkpoint
        except (ValueError, IndexError):
            continue
    
    return latest_checkpoint, max_count

def save_checkpoint(results, processed_count, checkpoint_dir):
    """Save results to a checkpoint file"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    checkpoint_file = os.path.join(checkpoint_dir, f"checkpoint_{processed_count}_{timestamp}.json")
    with open(checkpoint_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    print(f"Checkpoint saved: {checkpoint_file}")

def main():
    args = parse_args()
    
    # Initialize results storage
    results = []
    save_file_name = args.output_file
    error_file_name = args.error_file
    
    # Create checkpoints directory
    checkpoint_dir = args.checkpoint_dir
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    
    # Load test data
    all_data_file = args.input_file
    with open(all_data_file, 'r', encoding='utf-8') as f:
        all_data = json.load(f)
    
    # Initialize error tracking
    error_results = []
    
    # Load checkpoint if exists
    latest_checkpoint, checkpoint_count = get_latest_checkpoint(checkpoint_dir)
    if latest_checkpoint:
        print(f"Found latest checkpoint with {checkpoint_count} processed items: {latest_checkpoint}")
        try:
            with open(latest_checkpoint, 'r', encoding='utf-8') as f:
                results = json.load(f)
                print(f"Resumed from checkpoint: processed {len(results)} items")
        except Exception as e:
            print(f"Warning: Failed to load checkpoint {latest_checkpoint}: {e}")
            results = []
    else:
        print("No checkpoint found, starting from scratch")
        results = []
    
    max_prediction_retries = args.max_retries
    total_count = 0
    
    for item in tqdm(all_data, desc="Processing texts"):
        key = item.get('key')
        text_input = item.get('model_output')
        
        if not text_input:
            print(f"No text input found for key {key}, skipping...")
            continue
        
        print(f"Processing text for key={key}")
        
        prediction_retry_count = 0
        success = False
        
        while prediction_retry_count < max_prediction_retries and not success:
            try:
                print(f"\nProcessing attempt {prediction_retry_count + 1}")
                model_output, pred_score = make_api_call(text_input)
                
                if model_output is None or pred_score is None:
                    print("API call failed, retrying...")
                    prediction_retry_count += 1
                    continue
                    
                print(f"Received prediction: {pred_score}")
                
                if pred_score == 1:
                    success = True
                    print("Prediction score is 1, accepting result")
                else:
                    prediction_retry_count += 1
                    print(f"Prediction score is not 1 (attempt {prediction_retry_count}/{max_prediction_retries})")
                    if prediction_retry_count >= max_prediction_retries:
                        print("Max retries reached, accepting last prediction")
                        success = True
                    else:
                        continue
                
                results.append({
                    "key": key,
                    "text_input": text_input,
                    "model_output": model_output,
                    "predicted_score": pred_score,
                    "prediction_attempts": prediction_retry_count + 1
                })
                
                with open(save_file_name, "w", encoding="utf-8") as f:
                    json.dump(results, f, indent=2, ensure_ascii=False)
                    
                total_count += 1
                
                if total_count % args.checkpoint_interval == 0:
                    save_checkpoint(results, total_count, checkpoint_dir)
                
            except Exception as e:
                error_msg = str(e)
                print(f"Failed to process text for key {key}: {error_msg}")
                error_results.append({
                    "key": key,
                    "text_input": text_input,
                    "error": f"Exception: {error_msg}"
                })
                break
                
        with open(error_file_name, "w", encoding="utf-8") as f:
            json.dump(error_results, f, indent=2, ensure_ascii=False)
    
    # Save final results
    with open(save_file_name, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    print(f"Results saved to {save_file_name}")
    print(f"Total processed items: {total_count}")

if __name__ == "__main__":
    main()