| | |
| | """ |
| | Update Math Vision prompts to include instruction for multiple choice questions. |
| | |
| | This script updates all JSON files in data/math_vision/ to include: |
| | "如果是选择题直接给出选项字母" instruction. |
| | """ |
| |
|
| | import json |
| | import os |
| | from pathlib import Path |
| |
|
| |
|
| | def update_prompt(prompt: str) -> str: |
| | """ |
| | Update prompt to include instruction for multiple choice questions. |
| | |
| | Old: "Solve the problem and output the answer in the format of \\boxed{your answer}.\\n Question:" |
| | New: "Solve the problem and output the answer in the format of \\boxed{your answer}. If it is a multiple choice question, directly give the option letter.\\n Question:" |
| | """ |
| | |
| | chinese_text = "Solve the problem and output the answer in the format of \\boxed{your answer}. 如果是选择题直接给出选项字母.\\n Question:" |
| | old_text = "Solve the problem and output the answer in the format of \\boxed{your answer}.\\n Question:" |
| | new_text = "Solve the problem and output the answer in the format of \\boxed{your answer}. If it is a multiple choice question, directly give the option letter.\\n Question:" |
| | |
| | |
| | if chinese_text in prompt: |
| | return prompt.replace(chinese_text, new_text) |
| | elif old_text in prompt: |
| | return prompt.replace(old_text, new_text) |
| | else: |
| | |
| | return prompt |
| |
|
| |
|
| | def update_json_file(file_path: str) -> dict: |
| | """ |
| | Update a JSON file with new prompts. |
| | |
| | Returns: |
| | dict with statistics: total, updated, skipped |
| | """ |
| | print(f"\n处理文件: {file_path}") |
| | |
| | |
| | with open(file_path, 'r', encoding='utf-8') as f: |
| | data = json.load(f) |
| | |
| | |
| | total = len(data) |
| | updated = 0 |
| | |
| | for item in data: |
| | old_prompt = item.get('prompt', '') |
| | new_prompt = update_prompt(old_prompt) |
| | |
| | if new_prompt != old_prompt: |
| | item['prompt'] = new_prompt |
| | updated += 1 |
| | |
| | |
| | backup_path = file_path + '.backup' |
| | if not os.path.exists(backup_path): |
| | with open(backup_path, 'w', encoding='utf-8') as f: |
| | |
| | with open(file_path, 'r', encoding='utf-8') as f_orig: |
| | original_data = json.load(f_orig) |
| | json.dump(original_data, f, ensure_ascii=False, indent=2) |
| | print(f" ✓ 备份创建: {backup_path}") |
| | |
| | |
| | with open(file_path, 'w', encoding='utf-8') as f: |
| | json.dump(data, f, ensure_ascii=False, indent=2) |
| | |
| | stats = { |
| | 'total': total, |
| | 'updated': updated, |
| | 'skipped': total - updated |
| | } |
| | |
| | print(f" ✓ 总样本数: {stats['total']}") |
| | print(f" ✓ 已更新: {stats['updated']}") |
| | print(f" ✓ 跳过: {stats['skipped']}") |
| | |
| | return stats |
| |
|
| |
|
| | def main(): |
| | data_dir = "data/math_vision" |
| | |
| | if not os.path.exists(data_dir): |
| | print(f"错误: 目录不存在: {data_dir}") |
| | return |
| | |
| | print("=" * 80) |
| | print("Math Vision Prompt 更新脚本") |
| | print("=" * 80) |
| | print(f"数据目录: {data_dir}") |
| | print(f"更新内容: 添加 'If it is a multiple choice question, directly give the option letter.' 指令") |
| | |
| | |
| | json_files = ['train.json', 'valid.json', 'test.json'] |
| | |
| | total_stats = {'total': 0, 'updated': 0, 'skipped': 0} |
| | |
| | for filename in json_files: |
| | file_path = os.path.join(data_dir, filename) |
| | |
| | if not os.path.exists(file_path): |
| | print(f"\n⚠ 跳过不存在的文件: {file_path}") |
| | continue |
| | |
| | stats = update_json_file(file_path) |
| | total_stats['total'] += stats['total'] |
| | total_stats['updated'] += stats['updated'] |
| | total_stats['skipped'] += stats['skipped'] |
| | |
| | print("\n" + "=" * 80) |
| | print("总结") |
| | print("=" * 80) |
| | print(f"总样本数: {total_stats['total']}") |
| | print(f"已更新: {total_stats['updated']}") |
| | print(f"跳过: {total_stats['skipped']}") |
| | print(f"\n✓ 完成!所有prompt已更新。") |
| | print(f"\n备份文件位置:") |
| | for filename in json_files: |
| | backup_path = os.path.join(data_dir, filename + '.backup') |
| | if os.path.exists(backup_path): |
| | print(f" - {backup_path}") |
| | |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|