| | |
| | |
| | """ |
| | 模型评估脚本 |
| | 功能: |
| | 1. 读取JSON文件中的对话数据 |
| | 2. 提取human的value作为query调用server:8020 |
| | 3. 处理流式返回结果 |
| | 4. 对比和存储结果 |
| | """ |
| |
|
| | import json |
| | import httpx |
| | import asyncio |
| | import time |
| | import re |
| | import os |
| | from typing import Dict, List, Any |
| | from utils.custom_logging import setup_logging |
| | from utils.extraction import extract_json_from_string |
| | from loguru import logger |
| | from collections import Counter |
| | setup_logging() |
| |
|
| |
|
| | class ModelEvaluator: |
| | def __init__(self, server_url: str = "http://localhost:8020/mcp_end2end/stream"): |
| | self.server_url = server_url |
| | self.results = [] |
| | self.client = None |
| | self.start_time = None |
| | self.error_count = 0 |
| | self.success_count = 0 |
| | |
| | def load_data(self, file_path: str) -> List[Dict]: |
| | """加载JSON数据文件""" |
| | try: |
| | with open(file_path, 'r', encoding='utf-8') as f: |
| | data = json.load(f) |
| | logger.info(f"成功加载数据文件,共{len(data)}条记录") |
| | return data |
| | except Exception as e: |
| | logger.error(f"加载数据文件失败: {e}") |
| | return [] |
| | |
| | def extract_human_queries(self, data: List[Dict]) -> List[Dict]: |
| | """提取所有human的value作为query""" |
| | queries = [] |
| | for i, item in enumerate(data): |
| | if 'conversations' in item: |
| | for conv in item['conversations']: |
| | if conv.get('from') == 'human': |
| | query_data = { |
| | 'index': i, |
| | 'query': conv.get('value', ''), |
| | 'original_data': item |
| | } |
| | queries.append(query_data) |
| | break |
| | logger.info(f"提取到{len(queries)}个查询") |
| | return queries |
| | |
| | def parse_sse_events(self, sse_content: str, filter_events: List[str] = None) -> List[Dict]: |
| | """ |
| | 解析SSE格式的内容,提取指定类型的事件 |
| | |
| | Args: |
| | sse_content: SSE格式的文本内容(可以是多行) |
| | filter_events: 需要过滤的事件类型列表,如果为None则解析所有事件 |
| | |
| | Returns: |
| | 解析成功的事件列表 |
| | """ |
| | events = [] |
| | current_event = {} |
| | parsed_count = 0 |
| | failed_count = 0 |
| | |
| | |
| | if filter_events is None: |
| | filter_events = ['tool_call.created', 'tool_response.completed'] |
| | |
| | for line in sse_content.split('\n'): |
| | line = line.strip() |
| | if not line: |
| | |
| | if current_event and 'event' in current_event and 'data' in current_event: |
| | event_type = current_event['event'] |
| | |
| | |
| | if event_type in filter_events: |
| | |
| | data_content = extract_json_from_string(current_event['data']) |
| | if data_content is not None: |
| | event_obj = { |
| | 'id': current_event.get('id'), |
| | 'event': current_event['event'], |
| | 'data': data_content |
| | } |
| | events.append(event_obj) |
| | parsed_count += 1 |
| | logger.debug(f"✅ 成功解析事件: {current_event['event']} (ID: {current_event.get('id', 'N/A')})") |
| | else: |
| | failed_count += 1 |
| | logger.warning(f"❌ 无法解析事件数据: {current_event['event']} - {current_event['data'][:100]}...") |
| | current_event = {} |
| | continue |
| | |
| | |
| | if line.startswith('id: '): |
| | current_event['id'] = line[4:] |
| | elif line.startswith('event: '): |
| | current_event['event'] = line[7:] |
| | elif line.startswith('data: '): |
| | current_event['data'] = line[6:] |
| | |
| | if current_event['data'].strip() == '[DONE]': |
| | logger.debug("收到结束标记 [DONE]") |
| | break |
| | else: |
| | logger.debug(f"未知格式的行: {line}") |
| | |
| | |
| | if current_event and 'event' in current_event and 'data' in current_event: |
| | event_type = current_event['event'] |
| | |
| | |
| | if event_type in filter_events: |
| | data_content = extract_json_from_string(current_event['data']) |
| | if data_content is not None: |
| | event_obj = { |
| | 'id': current_event.get('id'), |
| | 'event': current_event['event'], |
| | 'data': data_content |
| | } |
| | events.append(event_obj) |
| | parsed_count += 1 |
| | logger.debug(f"✅ 成功解析最后一个事件: {current_event['event']} (ID: {current_event.get('id', 'N/A')})") |
| | else: |
| | failed_count += 1 |
| | logger.warning(f"❌ 无法解析最后一个事件数据: {current_event['event']} - {current_event['data'][:100]}...") |
| | |
| | |
| | logger.info(f"=== SSE解析结果统计 ===") |
| | logger.info(f"成功解析事件数: {parsed_count}") |
| | logger.info(f"解析失败事件数: {failed_count}") |
| | logger.info(f"总事件数: {len(events)}") |
| | |
| | if events: |
| | event_types = [event.get('event', 'unknown') for event in events] |
| | event_counts = Counter(event_types) |
| | logger.info(f"事件类型分布: {dict(event_counts)}") |
| | else: |
| | logger.warning("⚠️ 未解析到任何目标事件") |
| | |
| | return events |
| |
|
| | async def call_server(self, query: str, max_retries: int = 3, retry_delay: float = 2.0) -> List[Dict]: |
| | """异步调用server:8020端口,处理流式返回,支持重试机制""" |
| | payload = { |
| | "user_id": "166", |
| | "role_code": 1, |
| | "query": query, |
| | "save_method": 0 |
| | } |
| | |
| | for attempt in range(max_retries): |
| | try: |
| | async with httpx.AsyncClient(timeout=30.0) as client: |
| | async with client.stream( |
| | 'POST', |
| | self.server_url, |
| | json=payload, |
| | headers={'Accept': 'text/event-stream'} |
| | ) as response: |
| | response.raise_for_status() |
| | |
| | |
| | sse_content = "" |
| | async for line in response.aiter_text(): |
| | logger.debug(f"Received data: {line}") |
| | sse_content += line |
| | |
| | |
| | if '[DONE]' in line: |
| | logger.debug("收到结束标记 [DONE]") |
| | break |
| | |
| | |
| | events = self.parse_sse_events( |
| | sse_content, |
| | filter_events=['tool_call.created', 'tool_response.completed'] |
| | ) |
| | |
| | |
| | has_tool_call = any(event.get('event') == 'tool_call.created' for event in events) |
| | has_tool_response = any(event.get('event') == 'tool_response.completed' for event in events) |
| | logger.info(f"包含工具调用事件: {'✅' if has_tool_call else '❌'}") |
| | logger.info(f"包含工具响应事件: {'✅' if has_tool_response else '❌'}") |
| | |
| | return events |
| | |
| | except httpx.RequestError as e: |
| | logger.warning(f"Call server failed (attempt {attempt + 1}/{max_retries}): {e}") |
| | if attempt < max_retries - 1: |
| | logger.info(f"Retrying in {retry_delay} seconds...") |
| | await asyncio.sleep(retry_delay) |
| | retry_delay *= 2 |
| | else: |
| | logger.error(f"All retry attempts failed for query: {query[:50]}...") |
| | raise Exception(f"Server connection failed after {max_retries} attempts: {e}") |
| | except httpx.TimeoutException as e: |
| | logger.warning(f"Server timeout (attempt {attempt + 1}/{max_retries}): {e}") |
| | if attempt < max_retries - 1: |
| | logger.info(f"Retrying in {retry_delay} seconds...") |
| | await asyncio.sleep(retry_delay) |
| | retry_delay *= 2 |
| | else: |
| | logger.error(f"Timeout after all retry attempts for query: {query[:50]}...") |
| | raise Exception(f"Server timeout after {max_retries} attempts: {e}") |
| | except Exception as e: |
| | logger.error(f"Unexpected error processing response (attempt {attempt + 1}/{max_retries}): {e}") |
| | if attempt < max_retries - 1: |
| | logger.info(f"Retrying in {retry_delay} seconds...") |
| | await asyncio.sleep(retry_delay) |
| | retry_delay *= 2 |
| | else: |
| | logger.error(f"Unexpected error after all retry attempts for query: {query[:50]}...") |
| | raise Exception(f"Unexpected error after {max_retries} attempts: {e}") |
| | |
| | raise Exception("All retry attempts exhausted") |
| | |
| | def extract_tool_calls_and_observations(self, events: List[Dict]) -> Dict[str, List]: |
| | """Extract tool_call.created and tool_response.completed content from events""" |
| | tool_calls = [] |
| | tool_responses = [] |
| | |
| | logger.debug(f"开始提取工具调用和响应,共 {len(events)} 个事件") |
| | |
| | for event in events: |
| | event_type = event.get('event') |
| | event_data = event.get('data', {}) |
| | |
| | if event_type == 'tool_call.created': |
| | logger.debug(f"Extract tool_call.created content: {event}") |
| | |
| | tool_call_info = event_data.get('tool_call', {}) |
| | if tool_call_info: |
| | tool_calls.append(tool_call_info) |
| | logger.debug(f"✅ 提取工具调用: {tool_call_info.get('name', 'unknown')}") |
| | else: |
| | logger.warning(f"❌ tool_call.created 事件中缺少 tool_call 信息") |
| | |
| | elif event_type == 'tool_response.completed': |
| | logger.debug(f"Extract tool_response.completed content: {event}") |
| | |
| |
|
| | if 'result_delta' in event_data: |
| | tool_response = event_data['result_delta'].get('result', []) |
| | tool_responses.append(tool_response) |
| | logger.debug(f"✅ 提取工具响应: {len(str(tool_response))} 字符") |
| | else: |
| | tool_response = [] |
| | tool_responses.append(tool_response) |
| | |
| | logger.info(f"Extract {len(tool_calls)} tool calls, {len(tool_responses)} tool responses") |
| | return { |
| | 'tool_calls': tool_calls, |
| | 'tool_responses': tool_responses |
| | } |
| | |
| | def extract_original_data(self, original_data: Dict) -> Dict[str, List]: |
| | """Extract function_call and observation content from original data""" |
| | function_calls = [] |
| | observations = [] |
| | |
| | if 'conversations' in original_data: |
| | for conv in original_data['conversations']: |
| | if conv.get('from') == 'function_call': |
| | |
| | try: |
| | function_call_obj = json.loads(conv.get('value', '{}')) |
| | function_calls.append(function_call_obj) |
| | except json.JSONDecodeError as e: |
| | logger.warning(f"解析function_call JSON时出错: {e}") |
| | function_calls.append({}) |
| | elif conv.get('from') == 'observation': |
| | |
| | try: |
| | observation_obj = json.loads(conv.get('value', '[]')) |
| | observations.append(observation_obj) |
| | except json.JSONDecodeError as e: |
| | logger.warning(f"解析observation JSON时出错: {e}") |
| | observations.append([]) |
| | |
| | return { |
| | 'function_calls': function_calls, |
| | 'observations': observations |
| | } |
| | |
| | def compare_tool_call(self, server_call: Dict, original_call: Dict) -> Dict: |
| | """比较单个工具调用,检查name和arguments的匹配度""" |
| | try: |
| | |
| | name_match = server_call.get('name') == original_call.get('name') |
| | name_score = 1.0 if name_match else 0.0 |
| | |
| | |
| | server_args = server_call.get('arguments', {}) |
| | original_args = original_call.get('arguments', {}) |
| | arguments_match = server_args == original_args |
| | arguments_score = 1.0 if arguments_match else 0.0 |
| | |
| | return { |
| | 'name_match': name_match, |
| | 'name_score': name_score, |
| | 'arguments_match': arguments_match, |
| | 'arguments_score': arguments_score, |
| | 'server_name': server_call.get('name', ''), |
| | 'original_name': original_call.get('name', ''), |
| | 'server_arguments': server_args, |
| | 'original_arguments': original_args |
| | } |
| | except (KeyError, TypeError) as e: |
| | logger.warning(f"比较工具调用时出错: {e}") |
| | return { |
| | 'name_match': False, |
| | 'name_score': 0.0, |
| | 'arguments_match': False, |
| | 'arguments_score': 0.0, |
| | 'server_name': '', |
| | 'original_name': '', |
| | 'server_arguments': {}, |
| | 'original_arguments': {}, |
| | 'error': str(e) |
| | } |
| |
|
| | def compare_results(self, server_data: Dict[str, List], original_data: Dict[str, List]) -> Dict: |
| | """详细比较服务器返回结果和原始数据""" |
| | |
| | |
| | comparison = { |
| | 'tool_calls_comparison': { |
| | 'server_count': len(server_data['tool_calls']), |
| | 'original_count': len(original_data['function_calls']), |
| | 'detailed_scores': [], |
| | 'name_average_score': 0.0, |
| | 'arguments_average_score': 0.0, |
| | 'non_retrieval_name_average_score': 0.0, |
| | 'non_retrieval_arguments_average_score': 0.0 |
| | }, |
| | 'tool_responses_comparison': { |
| | 'server_count': len(server_data['tool_responses']), |
| | 'original_count': len(original_data['observations']), |
| | 'detailed_scores': [], |
| | 'average_score': 0.0 |
| | }, |
| | 'overall_scores': { |
| | 'tool_responses_avg': 0.0 |
| | } |
| | } |
| | |
| | |
| | tool_call_name_scores = [] |
| | tool_call_arguments_scores = [] |
| | non_retrieval_name_scores = [] |
| | non_retrieval_arguments_scores = [] |
| | max_tool_calls = max(len(server_data['tool_calls']), len(original_data['function_calls'])) |
| | |
| | for i in range(max_tool_calls): |
| | server_call = server_data['tool_calls'][i] if i < len(server_data['tool_calls']) else None |
| | original_call = original_data['function_calls'][i] if i < len(original_data['function_calls']) else None |
| | |
| | if server_call is None: |
| | |
| | score_detail = { |
| | 'index': i, |
| | 'server_present': False, |
| | 'original_present': True, |
| | 'name_score': 0.0, |
| | 'arguments_score': 0.0, |
| | 'original_call': original_call |
| | } |
| | elif original_call is None: |
| | |
| | score_detail = { |
| | 'index': i, |
| | 'server_present': True, |
| | 'original_present': False, |
| | 'name_score': 0.0, |
| | 'arguments_score': 0.0, |
| | 'server_call': server_call |
| | } |
| | else: |
| | |
| | |
| | call_comparison = self.compare_tool_call(server_call, original_call) |
| | score_detail = { |
| | 'index': i, |
| | 'server_present': True, |
| | 'original_present': True, |
| | 'name_score': call_comparison['name_score'], |
| | 'arguments_score': call_comparison['arguments_score'], |
| | 'name_match': call_comparison['name_match'], |
| | 'arguments_match': call_comparison['arguments_match'], |
| | 'server_name': call_comparison['server_name'], |
| | 'original_name': call_comparison['original_name'], |
| | 'server_call': server_call, |
| | 'original_call': original_call |
| | } |
| | if 'error' in call_comparison: |
| | score_detail['error'] = call_comparison['error'] |
| | |
| | comparison['tool_calls_comparison']['detailed_scores'].append(score_detail) |
| | tool_call_name_scores.append(score_detail['name_score']) |
| | tool_call_arguments_scores.append(score_detail['arguments_score']) |
| | |
| | |
| | if server_call and original_call: |
| | server_name = server_call.get('name', '') |
| | original_name = original_call.get('name', '') |
| | |
| | if server_name != 'retrieval_tool' and original_name != 'retrieval_tool': |
| | non_retrieval_name_scores.append(score_detail['name_score']) |
| | non_retrieval_arguments_scores.append(score_detail['arguments_score']) |
| | |
| | |
| | comparison['tool_calls_comparison']['name_average_score'] = ( |
| | sum(tool_call_name_scores) / len(tool_call_name_scores) if tool_call_name_scores else 0.0 |
| | ) |
| | comparison['tool_calls_comparison']['arguments_average_score'] = ( |
| | sum(tool_call_arguments_scores) / len(tool_call_arguments_scores) if tool_call_arguments_scores else 0.0 |
| | ) |
| | |
| | |
| | comparison['tool_calls_comparison']['non_retrieval_name_average_score'] = ( |
| | sum(non_retrieval_name_scores) / len(non_retrieval_name_scores) if non_retrieval_name_scores else 0.0 |
| | ) |
| | comparison['tool_calls_comparison']['non_retrieval_arguments_average_score'] = ( |
| | sum(non_retrieval_arguments_scores) / len(non_retrieval_arguments_scores) if non_retrieval_arguments_scores else 0.0 |
| | ) |
| | |
| | |
| | tool_response_scores = [] |
| | max_tool_responses = max(len(server_data['tool_responses']), len(original_data['observations'])) |
| | |
| | for i in range(max_tool_responses): |
| | server_response = server_data['tool_responses'][i] if i < len(server_data['tool_responses']) else None |
| | original_response = original_data['observations'][i] if i < len(original_data['observations']) else None |
| | |
| | if server_response is None: |
| | |
| | score_detail = { |
| | 'index': i, |
| | 'server_present': False, |
| | 'original_present': True, |
| | 'match_score': 0.0, |
| | 'original_response': original_response |
| | } |
| | elif original_response is None: |
| | |
| | score_detail = { |
| | 'index': i, |
| | 'server_present': True, |
| | 'original_present': False, |
| | 'match_score': 0.0, |
| | 'server_response': server_response |
| | } |
| | else: |
| | |
| | responses_match = server_response == original_response |
| | match_score = 1.0 if responses_match else 0.0 |
| | |
| | score_detail = { |
| | 'index': i, |
| | 'server_present': True, |
| | 'original_present': True, |
| | 'match_score': match_score, |
| | 'responses_match': responses_match, |
| | 'server_response': server_response, |
| | 'original_response': original_response |
| | } |
| | |
| | comparison['tool_responses_comparison']['detailed_scores'].append(score_detail) |
| | tool_response_scores.append(score_detail['match_score']) |
| | |
| | |
| | comparison['tool_responses_comparison']['average_score'] = ( |
| | sum(tool_response_scores) / len(tool_response_scores) if tool_response_scores else 0.0 |
| | ) |
| | |
| | |
| | comparison['overall_scores']['tool_responses_avg'] = comparison['tool_responses_comparison']['average_score'] |
| | |
| | |
| | comparison['tool_calls_match'] = ( |
| | comparison['tool_calls_comparison']['name_average_score'] == 1.0 and |
| | comparison['tool_calls_comparison']['arguments_average_score'] == 1.0 |
| | ) |
| | comparison['tool_responses_match'] = comparison['overall_scores']['tool_responses_avg'] == 1.0 |
| | |
| | return comparison |
| | |
| | def calculate_global_scores(self, results: List[Dict]) -> Dict: |
| | """计算多个结果的全局评分""" |
| | if not results: |
| | return { |
| | 'global_tool_responses_avg': 0.0, |
| | 'global_tool_calls_name_avg': 0.0, |
| | 'global_tool_calls_arguments_avg': 0.0, |
| | 'global_non_retrieval_name_avg': 0.0, |
| | 'global_non_retrieval_arguments_avg': 0.0, |
| | 'total_queries': 0 |
| | } |
| | |
| | |
| | all_tool_responses_scores = [] |
| | all_tool_calls_name_scores = [] |
| | all_tool_calls_arguments_scores = [] |
| | all_non_retrieval_name_scores = [] |
| | all_non_retrieval_arguments_scores = [] |
| | |
| | for result in results: |
| | comparison = result.get('comparison', {}) |
| | overall_scores = comparison.get('overall_scores', {}) |
| | tool_calls_comparison = comparison.get('tool_calls_comparison', {}) |
| | |
| | tool_responses_avg = overall_scores.get('tool_responses_avg', 0.0) |
| | |
| | |
| | detailed_scores = tool_calls_comparison.get('detailed_scores', []) |
| | for score_detail in detailed_scores: |
| | if score_detail.get('server_present') and score_detail.get('original_present'): |
| | all_tool_calls_name_scores.append(score_detail.get('name_score', 0.0)) |
| | all_tool_calls_arguments_scores.append(score_detail.get('arguments_score', 0.0)) |
| | |
| | |
| | server_call = score_detail.get('server_call', {}) |
| | original_call = score_detail.get('original_call', {}) |
| | if server_call and original_call: |
| | server_name = server_call.get('name', '') |
| | original_name = original_call.get('name', '') |
| | |
| | if server_name != 'retrieval_tool' and original_name != 'retrieval_tool': |
| | all_non_retrieval_name_scores.append(score_detail.get('name_score', 0.0)) |
| | all_non_retrieval_arguments_scores.append(score_detail.get('arguments_score', 0.0)) |
| | |
| | all_tool_responses_scores.append(tool_responses_avg) |
| | |
| | |
| | global_tool_responses_avg = sum(all_tool_responses_scores) / len(all_tool_responses_scores) if all_tool_responses_scores else 0.0 |
| | global_tool_calls_name_avg = sum(all_tool_calls_name_scores) / len(all_tool_calls_name_scores) if all_tool_calls_name_scores else 0.0 |
| | global_tool_calls_arguments_avg = sum(all_tool_calls_arguments_scores) / len(all_tool_calls_arguments_scores) if all_tool_calls_arguments_scores else 0.0 |
| | global_non_retrieval_name_avg = sum(all_non_retrieval_name_scores) / len(all_non_retrieval_name_scores) if all_non_retrieval_name_scores else 0.0 |
| | global_non_retrieval_arguments_avg = sum(all_non_retrieval_arguments_scores) / len(all_non_retrieval_arguments_scores) if all_non_retrieval_arguments_scores else 0.0 |
| | |
| | return { |
| | 'global_tool_responses_avg': global_tool_responses_avg, |
| | 'global_tool_calls_name_avg': global_tool_calls_name_avg, |
| | 'global_tool_calls_arguments_avg': global_tool_calls_arguments_avg, |
| | 'global_non_retrieval_name_avg': global_non_retrieval_name_avg, |
| | 'global_non_retrieval_arguments_avg': global_non_retrieval_arguments_avg, |
| | 'total_queries': len(results) |
| | } |
| | |
| | def save_results(self, results: List[Dict], output_file: str): |
| | """Save evaluation results to file""" |
| | try: |
| | |
| | global_scores = self.calculate_global_scores(results) |
| | |
| | |
| | complete_results = { |
| | 'global_scores': global_scores, |
| | 'results': results |
| | } |
| | |
| | with open(output_file, 'w', encoding='utf-8') as f: |
| | json.dump(complete_results, f, ensure_ascii=False, indent=2) |
| | logger.info(f"Results saved to: {output_file}") |
| | except Exception as e: |
| | logger.error(f"Save results failed: {e}") |
| | |
| | def save_checkpoint(self, results: List[Dict], checkpoint_file: str, processed_count: int, total_count: int): |
| | """保存检查点文件""" |
| | try: |
| | checkpoint_data = { |
| | 'processed_count': processed_count, |
| | 'total_count': total_count, |
| | 'results': results, |
| | 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S') |
| | } |
| | |
| | with open(checkpoint_file, 'w', encoding='utf-8') as f: |
| | json.dump(checkpoint_data, f, ensure_ascii=False, indent=2) |
| | logger.info(f"Checkpoint saved: {processed_count}/{total_count} processed") |
| | except Exception as e: |
| | logger.error(f"Save checkpoint failed: {e}") |
| | |
| | def load_checkpoint(self, checkpoint_file: str) -> Dict: |
| | """加载检查点文件""" |
| | try: |
| | if os.path.exists(checkpoint_file): |
| | with open(checkpoint_file, 'r', encoding='utf-8') as f: |
| | checkpoint_data = json.load(f) |
| | logger.info(f"Checkpoint loaded: {checkpoint_data['processed_count']}/{checkpoint_data['total_count']} processed") |
| | return checkpoint_data |
| | else: |
| | logger.info("No checkpoint file found, starting from beginning") |
| | return None |
| | except Exception as e: |
| | logger.error(f"Load checkpoint failed: {e}") |
| | return None |
| | |
| | def print_progress(self, current: int, total: int, start_time: float): |
| | """打印进度信息""" |
| | if total == 0: |
| | return |
| | |
| | elapsed_time = time.time() - start_time |
| | progress_percent = (current / total) * 100 |
| | |
| | if current > 0: |
| | avg_time_per_query = elapsed_time / current |
| | remaining_queries = total - current |
| | estimated_remaining_time = remaining_queries * avg_time_per_query |
| | |
| | logger.info(f"进度: {current}/{total} ({progress_percent:.1f}%) | " |
| | f"成功: {self.success_count} | 错误: {self.error_count} | " |
| | f"已用时间: {elapsed_time/60:.1f}分钟 | " |
| | f"预计剩余: {estimated_remaining_time/60:.1f}分钟") |
| | else: |
| | logger.info(f"进度: {current}/{total} ({progress_percent:.1f}%) | " |
| | f"成功: {self.success_count} | 错误: {self.error_count} | " |
| | f"已用时间: {elapsed_time/60:.1f}分钟") |
| | |
| | def save_progress_report(self, output_file: str, current: int, total: int): |
| | """保存进度报告""" |
| | try: |
| | progress_data = { |
| | 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), |
| | 'current_progress': current, |
| | 'total_queries': total, |
| | 'success_count': self.success_count, |
| | 'error_count': self.error_count, |
| | 'progress_percentage': (current / total * 100) if total > 0 else 0, |
| | 'elapsed_time_minutes': (time.time() - self.start_time) / 60 if self.start_time else 0 |
| | } |
| | |
| | progress_file = f"{output_file}.progress" |
| | with open(progress_file, 'w', encoding='utf-8') as f: |
| | json.dump(progress_data, f, ensure_ascii=False, indent=2) |
| | |
| | except Exception as e: |
| | logger.error(f"Save progress report failed: {e}") |
| | |
| | def generate_interruption_report(self, output_file: str, processed_count: int, total_queries: int, error_message: str): |
| | """生成中断报告""" |
| | try: |
| | total_time = time.time() - self.start_time if self.start_time else 0 |
| | |
| | interruption_report = { |
| | 'interruption_type': 'server_connection_failure', |
| | 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), |
| | 'processed_count': processed_count, |
| | 'total_queries': total_queries, |
| | 'success_count': self.success_count, |
| | 'error_count': self.error_count, |
| | 'progress_percentage': (processed_count / total_queries * 100) if total_queries > 0 else 0, |
| | 'elapsed_time_minutes': total_time / 60, |
| | 'error_message': error_message, |
| | 'resume_instructions': { |
| | 'checkpoint_file': f"{output_file}.checkpoint", |
| | 'command': f"await evaluator.evaluate(input_file='data/9.17_evaluate_data_top5_final.json', output_file='{output_file}', resume=True)", |
| | 'note': '使用 resume=True 参数从检查点恢复评估' |
| | } |
| | } |
| | |
| | interruption_file = f"{output_file}.interruption_report" |
| | with open(interruption_file, 'w', encoding='utf-8') as f: |
| | json.dump(interruption_report, f, ensure_ascii=False, indent=2) |
| | |
| | logger.info(f"中断报告已保存到: {interruption_file}") |
| | |
| | except Exception as e: |
| | logger.error(f"Generate interruption report failed: {e}") |
| | |
| | async def evaluate(self, input_file: str, output_file: str = "evaluation_results.json", |
| | batch_size: int = 50, start_index: int = 0, max_queries: int = None, |
| | checkpoint_file: str = None, resume: bool = True): |
| | """Execute complete evaluation process with batch processing and checkpoint support""" |
| | logger.info("Start model evaluation...") |
| | |
| | |
| | self.start_time = time.time() |
| | self.error_count = 0 |
| | self.success_count = 0 |
| | |
| | |
| | if checkpoint_file is None: |
| | checkpoint_file = f"{output_file}.checkpoint" |
| | |
| | |
| | checkpoint_data = None |
| | if resume: |
| | checkpoint_data = self.load_checkpoint(checkpoint_file) |
| | if checkpoint_data: |
| | self.results = checkpoint_data.get('results', []) |
| | start_index = checkpoint_data.get('processed_count', 0) |
| | self.success_count = len(self.results) |
| | logger.info(f"Resuming from checkpoint: {start_index} queries already processed") |
| | |
| | |
| | data = self.load_data(input_file) |
| | if not data: |
| | logger.error("Cannot load data, evaluation terminated") |
| | return |
| | |
| | |
| | queries = self.extract_human_queries(data) |
| | if not queries: |
| | logger.error("No valid queries found, evaluation terminated") |
| | return |
| | |
| | |
| | if max_queries: |
| | queries = queries[:max_queries] |
| | |
| | if start_index > 0: |
| | queries = queries[start_index:] |
| | logger.info(f"Starting from index {start_index}, processing {len(queries)} queries") |
| | |
| | |
| | total_queries = len(queries) |
| | processed_count = len(self.results) |
| | |
| | for batch_start in range(0, total_queries, batch_size): |
| | batch_end = min(batch_start + batch_size, total_queries) |
| | batch_queries = queries[batch_start:batch_end] |
| | |
| | logger.info(f"Processing batch {batch_start//batch_size + 1}: queries {batch_start + 1}-{batch_end} of {total_queries}") |
| | |
| | |
| | for i, query_data in enumerate(batch_queries): |
| | global_index = batch_start + i |
| | logger.info(f"Process {global_index + 1}/{total_queries} query: {query_data['query'][:50]}...") |
| | |
| | try: |
| | |
| | events = await self.call_server(query_data['query']) |
| | |
| | |
| | server_data = self.extract_tool_calls_and_observations(events) |
| | |
| | |
| | original_data = self.extract_original_data(query_data['original_data']) |
| | |
| | |
| | comparison = self.compare_results(server_data, original_data) |
| | |
| | |
| | result = { |
| | 'index': query_data['index'], |
| | 'query': query_data['query'], |
| | 'server_events_count': len(events), |
| | 'server_tool_calls': server_data['tool_calls'], |
| | 'server_tool_responses': server_data['tool_responses'], |
| | 'original_function_calls': original_data['function_calls'], |
| | 'original_observations': original_data['observations'], |
| | 'comparison': comparison, |
| | 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S') |
| | } |
| | |
| | self.results.append(result) |
| | processed_count += 1 |
| | self.success_count += 1 |
| | |
| | |
| | if processed_count % 10 == 0: |
| | self.save_checkpoint(self.results, checkpoint_file, processed_count, total_queries) |
| | self.save_progress_report(output_file, processed_count, total_queries) |
| | self.print_progress(processed_count, total_queries, self.start_time) |
| | |
| | |
| | await asyncio.sleep(1) |
| | |
| | except Exception as e: |
| | logger.error(f"Error processing query {global_index + 1}: {e}") |
| | self.error_count += 1 |
| | |
| | |
| | if "Server connection failed after" in str(e) or "Server timeout after" in str(e) or "Unexpected error after" in str(e): |
| | logger.error(f"🚨 服务器连接失败,保存检查点并结束评估") |
| | logger.error(f"失败查询: {query_data['query'][:50]}...") |
| | |
| | |
| | self.save_checkpoint(self.results, checkpoint_file, processed_count, total_queries) |
| | self.save_progress_report(output_file, processed_count, total_queries) |
| | |
| | |
| | self.generate_interruption_report(output_file, processed_count, total_queries, str(e)) |
| | |
| | logger.error(f"评估因服务器连接失败而中断") |
| | logger.error(f"已处理 {processed_count}/{total_queries} 个查询") |
| | logger.error(f"检查点已保存到: {checkpoint_file}") |
| | logger.error(f"可以稍后使用 resume=True 从检查点恢复") |
| | |
| | return |
| | else: |
| | |
| | logger.warning(f"查询处理失败,继续处理下一个查询: {e}") |
| | continue |
| | |
| | |
| | batch_output_file = f"{output_file}.batch_{batch_start//batch_size + 1}" |
| | self.save_results(self.results, batch_output_file) |
| | logger.info(f"Batch {batch_start//batch_size + 1} completed, results saved to {batch_output_file}") |
| | |
| | |
| | self.save_checkpoint(self.results, checkpoint_file, processed_count, total_queries) |
| | self.save_progress_report(output_file, processed_count, total_queries) |
| | self.print_progress(processed_count, total_queries, self.start_time) |
| | |
| | |
| | self.save_results(self.results, output_file) |
| | |
| | |
| | if os.path.exists(checkpoint_file): |
| | os.remove(checkpoint_file) |
| | logger.info("Checkpoint file removed after successful completion") |
| | |
| | |
| | self.generate_summary_report() |
| | |
| | |
| | total_time = time.time() - self.start_time |
| | logger.info(f"=== 评估完成 ===") |
| | logger.info(f"总查询数: {total_queries}") |
| | logger.info(f"成功处理: {self.success_count}") |
| | logger.info(f"处理失败: {self.error_count}") |
| | logger.info(f"总用时: {total_time/60:.1f}分钟") |
| | logger.info(f"平均每查询用时: {total_time/total_queries:.1f}秒") |
| | |
| | def generate_summary_report(self): |
| | """生成详细的评估摘要报告""" |
| | if not self.results: |
| | return |
| | |
| | total_queries = len(self.results) |
| | |
| | |
| | global_scores = self.calculate_global_scores(self.results) |
| | |
| | |
| | query_details = [] |
| | |
| | for i, result in enumerate(self.results): |
| | comparison = result['comparison'] |
| | overall_scores = comparison.get('overall_scores', {}) |
| | tool_calls_comparison = comparison.get('tool_calls_comparison', {}) |
| | |
| | tool_responses_avg = overall_scores.get('tool_responses_avg', 0.0) |
| | tool_calls_name_avg = tool_calls_comparison.get('name_average_score', 0.0) |
| | tool_calls_arguments_avg = tool_calls_comparison.get('arguments_average_score', 0.0) |
| | non_retrieval_name_avg = tool_calls_comparison.get('non_retrieval_name_average_score', 0.0) |
| | non_retrieval_arguments_avg = tool_calls_comparison.get('non_retrieval_arguments_average_score', 0.0) |
| | |
| | query_details.append({ |
| | 'index': i, |
| | 'query': result['query'][:50] + '...' if len(result['query']) > 50 else result['query'], |
| | 'tool_calls_name_score': tool_calls_name_avg, |
| | 'tool_calls_arguments_score': tool_calls_arguments_avg, |
| | 'non_retrieval_name_score': non_retrieval_name_avg, |
| | 'non_retrieval_arguments_score': non_retrieval_arguments_avg, |
| | 'tool_responses_score': tool_responses_avg |
| | }) |
| | |
| | |
| | tool_calls_perfect_matches = sum(1 for r in self.results if r['comparison']['tool_calls_match']) |
| | tool_responses_perfect_matches = sum(1 for r in self.results if r['comparison']['tool_responses_match']) |
| | |
| | |
| | report = f""" |
| | === 模型评估详细摘要报告 === |
| | |
| | 【整体统计】 |
| | 总查询数: {total_queries} |
| | 工具调用完全匹配数: {tool_calls_perfect_matches} ({tool_calls_perfect_matches/total_queries*100:.1f}%) |
| | 工具响应完全匹配数: {tool_responses_perfect_matches} ({tool_responses_perfect_matches/total_queries*100:.1f}%) |
| | |
| | 【全局平均评分】 |
| | 工具名称匹配平均分: {global_scores['global_tool_calls_name_avg']:.3f} |
| | 工具参数匹配平均分: {global_scores['global_tool_calls_arguments_avg']:.3f} |
| | 非retrieval工具名称匹配平均分: {global_scores['global_non_retrieval_name_avg']:.3f} |
| | 非retrieval工具参数匹配平均分: {global_scores['global_non_retrieval_arguments_avg']:.3f} |
| | 工具响应全局平均分: {global_scores['global_tool_responses_avg']:.3f} |
| | |
| | 【各查询详细评分】""" |
| | |
| | for detail in query_details: |
| | report += f""" |
| | Query {detail['index']}: {detail['query']} |
| | - 工具名称评分: {detail['tool_calls_name_score']:.3f} |
| | - 工具参数评分: {detail['tool_calls_arguments_score']:.3f} |
| | - 非retrieval工具名称评分: {detail['non_retrieval_name_score']:.3f} |
| | - 非retrieval工具参数评分: {detail['non_retrieval_arguments_score']:.3f} |
| | - 工具响应评分: {detail['tool_responses_score']:.3f}""" |
| | |
| | report += f""" |
| | |
| | 【评分说明】 |
| | - 工具名称匹配分: 工具名称完全一致为1分,否则为0分 |
| | - 工具参数匹配分: 工具参数完全一致为1分,否则为0分 |
| | - 非retrieval工具名称匹配分: 排除retrieval_tool后,工具名称完全一致为1分,否则为0分 |
| | - 非retrieval工具参数匹配分: 排除retrieval_tool后,工具参数完全一致为1分,否则为0分 |
| | - 工具响应评分: 完全一致为1分,否则为0分 |
| | |
| | 详细结果请查看 evaluation_results.json 文件 |
| | """ |
| | |
| | print(report) |
| | logger.info("详细摘要报告已生成") |
| |
|
| | def test_sse_parsing(self): |
| | """测试SSE解析功能""" |
| | test_data_tool_call = """id: 3 |
| | event: tool_call.created |
| | data: {"conversation_id": "c_9c5b3617", "message_id": "m_1248", "sequence": 3, "role": "assistant", "timestamp": "2025-09-18T13:01:34.230464Z", "content": "", "tool_call": {"name": "intelligent_route_analysis", "arguments": {"access": "1"}}} |
| | """ |
| | test_data_tool_response = """ |
| | id: 4 |
| | event: tool_response.completed |
| | data: {"conversation_id": "c_9c5b3617", "message_id": "m_1248", "sequence": 4, "role": "tool", "timestamp": "2025-09-18T13:01:34.358678Z", "tool_call_id": "tool_2", "result_delta": {"chat_log_id": 1234, "content": "", "markdown": "智能路由分析结果\\n\\n访问链接: [上传派团单](https://testai.compassaihz.com/#/$&!upload \\"成功匹配到对应页面\\")\\n\\n", "result": {"success": true, "url": "https://testai.compassaihz.com/#/$&!upload", "message": "成功匹配到对应页面"}, "ambulance": "", "potential_tools": [{"api": "/intelligent_route_analysis", "api_cn": "页面跳转工具", "queryData": [{"type": "select", "key": "access", "label": "数字访问码", "value": {"options": [{"label": "1", "value": "1"}, {"label": "2", "value": "2"}, {"label": "3", "value": "3"}, {"label": "4", "value": "4"}, {"label": "5", "value": "5"}, {"label": "6", "value": "6"}, {"label": "7", "value": "7"}, {"label": "8", "value": "8"}, {"label": "9", "value": "9"}, {"label": "10", "value": "10"}]}, "default": "1", "multiple": false}], "description": "智能页面路由工具,通过输入1-10的数字快速跳转到对应的业务功能页面。业务功能包括:上传派团单,新增资源,新增产品,新增协议,新增协议模版,离线上传协议,离线上传价格政策,新增价格政策,新增报表,新增审批流"}], "tool_calling_chain": [{"role": "function", "tool_call": {"name": "intelligent_route_analysis", "arguments": {"access": "1"}}, "tool_response": {"success": true, "url": "https://testai.compassaihz.com/#/$&!upload", "message": "成功匹配到对应页面"}}], "api_Info": {"api": "/intelligent_route_analysis", "api_cn": "页面跳转工具", "queryData": [{"type": "select", "key": "access", "label": "数字访问码", "value": {"options": [{"label": "1", "value": "1"}, {"label": "2", "value": "2"}, {"label": "3", "value": "3"}, {"label": "4", "value": "4"}, {"label": "5", "value": "5"}, {"label": "6", "value": "6"}, {"label": "7", "value": "7"}, {"label": "8", "value": "8"}, {"label": "9", "value": "9"}, {"label": "10", "value": "10"}]}, "default": "1", "multiple": false}], "description": "智能页面路由工具,通过输入1-10的数字快速跳转到对应的业务功能页面。业务功能包括:上传派团单,新增资源,新增产品,新增协议,新增协议模版,离线上传协议,离线上传价格政策,新增价格政策,新增报表,新增审批流"}}, "success": true, "execution_time": 0.0} |
| | |
| | """ |
| | |
| | logger.info("=== 开始测试SSE解析功能 ===") |
| | |
| | |
| | combined_test_data = test_data_tool_call + test_data_tool_response |
| | |
| | |
| | events = self.parse_sse_events( |
| | combined_test_data, |
| | filter_events=['tool_call.created', 'tool_response.completed'] |
| | ) |
| | |
| | logger.info(f"=== 测试解析结果总结 ===") |
| | logger.info(f"总共解析到 {len(events)} 个事件") |
| | for event in events: |
| | logger.info(f"事件: {event['event']}, ID: {event['id']}") |
| | logger.info(f" 数据摘要: {str(event['data'])[:100]}...") |
| | |
| | |
| | extracted = self.extract_tool_calls_and_observations(events) |
| | logger.info(f"Extraction results from one tool calling: {extracted}") |
| | |
| | return events |
| |
|
| | async def main(): |
| | """Main function""" |
| | evaluator = ModelEvaluator() |
| | |
| | |
| | |
| | |
| | |
| | |
| | input_file = "data/9.17_evaluate_data_top5_final.json" |
| | output_file = "eval_results/evaluation_results.json" |
| | |
| | |
| | |
| | |
| | |
| | await evaluator.evaluate( |
| | input_file=input_file, |
| | output_file=output_file, |
| | batch_size=50, |
| | max_queries=None, |
| | checkpoint_file="eval_results/evaluation_results.json.checkpoint", |
| | resume=True |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | asyncio.run(main()) |
| |
|