Spaces:
Running
Running
import os | |
import json | |
import pandas as pd | |
import numpy as np | |
import gradio as gr | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
import matplotlib as mpl | |
import re | |
import sqlite3 | |
import math | |
import time | |
from huggingface_hub import hf_hub_download | |
import psutil | |
import gc | |
# 翻译表 | |
SUBJECT_TRANS = { | |
"代数": "Algebra", | |
"数论": "Number Theory", | |
"几何": "Geometry", | |
"组合": "Combinatorics" | |
} | |
# 英文到中文的翻译表 | |
SUBJECT_TRANS_EN_TO_ZH = { | |
"Algebra": "代数", | |
"Number Theory": "数论", | |
"Geometry": "几何", | |
"Combinatorics": "组合" | |
} | |
MODEL_TRANS = { | |
"acemath-rl-nemotron-7b": "AceMath-RL-Nemotron-7B", | |
"deepseek-r1-distill-qwen-1.5b": "DeepSeek-R1-Distill-Qwen-1.5B", | |
"light-r1-32b-ds": "Light-R1-32B-DS", | |
"openmath-nemotron-1.5b": "OpenMath-Nemotron-1.5B", | |
"openthinker2-7b": "OpenThinker2-7B", | |
"qwq-32b": "QwQ-32B", | |
"still-3-1.5b-preview": "STILL-3-1.5B-Preview", | |
"deepseek-r1-distill-qwen-32b": "DeepSeek-R1-Distill-Qwen-32B", | |
"light-r1-7b-ds": "Light-R1-7B-DS", | |
"openmath-nemotron-32b": "OpenMath-Nemotron-32B", | |
"qwen3-235b-a22b": "Qwen3-235B-A22B", | |
"skywork-or1-32b-preview": "Skywork-OR1-32B-Preview", | |
"deepscaler-1.5b-preview": "DeepScaler-1.5B-Preview", | |
"deepseek-r1-distill-qwen-7b": "DeepSeek-R1-Distill-Qwen-7B", | |
"openmath-nemotron-7b": "OpenMath-Nemotron-7B", | |
"deepseek-r1-distill-qwen-14b": "DeepSeek-R1-Distill-Qwen-14B", | |
"light-r1-14b-ds": "Light-R1-14B-DS", | |
"openmath-nemotron-14b": "OpenMath-Nemotron-14B", | |
"openthinker2-32b": "OpenThinker2-32B", | |
"qwen3-4b": "Qwen3-4B", | |
"skywork-or1-math-7b": "Skywork-OR1-Math-7B", | |
"skywork-or1-7b-preview": "Skywork-OR1-7B-Preview", | |
"qwen3-30b-a3b": "Qwen3-30B-A3B", | |
"deepseek-r1": "DeepSeek-R1", | |
"glm-z1-air": "GLM-Z1-Air", | |
"gemini-2.5-pro-exp-03-25": "Gemini 2.5 Pro Exp 0325", | |
"o3-mini-high": "OpenAI o3-mini (high)", | |
"qwen3-0.6b": "Qwen3-0.6B" | |
# 添加更多模型映射 | |
} | |
# Configure matplotlib for better display | |
plt.style.use('ggplot') | |
mpl.rcParams['figure.figsize'] = (10, 6) | |
mpl.rcParams['font.size'] = 10 | |
# Constants | |
DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] | |
# 全局数据库实例 | |
db = None | |
# 全局缓存for Reference Solutions | |
reference_accuracy_cache = {} | |
def precompute_reference_accuracies(db, reference_loader): | |
"""Pre-compute all reference problem accuracies for fast loading""" | |
global reference_accuracy_cache | |
if not db or not reference_loader: | |
return | |
print("Pre-computing reference problem accuracies...") | |
start_time = time.time() | |
problem_ids = reference_loader.get_all_problem_ids() | |
reference_accuracy_cache = {} | |
# 获取所有模型一次性 | |
all_models = db.get_available_models() | |
print(f"Computing accuracies for {len(problem_ids)} problems across {len(all_models)} models...") | |
for i, pid in enumerate(problem_ids): | |
if i % 5 == 0: # 每5个问题打印一次进度 | |
print(f"Processing problem {i+1}/{len(problem_ids)}: {pid}") | |
try: | |
en_unique_id = f"OlymMATH-HARD-{pid}-EN" | |
zh_unique_id = f"OlymMATH-HARD-{pid}-ZH" | |
en_accuracies = [] | |
zh_accuracies = [] | |
for model in all_models: | |
# 英文版本 | |
try: | |
_, responses_en = db.get_problem_data(model, "EN-HARD", en_unique_id) | |
if responses_en and len(responses_en) > 0: | |
avg_accuracy_en = sum(r['correctness'] for r in responses_en) / len(responses_en) | |
en_accuracies.append(avg_accuracy_en) | |
except Exception: | |
pass | |
# 中文版本 | |
try: | |
_, responses_zh = db.get_problem_data(model, "ZH-HARD", zh_unique_id) | |
if responses_zh and len(responses_zh) > 0: | |
avg_accuracy_zh = sum(r['correctness'] for r in responses_zh) / len(responses_zh) | |
zh_accuracies.append(avg_accuracy_zh) | |
except Exception: | |
pass | |
# 计算平均值并存储到缓存 | |
en_avg = sum(en_accuracies) / len(en_accuracies) if en_accuracies else 0.0 | |
zh_avg = sum(zh_accuracies) / len(zh_accuracies) if zh_accuracies else 0.0 | |
reference_accuracy_cache[pid] = {"EN": en_avg, "ZH": zh_avg} | |
except Exception as e: | |
print(f"Error computing accuracy for problem {pid}: {e}") | |
reference_accuracy_cache[pid] = {"EN": 0.0, "ZH": 0.0} | |
elapsed_time = time.time() - start_time | |
print(f"✅ Pre-computation completed in {elapsed_time:.2f} seconds") | |
print(f"✅ Cached accuracies for {len(reference_accuracy_cache)} problems") | |
class ModelDatabase: | |
"""Database access class""" | |
def __init__(self, db_path): | |
"""Initialize database connection""" | |
self.db_path = db_path | |
# Use connection pool pattern to avoid too many connections | |
self.conn = sqlite3.connect(db_path, check_same_thread=False, isolation_level=None, timeout=60) | |
self.conn.execute("PRAGMA journal_mode = WAL") # Use Write-Ahead Logging for better performance | |
self.conn.execute("PRAGMA synchronous = NORMAL") # Reduce synchronization overhead | |
self.conn.execute("PRAGMA cache_size = -8000") # 8MB cache (比原来大4倍) | |
self.conn.execute("PRAGMA temp_store = MEMORY") # 临时表存储在内存中 | |
self.conn.execute("PRAGMA mmap_size = 8589934592") # 尝试使用8GB内存映射 | |
self.conn.row_factory = sqlite3.Row | |
# 创建索引以加速查询 | |
self._ensure_indices() | |
# 初始化模型名称映射 | |
self.model_display_to_real = {} | |
self.comp_model_display_to_real = {} | |
# 初始化缓存 | |
self._cache = {} | |
self._problem_cache = {} | |
self._response_cache = {} | |
def _ensure_indices(self): | |
"""确保数据库有必要的索引""" | |
try: | |
cursor = self.conn.cursor() | |
# 添加最常用查询的索引 | |
cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_model_dataset ON responses(model_name, dataset)") | |
cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_unique_id ON responses(unique_id)") | |
cursor.execute("CREATE INDEX IF NOT EXISTS idx_problems_unique_id ON problems(unique_id)") | |
cursor.execute("ANALYZE") # 分析表以优化查询计划 | |
except Exception as e: | |
pass | |
def get_available_models(self): | |
"""Get list of all available models""" | |
# 缓存在实例变量中 | |
if hasattr(self, '_models_cache') and self._models_cache: | |
return self._models_cache | |
try: | |
cursor = self.conn.cursor() | |
cursor.execute("SELECT DISTINCT model_name FROM responses ORDER BY model_name") | |
models = [row['model_name'] for row in cursor.fetchall()] | |
self._models_cache = models # 存储到实例缓存 | |
return models | |
except sqlite3.OperationalError: | |
return [] | |
def get_available_datasets(self): | |
"""Get list of all available datasets""" | |
# 缓存在实例变量中 | |
if hasattr(self, '_datasets_cache') and self._datasets_cache: | |
return self._datasets_cache | |
try: | |
cursor = self.conn.cursor() | |
cursor.execute("SELECT DISTINCT dataset FROM responses ORDER BY dataset") | |
datasets = [row['dataset'].upper() for row in cursor.fetchall()] | |
self._datasets_cache = datasets # 存储到实例缓存 | |
return datasets | |
except sqlite3.OperationalError: | |
return DATASETS | |
def get_model_statistics(self, model_name, dataset): | |
"""Get statistics for a model on a specific dataset""" | |
if hasattr(model_name, 'value'): model_name = model_name.value | |
if hasattr(dataset, 'value'): dataset = dataset.value | |
cache_key = f"stats_{model_name}_{dataset}" | |
if not hasattr(self, '_cache'): self._cache = {} | |
if cache_key in self._cache: return self._cache[cache_key] | |
cursor = self.conn.cursor() | |
try: | |
# 优化查询1: 整体准确率 - 使用索引提示加速 | |
cursor.execute(""" | |
SELECT COUNT(*) as total_samples, AVG(correctness) as accuracy | |
FROM responses INDEXED BY idx_responses_model_dataset | |
WHERE model_name = ? AND dataset = ? | |
""", (model_name, dataset.lower())) | |
overall_stats = cursor.fetchone() | |
# 优化查询2: 按学科统计 - 避免子查询和复杂JOIN | |
cursor.execute(""" | |
SELECT p.subject, COUNT(r.id) as sample_count, AVG(r.correctness) as accuracy | |
FROM responses r JOIN problems p ON r.unique_id = p.unique_id | |
WHERE r.model_name = ? AND r.dataset = ? | |
GROUP BY p.subject ORDER BY p.subject | |
""", (model_name, dataset.lower())) | |
subject_stats_rows = cursor.fetchall() | |
stats_data = [] | |
if overall_stats and overall_stats['accuracy'] is not None: | |
stats_data.append(["Overall Acc.", f"{overall_stats['accuracy']:.2%}"]) | |
else: | |
stats_data.append(["Overall Acc.", "N/A"]) | |
for subject_row in subject_stats_rows: | |
acc_val = f"{subject_row['accuracy']:.2%}" if subject_row['accuracy'] is not None else "N/A" | |
subject_name = subject_row['subject'] | |
# 使用翻译表翻译科目名称 | |
translated_subject = SUBJECT_TRANS.get(subject_name, subject_name) | |
stats_data.append([f"{translated_subject} Acc.", acc_val]) | |
self._cache[cache_key] = stats_data | |
return stats_data | |
except sqlite3.OperationalError: | |
return [["Database Error", "No data available"]] | |
def get_all_model_accuracies(self, dataset): | |
"""获取所有模型在特定数据集上的准确率 (优化版本)""" | |
if hasattr(dataset, 'value'): dataset = dataset.value | |
cache_key = f"all_accuracies_{dataset}" | |
if not hasattr(self, '_cache'): self._cache = {} | |
if cache_key in self._cache: return self._cache[cache_key] | |
try: | |
cursor = self.conn.cursor() | |
# 使用索引提示加速查询 | |
cursor.execute(""" | |
SELECT model_name, AVG(correctness) as accuracy | |
FROM responses INDEXED BY idx_responses_model_dataset | |
WHERE dataset = ? GROUP BY model_name ORDER BY accuracy DESC | |
""", (dataset.lower(),)) | |
results = [(row['model_name'], row['accuracy']) for row in cursor.fetchall()] | |
self._cache[cache_key] = results | |
return results | |
except sqlite3.OperationalError: | |
return [] | |
def get_problems_by_model_dataset(self, model_name, dataset): | |
"""获取模型在特定数据集上的所有问题 (优化版本)""" | |
if hasattr(model_name, 'value'): model_name = model_name.value | |
if hasattr(dataset, 'value'): dataset = dataset.value | |
cache_key = f"problems_{model_name}_{dataset}" | |
if not hasattr(self, '_cache'): self._cache = {} | |
if cache_key in self._cache: return self._cache[cache_key] | |
cursor = self.conn.cursor() | |
try: | |
# 优化查询:使用索引提示和优化JOIN策略 | |
cursor.execute(""" | |
SELECT DISTINCT r.unique_id, p.problem, AVG(r.correctness) as accuracy | |
FROM responses r INDEXED BY idx_responses_model_dataset | |
JOIN problems p INDEXED BY idx_problems_unique_id ON r.unique_id = p.unique_id | |
WHERE r.model_name = ? AND r.dataset = ? | |
GROUP BY r.unique_id ORDER BY r.unique_id | |
""", (model_name, dataset.lower())) | |
results = [(row['unique_id'], row['accuracy'] if row['accuracy'] is not None else 0.0, row['problem']) for row in cursor.fetchall()] | |
# Sort by the integer part of unique_id | |
sorted_results = sorted(results, key=lambda x: int(re.search(r'\d+', x[0]).group(0)) if re.search(r'\d+', x[0]) else 0) | |
self._cache[cache_key] = sorted_results | |
return sorted_results | |
except sqlite3.OperationalError: | |
return [] | |
def get_problem_data(self, model_name, dataset, problem_id): | |
"""获取问题和响应数据 (采用局部缓存策略)""" | |
if hasattr(model_name, 'value'): model_name = model_name.value | |
if hasattr(dataset, 'value'): dataset = dataset.value | |
if hasattr(problem_id, 'value'): problem_id = problem_id.value | |
# 问题数据缓存 - 问题数据通常不会变化,可长期缓存 | |
problem_cache_key = f"problem_{problem_id}" | |
if problem_cache_key in self._problem_cache: | |
problem = self._problem_cache[problem_cache_key] | |
else: | |
if not self.conn: | |
return None, None | |
try: | |
cursor = self.conn.cursor() | |
cursor.execute("SELECT * FROM problems WHERE unique_id = ?", (problem_id,)) | |
problem = cursor.fetchone() | |
if problem: | |
# 转为字典存储,避免SQLite连接依赖 | |
self._problem_cache[problem_cache_key] = dict(problem) | |
problem = self._problem_cache[problem_cache_key] | |
except Exception: | |
return None, None | |
if not problem: | |
return None, None | |
# 响应数据缓存 - 更细粒度的缓存键 | |
if model_name: | |
resp_cache_key = f"responses_{model_name}_{dataset}_{problem_id}" | |
if resp_cache_key in self._response_cache: | |
return problem, self._response_cache[resp_cache_key] | |
if not self.conn: | |
return problem, None | |
# 获取特定模型的响应 | |
try: | |
cursor = self.conn.cursor() | |
cursor.execute(""" | |
SELECT * FROM responses | |
WHERE model_name = ? AND dataset = ? AND unique_id = ? | |
ORDER BY response_id | |
""", (model_name, dataset.lower(), problem_id)) | |
responses = cursor.fetchall() | |
# 转换为字典列表存储 | |
if responses: | |
responses = [dict(r) for r in responses] | |
self._response_cache[resp_cache_key] = responses | |
return problem, responses | |
except Exception: | |
return problem, None | |
else: | |
# 获取所有模型对此问题的响应 | |
resp_cache_key = f"all_responses_{dataset}_{problem_id}" | |
if resp_cache_key in self._response_cache: | |
return problem, self._response_cache[resp_cache_key] | |
if not self.conn: | |
return problem, None | |
try: | |
cursor = self.conn.cursor() | |
cursor.execute(""" | |
SELECT * FROM responses | |
WHERE dataset = ? AND unique_id = ? | |
ORDER BY model_name, response_id | |
""", (dataset.lower(), problem_id)) | |
responses = cursor.fetchall() | |
# 转换为字典列表存储 | |
if responses: | |
responses = [dict(r) for r in responses] | |
self._response_cache[resp_cache_key] = responses | |
return problem, responses | |
except Exception: | |
return problem, None | |
def get_model_responses(self, selected_models, dataset, problem_id): | |
"""获取多个模型对特定问题的响应(优化版本)""" | |
if hasattr(dataset, 'value'): dataset = dataset.value | |
if hasattr(problem_id, 'value'): problem_id = problem_id.value | |
if not selected_models or not dataset or not problem_id: | |
return None, {} | |
# 获取问题数据 - 可共享缓存 | |
problem, _ = self.get_problem_data(None, dataset, problem_id) | |
if not problem: | |
return None, {} | |
model_responses_data = {} | |
for model_display in selected_models: | |
model_display_val = model_display.value if hasattr(model_display, 'value') else model_display | |
# 从显示名称中获取真实模型名称 | |
model = self.comp_model_display_to_real.get(model_display_val, model_display_val) | |
_, responses_for_model = self.get_problem_data(model, dataset, problem_id) | |
if responses_for_model: | |
# 尝试找到正确的响应,否则使用第一个 | |
correct_resp = next((r for r in responses_for_model if r['correctness'] == 1), None) | |
model_responses_data[model_display_val] = correct_resp if correct_resp else responses_for_model[0] | |
else: | |
model_responses_data[model_display_val] = None | |
return problem, model_responses_data | |
def clear_cache(self, section=None): | |
"""清除指定部分或全部缓存""" | |
if section == 'main' or section is None: | |
self._cache = {} | |
if section == 'problem' or section is None: | |
self._problem_cache = {} | |
if section == 'response' or section is None: | |
self._response_cache = {} | |
if section == 'models' or section is None: | |
if hasattr(self, '_models_cache'): | |
self._models_cache = None | |
if hasattr(self, '_datasets_cache'): | |
self._datasets_cache = None | |
def close(self): | |
"""关闭数据库连接并释放资源""" | |
if hasattr(self, 'conn') and self.conn: | |
try: | |
self.conn.close() | |
except Exception: | |
pass | |
# 清理所有缓存 | |
self.clear_cache() | |
class ReferenceDataLoader: | |
"""Load and manage reference solutions data""" | |
def __init__(self, jsonl_path): | |
self.jsonl_path = jsonl_path | |
self.reference_data = {} | |
self._load_data() | |
def _load_data(self): | |
"""Load data from extra.jsonl""" | |
try: | |
with open(self.jsonl_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
data = json.loads(line.strip()) | |
unique_id = data['unique_id'] | |
self.reference_data[unique_id] = data | |
except Exception as e: | |
print(f"Error loading reference data: {e}") | |
def get_problem_data(self, unique_id): | |
"""Get reference data for a specific problem ID""" | |
return self.reference_data.get(unique_id) | |
def get_all_problem_ids(self): | |
"""Get all available problem IDs""" | |
return sorted(self.reference_data.keys()) | |
def calculate_reference_problem_accuracy(db, unique_id): | |
"""Calculate average accuracy for a reference problem across all models for both EN and ZH versions""" | |
try: | |
# 构建英文和中文版本的unique_id | |
en_unique_id = f"OlymMATH-HARD-{unique_id}-EN" | |
zh_unique_id = f"OlymMATH-HARD-{unique_id}-ZH" | |
print(f"Calculating accuracy for problem {unique_id}: EN={en_unique_id}, ZH={zh_unique_id}") | |
accuracies = {"EN": [], "ZH": []} | |
# 获取所有模型 | |
all_models = db.get_available_models() | |
print(f"Found {len(all_models)} models in database") | |
for model in all_models: | |
# 英文版本 | |
try: | |
_, responses_en = db.get_problem_data(model, "EN-HARD", en_unique_id) | |
if responses_en and len(responses_en) > 0: | |
avg_accuracy_en = sum(r['correctness'] for r in responses_en) / len(responses_en) | |
accuracies["EN"].append(avg_accuracy_en) | |
print(f" Model {model} EN: {avg_accuracy_en:.2%} ({len(responses_en)} responses)") | |
except Exception as e: | |
print(f" Error getting EN data for model {model}: {e}") | |
pass | |
# 中文版本 | |
try: | |
_, responses_zh = db.get_problem_data(model, "ZH-HARD", zh_unique_id) | |
if responses_zh and len(responses_zh) > 0: | |
avg_accuracy_zh = sum(r['correctness'] for r in responses_zh) / len(responses_zh) | |
accuracies["ZH"].append(avg_accuracy_zh) | |
print(f" Model {model} ZH: {avg_accuracy_zh:.2%} ({len(responses_zh)} responses)") | |
except Exception as e: | |
print(f" Error getting ZH data for model {model}: {e}") | |
pass | |
# 计算平均值 | |
en_avg = sum(accuracies["EN"]) / len(accuracies["EN"]) if accuracies["EN"] else 0.0 | |
zh_avg = sum(accuracies["ZH"]) / len(accuracies["ZH"]) if accuracies["ZH"] else 0.0 | |
print(f"Final averages for problem {unique_id}: EN={en_avg:.2%} (from {len(accuracies['EN'])} models), ZH={zh_avg:.2%} (from {len(accuracies['ZH'])} models)") | |
return en_avg, zh_avg | |
except Exception as e: | |
print(f"Error calculating accuracy for problem {unique_id}: {e}") | |
return 0.0, 0.0 | |
def format_latex(text): | |
if text is None: return "" | |
# Process the text for proper LaTeX rendering with KaTeX | |
# KaTeX requires LaTeX backslashes to be preserved | |
# Only replace newlines with HTML breaks | |
text = text.replace('\n', '<br>') | |
# Wrap in a span that KaTeX can detect and render | |
return f'<span class="math-inline">{text}</span>' | |
def format_markdown_with_math(text): | |
if text is None: return "" | |
# Convert LaTeX delimiters first - same logic as format_solution_latex | |
# Convert $$xxx$$ to \[xxx\] (display math) | |
text = re.sub(r'\$\$(.*?)\$\$', r'\\[\1\\]', text, flags=re.DOTALL) | |
# Convert $xxx$ to \(xxx\) (inline math) | |
# Be careful not to match already converted \[...\] content | |
text = re.sub(r'(?<!\\)\$([^$\n]+?)\$(?!\])', r'\\(\1\\)', text) | |
# Convert newlines for markdown | |
text = text.replace('\r\n', '\n').replace('\r', '\n') | |
# Clean up excessive newlines | |
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text) | |
# Debug: Print if aligned environment detected | |
if '\\begin{aligned}' in text: | |
print(f"LaTeX aligned environment detected in text (first 200 chars): {text[:200]}...") | |
# Return the cleaned text for Gradio's markdown component to render | |
return text | |
def get_gradient_color(accuracy, color_map='RdYlGn'): | |
if accuracy is None or not isinstance(accuracy, (int, float)): | |
return "#505050" # Default for missing or invalid accuracy | |
try: | |
# 使用更深的颜色映射 | |
cmap = plt.colormaps.get_cmap(color_map) | |
rgba = cmap(float(accuracy)) | |
# 确保颜色足够深以与白色文本形成对比 | |
r, g, b, a = rgba | |
# 降低颜色亮度,确保文本可读性 | |
r = r * 0.7 | |
g = g * 0.7 | |
b = b * 0.7 | |
# 转回十六进制 | |
hex_color = mpl.colors.rgb2hex((r, g, b, a)) | |
return hex_color | |
except Exception: | |
return "#505050" | |
def get_contrasting_text_color(bg_color): | |
"""计算最佳对比文本颜色""" | |
# 如果背景是十六进制格式,转换为RGB | |
if bg_color.startswith('#'): | |
r = int(bg_color[1:3], 16) | |
g = int(bg_color[3:5], 16) | |
b = int(bg_color[5:7], 16) | |
else: | |
# 未知格式默认返回黑色 | |
return "#000" | |
# 计算YIQ亮度值 - 更精确地表示人眼对亮度的感知 | |
yiq = (r * 299 + g * 587 + b * 114) / 1000 | |
# 黄色检测 - 黄色通常R和G高,B低 | |
is_yellow = r > 200 and g > 200 and b < 150 | |
# 浅绿色检测 - 通常G高,R中等,B低 | |
is_light_green = g > 200 and r > 100 and r < 180 and b < 150 | |
# 米色/浅棕色检测 - R高,G中高,B低 | |
is_beige = r > 220 and g > 160 and g < 220 and b < 160 | |
# 强制这些特定颜色使用黑色文本 | |
if is_yellow or is_light_green or is_beige: | |
return "#000" | |
# 其他颜色根据亮度决定 | |
return "#000" if yiq > 160 else "#fff" | |
def format_sample_metadata(sample, show_correctness=True): | |
"""生成样本元数据的HTML格式显示""" | |
if sample is None: return "" | |
sample_dict = dict(sample) if hasattr(sample, 'keys') else sample if isinstance(sample, dict) else {} | |
if not sample_dict: return "No sample data" | |
# 提取所需信息 | |
extracted = sample_dict.get('extracted', '') | |
correctness = sample_dict.get('correctness', 0) | |
correctness_label = "✓ Correct" if correctness else "✗ Incorrect" | |
correctness_color = "var(--color-green)" if correctness else "var(--color-red)" | |
# 获取token信息 | |
output_tokens = sample_dict.get('output_tokens', None) | |
reasoning_tokens = sample_dict.get('reasoning_tokens', None) | |
# 创建元数据HTML | |
html = f"<div style='font-size: 0.85em; padding: 10px; border-radius: 8px; margin-bottom: 5px;' class='dark-mode-compatible dark-mode-bg-secondary'>" | |
# 创建信息行 | |
if show_correctness: | |
html += f"<div style='display: flex; flex-wrap: wrap; align-items: center; margin-bottom: 5px;'>" | |
# 正确性指示器 | |
html += f"<span style='color: {correctness_color}; font-weight: bold; margin-right: 10px;'>{correctness_label}</span>" | |
# 提取的答案 | |
if extracted: | |
html += f"<span style='background-color: rgba(0,0,0,0.05); padding: 2px 5px; border-radius: 3px; margin-right: 10px;'><b>Extracted:</b> ${extracted}$</span>" | |
# 输出token数 | |
if output_tokens is not None: | |
html += f"<span style='background-color: rgba(0,0,0,0.05); padding: 2px 5px; border-radius: 3px; margin-right: 10px;'><b>Output Tokens:</b> {output_tokens}</span>" | |
# 推理token数 - 仅在可用时 | |
if reasoning_tokens is not None: | |
html += f"<span style='background-color: rgba(0,0,0,0.05); padding: 2px 5px; border-radius: 3px;'><b>Reasoning Tokens:</b> {reasoning_tokens}</span>" | |
html += f"</div>" | |
html += "</div>" | |
return html | |
def format_sample_response(sample): | |
"""生成样本响应的Markdown格式显示""" | |
if sample is None: return "" | |
sample_dict = dict(sample) if hasattr(sample, 'keys') else sample if isinstance(sample, dict) else {} | |
if not sample_dict: return "No sample data" | |
# 获取响应内容 | |
response = sample_dict.get('response', '') | |
# 转义特殊标签以防止被解析为HTML | |
# 替换<think>标签 | |
response = response.replace("<think>", "<think>") | |
response = response.replace("</think>", "</think>") | |
# 替换其他可能的特殊标签 | |
response = response.replace("<reasoning>", "<reasoning>") | |
response = response.replace("</reasoning>", "</reasoning>") | |
response = response.replace("<answer>", "<answer>") | |
response = response.replace("</answer>", "</answer>") | |
return response | |
def handle_sample_select(sample_number, samples_data): | |
# 确保从Gradio State对象中提取实际值 | |
if hasattr(samples_data, 'value'): | |
samples_list = samples_data.value | |
else: | |
samples_list = samples_data | |
# 确保样本编号是整数 | |
try: | |
sample_idx = int(sample_number) | |
except ValueError: | |
return "Error: Sample number must be an integer.", "" | |
# 确保样本数据存在且为非空列表 | |
if not samples_list or not isinstance(samples_list, list) or len(samples_list) == 0: | |
return "No sample data available. Please select a problem first.", "" | |
# 检查索引是否在有效范围内,如果不在范围内,显示错误消息 | |
if sample_idx < 0: | |
err_msg = f"**Error:** Sample number {sample_idx} is out of range. Valid range is 0 to {len(samples_list) - 1}." | |
return err_msg, "" | |
if sample_idx >= len(samples_list): | |
err_msg = f"**Error:** Sample number {sample_idx} is out of range. Valid range is 0 to {len(samples_list) - 1}." | |
return err_msg, "" | |
# 获取所选样本的数据 | |
try: | |
sample = samples_list[sample_idx] | |
formatted_metadata = format_sample_metadata(sample) | |
formatted_response = format_sample_response(sample) | |
return formatted_metadata, formatted_response | |
except Exception as e: | |
err_msg = f"**Error displaying sample {sample_idx}:** {str(e)}" | |
return err_msg, "" | |
def handle_first_sample(samples_data): | |
"""处理并显示第一个样本(索引0)""" | |
# 确保从Gradio State对象中提取实际值 | |
if hasattr(samples_data, 'value'): | |
samples_list = samples_data.value | |
else: | |
samples_list = samples_data | |
# 检查样本数据是否存在 | |
if not samples_list or not isinstance(samples_list, list) or len(samples_list) == 0: | |
return "No sample data available. Please select the problem and dataset first.", "" | |
# 直接获取第一个样本,避免错误处理逻辑 | |
try: | |
sample = samples_list[0] | |
formatted_metadata = format_sample_metadata(sample) | |
formatted_response = format_sample_response(sample) | |
return formatted_metadata, formatted_response | |
except Exception as e: | |
err_msg = f"**Error displaying first sample:** {str(e)}" | |
return err_msg, "" | |
def handle_comparison_problem_update(problem_id, dataset_state): | |
"""处理比较页面的问题更新,仅更新问题和答案内容,不需要模型""" | |
global db | |
# 确保从Gradio State对象中提取实际值 | |
dataset_name = dataset_state.value if hasattr(dataset_state, 'value') else dataset_state | |
problem_id_value = problem_id.value if hasattr(problem_id, 'value') else problem_id | |
if not problem_id_value or not dataset_name: | |
return "Please select a dataset and enter a problem ID.", "No answer available." | |
# 处理纯数字输入,构建完整unique_id | |
if problem_id_value and problem_id_value.isdigit(): | |
# 构建格式:OlymMATH-HARD-0-EN 或类似格式 | |
parts = dataset_name.split('-') | |
if len(parts) == 2: # 确保格式正确 (例如 "EN-HARD") | |
language, difficulty = parts | |
# 构建完整ID | |
problem_id_value = f"OlymMATH-{difficulty}-{problem_id_value}-{language}" | |
try: | |
# 只获取问题数据,不获取特定模型的响应 | |
problem_data, _ = db.get_problem_data(None, dataset_name, problem_id_value) | |
if not problem_data: | |
return f"Problem not found: {problem_id_value}. Please check the ID and try again.", "No answer available." | |
problem_dict = dict(problem_data) | |
# Use format_markdown_with_math for proper rendering | |
problem_content = format_markdown_with_math(problem_dict.get('problem', '')) | |
# Use special answer formatting | |
answer_text = problem_dict.get('answer', '') | |
answer_content = format_answer_with_math(answer_text) | |
return problem_content, answer_content | |
except Exception as e: | |
return f"Error: {str(e)}", "No answer available." | |
def handle_problem_select(problem_id_from_js, current_model_state, current_dataset_state, mode='default'): | |
global db | |
# Ensure we're using the actual values from Gradio State objects | |
model_name = current_model_state.value if hasattr(current_model_state, 'value') else current_model_state | |
dataset_name = current_dataset_state.value if hasattr(current_dataset_state, 'value') else current_dataset_state | |
problem_id = problem_id_from_js.value if hasattr(problem_id_from_js, 'value') else problem_id_from_js | |
# 处理纯数字输入,构建完整unique_id | |
if problem_id and problem_id.isdigit(): | |
# 构建格式:OlymMATH-HARD-0-EN 或类似格式 | |
# 从dataset_name (例如 "EN-HARD") 解析语言和难度 | |
parts = dataset_name.split('-') | |
if len(parts) == 2: # 确保格式正确 (例如 "EN-HARD") | |
language, difficulty = parts | |
# 构建完整ID | |
problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}" | |
if not problem_id or not dataset_name: | |
error_message = f"Missing data: problem_id='{problem_id}', dataset='{dataset_name}'" | |
return "Please fill in all the fields.", "No answer available.", "", gr.State([]) | |
# For comparison mode, we might not have a model selected yet | |
if not model_name and mode == 'comparison': | |
try: | |
# Just get the problem data without model-specific responses | |
problem_data, _ = db.get_problem_data(None, dataset_name, problem_id) | |
if not problem_data: | |
error_message = f"Problem data not found: problem_id='{problem_id}', dataset='{dataset_name}'" | |
return f"Problem not found: {problem_id}. Please check the ID and try again.", "No answer available.", "", gr.State([]) | |
problem_dict = dict(problem_data) | |
# Process problem and answer text for Markdown rendering | |
problem_content = format_markdown_with_math(problem_dict.get('problem', '')) | |
# Use special answer formatting | |
answer_text = problem_dict.get('answer', '') | |
answer_content = format_answer_with_math(answer_text) | |
# For comparison without model, we don't have samples to display | |
return problem_content, answer_content, "", gr.State([]) | |
except Exception as e: | |
error_message = f"Database error: {str(e)}" | |
return f"Database error occurred. Please try again.", "No answer available.", "", gr.State([]) | |
# The regular flow for model-specific data | |
if not model_name: | |
error_message = f"Missing data: model='{model_name}'" | |
return "Please fill in all the fields.", "No answer available.", "", gr.State([]) | |
# The problem_id from JS should be the full unique_id. No reconstruction needed normally. | |
try: | |
problem_data, responses_data = db.get_problem_data(model_name, dataset_name, problem_id) | |
if not problem_data: | |
error_message = f"Problem data not found: problem_id='{problem_id}', model='{model_name}', dataset='{dataset_name}'" | |
return f"Problem not found: {problem_id}. Please check the ID and try again.", "No answer available.", "", gr.State([]) | |
except Exception as e: | |
error_message = f"Database error: {str(e)}" | |
return f"Database error occurred. Please try again.", "No answer available.", "", gr.State([]) | |
problem_dict = dict(problem_data) | |
problem_display_num = re.search(r'\d+', problem_id).group(0) if re.search(r'\d+', problem_id) else problem_id | |
# Process problem and answer text for Markdown rendering | |
problem_content = format_markdown_with_math(problem_dict.get('problem', '')) | |
# Use special answer formatting | |
answer_text = problem_dict.get('answer', '') | |
answer_content = format_answer_with_math(answer_text) | |
# Rest of the function remains the same | |
if not responses_data: | |
samples_grid_html = "<div>No samples available for this problem.</div>" | |
# 返回空的样本数据状态 | |
return problem_content, answer_content, samples_grid_html, gr.State([]) | |
else: | |
# 准备所有样本数据,用于后续处理 | |
samples_data = [] | |
for i, resp in enumerate(responses_data): | |
resp_dict = dict(resp) | |
samples_data.append(resp_dict) | |
# 计算正确率 | |
correct_count = sum(1 for r in samples_data if r['correctness']) | |
total_samples = len(samples_data) | |
accuracy_on_problem = correct_count / total_samples if total_samples > 0 else 0 | |
# 创建样本网格显示 (最多显示 64 个样本) | |
displayed_samples = samples_data[:64] | |
actual_display_count = len(displayed_samples) | |
# 根据模式确定每行的样本数 | |
samples_per_row = 16 if mode == 'comparison' else 32 | |
# 第一行: 样本 0-samples_per_row | |
samples_grid_html = f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;">' | |
for i, resp in enumerate(displayed_samples[:samples_per_row]): | |
correctness = resp.get('correctness', 0) | |
bg_color = get_gradient_color(1.0 if correctness else 0.0) | |
# 移除点击事件和data属性,只保留纯显示 | |
samples_grid_html += f""" | |
<div | |
class="sample-grid-btn" | |
style='background-color: {bg_color}; | |
border-radius: 2px; width: 100%; height: 20px; | |
display: flex; align-items: center; justify-content: center;'> | |
<span style="color: white; font-size: 0.65em; font-weight: bold;">{i}</span> | |
</div> | |
""" | |
# 如果少于samples_per_row个样本,填充剩余空间 | |
for i in range(min(actual_display_count, samples_per_row), samples_per_row): | |
samples_grid_html += f""" | |
<div style='background-color: #505050; border-radius: 2px; width: 100%; height: 20px;'></div> | |
""" | |
samples_grid_html += '</div>' | |
# 如果有更多样本,显示第二行 | |
if actual_display_count > samples_per_row: | |
row_samples = displayed_samples[samples_per_row:2*samples_per_row] | |
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;">' | |
for i, resp in enumerate(row_samples): | |
actual_idx = i + samples_per_row | |
correctness = resp.get('correctness', 0) | |
bg_color = get_gradient_color(1.0 if correctness else 0.0) | |
samples_grid_html += f""" | |
<div | |
class="sample-grid-btn" | |
style='background-color: {bg_color}; | |
border-radius: 2px; width: 100%; height: 20px; | |
display: flex; align-items: center; justify-content: center;'> | |
<span style="color: white; font-size: 0.65em; font-weight: bold;">{actual_idx}</span> | |
</div> | |
""" | |
# 填充剩余空间 | |
for i in range(len(row_samples), samples_per_row): | |
samples_grid_html += f""" | |
<div style='background-color: #505050; border-radius: 2px; width: 100%; height: 20px;'></div> | |
""" | |
samples_grid_html += '</div>' | |
# 第三行和第四行 - 允许所有模式显示完整的64个样本 | |
if actual_display_count > 2*samples_per_row: | |
# 第三行 | |
row_samples = displayed_samples[2*samples_per_row:3*samples_per_row] | |
if row_samples: | |
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;">' | |
for i, resp in enumerate(row_samples): | |
actual_idx = i + 2*samples_per_row | |
correctness = resp.get('correctness', 0) | |
bg_color = get_gradient_color(1.0 if correctness else 0.0) | |
samples_grid_html += f""" | |
<div | |
class="sample-grid-btn" | |
style='background-color: {bg_color}; | |
border-radius: 2px; width: 100%; height: 20px; | |
display: flex; align-items: center; justify-content: center;'> | |
<span style="color: white; font-size: 0.65em; font-weight: bold;">{actual_idx}</span> | |
</div> | |
""" | |
# 填充剩余空间 | |
for i in range(len(row_samples), samples_per_row): | |
samples_grid_html += f""" | |
<div style='background-color: #505050; border-radius: 2px; width: 100%; height: 20px;'></div> | |
""" | |
samples_grid_html += '</div>' | |
# 第四行 | |
if actual_display_count > 3*samples_per_row: | |
row_samples = displayed_samples[3*samples_per_row:4*samples_per_row] | |
if row_samples: | |
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;">' | |
for i, resp in enumerate(row_samples): | |
actual_idx = i + 3*samples_per_row | |
correctness = resp.get('correctness', 0) | |
bg_color = get_gradient_color(1.0 if correctness else 0.0) | |
samples_grid_html += f""" | |
<div | |
class="sample-grid-btn" | |
style='background-color: {bg_color}; | |
border-radius: 2px; width: 100%; height: 20px; | |
display: flex; align-items: center; justify-content: center;'> | |
<span style="color: white; font-size: 0.65em; font-weight: bold;">{actual_idx}</span> | |
</div> | |
""" | |
# 填充剩余空间 | |
for i in range(len(row_samples), samples_per_row): | |
samples_grid_html += f""" | |
<div style='background-color: #505050; border-radius: 2px; width: 100%; height: 20px;'></div> | |
""" | |
samples_grid_html += '</div>' | |
# 组合HTML内容 | |
final_html = f""" | |
<div style='margin-top:15px; padding: 10px; border-radius: 8px;' class='dark-mode-compatible dark-mode-bg-secondary'> | |
<h4 style="margin-top:0;">Samples {actual_display_count} - Model Accuracy: {correct_count}/{actual_display_count} = {accuracy_on_problem:.1%}</h4> | |
{samples_grid_html} | |
</div> | |
""" | |
# 获取第一个样本作为初始样本 | |
if samples_data: | |
# 这样样本会在选择问题后立即显示 | |
return problem_content, answer_content, final_html, gr.State(samples_data) | |
else: | |
return problem_content, answer_content, final_html, gr.State([]) | |
def create_problem_grid_html(problems, mode='default'): | |
"""Create HTML for problem grid buttons. The JS function will be defined globally.""" | |
if not problems: | |
return "<div>No problems found for this model/dataset. Please select a model and dataset.</div>" | |
html_buttons = "" | |
try: | |
sorted_problems = sorted( | |
[(str(p[0]), float(p[1]) if p[1] is not None else 0.0, p[2]) for p in problems], | |
key=lambda x: int(re.search(r'\d+', x[0]).group(0)) if re.search(r'\d+', x[0]) else 0 | |
) | |
except Exception as e: | |
return f"<div>Error displaying problems. Check logs. {e}</div>" | |
for pid, accuracy, _ in sorted_problems: | |
match = re.search(r'\d+', pid) | |
num_display = match.group(0) if match else pid | |
acc_pct = int(accuracy * 100) | |
# 获取背景颜色 | |
bg_color = get_gradient_color(accuracy) | |
# 统一使用白色文本,添加!important确保不被覆盖 | |
text_color = "#ffffff" | |
html_buttons += f""" | |
<div | |
data-problem-id=\"{pid}\" | |
class=\"problem-btn\" | |
title=\"ID: {pid} - Acc: {acc_pct}%\" | |
style='background-color: {bg_color}; color: {text_color} !important; | |
border-radius: 4px; padding: 5px; text-align: center; font-size: 0.7em; | |
min-height: 36px; user-select: none; width: 100%; | |
display: flex; flex-direction: column; justify-content: center; | |
overflow: hidden; text-overflow: ellipsis; white-space: nowrap;'> | |
<div style="font-weight: bold; color: {text_color} !important;">{num_display}</div> | |
<div style="color: {text_color} !important;">{acc_pct}%</div> | |
</div> | |
""" | |
# 添加自定义样式强制文本颜色为白色 | |
custom_style = "<style>.problem-btn, .problem-btn div { color: white !important; }</style>" | |
# 根据模式设置每行显示的列数 | |
grid_cols = 20 if mode == 'comparison' else 10 | |
grid_html = f"{custom_style}<div style='display: grid; grid-template-columns: repeat({grid_cols}, 1fr); gap: 4px;'>{html_buttons}</div>" | |
return grid_html | |
def create_ui(db_path): | |
global db | |
db = ModelDatabase(db_path) | |
# Initialize reference data loader with better path handling | |
reference_loader = None | |
# Try multiple possible paths for extra.jsonl | |
possible_paths = [ | |
os.path.join(os.path.dirname(db_path), "extra.jsonl"), | |
os.path.join(os.getcwd(), "extra.jsonl"), | |
"extra.jsonl" | |
] | |
for extra_jsonl_path in possible_paths: | |
if os.path.exists(extra_jsonl_path): | |
try: | |
reference_loader = ReferenceDataLoader(extra_jsonl_path) | |
print(f"Successfully loaded reference data from: {extra_jsonl_path}") | |
break | |
except Exception as e: | |
print(f"Error loading reference data from {extra_jsonl_path}: {e}") | |
continue | |
# If not found locally, try to download from Hugging Face | |
if not reference_loader: | |
try: | |
print("Attempting to download extra.jsonl from Hugging Face...") | |
extra_jsonl_path = hf_hub_download( | |
repo_id="CoderBak/OlymMATH-data", | |
filename="extra.jsonl", | |
repo_type="dataset" | |
) | |
reference_loader = ReferenceDataLoader(extra_jsonl_path) | |
print(f"Successfully downloaded and loaded reference data from: {extra_jsonl_path}") | |
except Exception as e: | |
print(f"Failed to download extra.jsonl from Hugging Face: {e}") | |
if not reference_loader: | |
print("Warning: extra.jsonl not found in any of the expected locations:") | |
for path in possible_paths: | |
print(f" - {path}") | |
print("Reference Solutions tab will not be available.") | |
else: | |
# Test the reference data availability | |
test_reference_data_availability(db, reference_loader) | |
# Pre-compute reference problem accuracies for fast loading | |
precompute_reference_accuracies(db, reference_loader) | |
# Test LaTeX formatting | |
test_latex_formatting() | |
AVAILABLE_DATASETS = db.get_available_datasets() | |
if not AVAILABLE_DATASETS: | |
AVAILABLE_DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] # Fallback | |
# Add MathJax support to the CSS | |
custom_css = """ | |
.padding.svelte-phx28p { padding: unset !important; } | |
body, .gradio-container { font-family: sans-serif; font-size: 0.95em; line-height: 1.6; } | |
.sample-btn { transition: all 0.15s ease-in-out; } | |
.sample-btn:hover { transform: translateY(-1px); box-shadow: 0 2px 5px rgba(0,0,0,0.1); } | |
.problem-grid-container { overflow: visible !important; } | |
.math-content { overflow: visible !important; padding: 5px; } | |
.sample-response { overflow: visible !important; max-height: none !important; height: auto !important; } | |
h1, h2, h3, h4, h5 { margin-top: 0.8em; margin-bottom: 0.4em; color: var(--color-text); } | |
.gradio-tabs > div[role='tablist'] button { font-size: 0.9em; padding: 8px 12px; } | |
.gr-dropdown select { font-size: 0.9em; } | |
.gr-radio label span { font-size: 0.9em; } | |
.gr-checkboxgroup label span { font-size: 0.9em; } | |
.gr-button { font-size: 0.9em; padding: 8px 12px; } | |
.gr-dataframe table { font-size:0.85em; } | |
.gr-markdown { font-size: 1em; } | |
/* 适应深色模式的样式 */ | |
.dark-mode-compatible { | |
background-color: var(--background-fill-primary); | |
color: var(--color-text); | |
border-color: var(--border-color-primary); | |
} | |
.dark-mode-bg-secondary { | |
background-color: var(--background-fill-secondary); | |
} | |
/* DataTable深色模式样式 */ | |
.dataframe-container { | |
//padding: 12px; | |
//border-radius: 8px; | |
//margin-top: 10px; | |
} | |
/* MathJax Styles for Gradio's Built-in LaTeX */ | |
.math-inline, .math-display { | |
font-size: 110%; | |
} | |
.math-container p { | |
margin: 0.5em 0; | |
} | |
/* Markdown content styles */ | |
.gr-markdown strong { | |
font-weight: bold; | |
} | |
.gr-markdown em { | |
font-style: italic; | |
} | |
.gr-markdown ul, .gr-markdown ol { | |
padding-left: 2em; | |
margin: 0.5em 0; | |
} | |
.gr-markdown blockquote { | |
border-left: 3px solid #ccc; | |
margin: 0.5em 0; | |
padding-left: 1em; | |
color: #666; | |
} | |
.gr-markdown pre, .gr-markdown code { | |
background-color: rgba(0,0,0,0.05); | |
padding: 2px 4px; | |
border-radius: 3px; | |
font-family: monospace; | |
} | |
.gr-markdown table { | |
border-collapse: collapse; | |
margin: 0.5em 0; | |
} | |
.gr-markdown th, .gr-markdown td { | |
border: 1px solid #ddd; | |
padding: 4px 8px; | |
} | |
/* 隐藏滚动条但保留功能 */ | |
::-webkit-scrollbar { | |
display: none !important; | |
width: 0px !important; | |
height: 0px !important; | |
} | |
/* 主容器禁用滚动 */ | |
.gradio-container { | |
overflow-x: hidden !important; | |
} | |
/* Gradio组件容器 */ | |
.gradio-row, .gradio-column { | |
overflow: visible !important; | |
max-height: none !important; | |
} | |
/* HTML组件 */ | |
.gr-html { | |
overflow: visible !important; | |
max-height: none !important; | |
} | |
/* Markdown组件保持可见 */ | |
.gr-markdown { | |
overflow: visible !important; | |
max-height: none !important; | |
} | |
/* 特定的问题网格容器 */ | |
#ref-problem-grid-container, #problem-grid-container, #comp-problem-grid-container-left, #comp-problem-grid-container-right { | |
overflow: visible !important; | |
max-height: none !important; | |
height: auto !important; | |
} | |
/* 样本网格 */ | |
.sample-grid-btn { | |
overflow: visible !important; | |
} | |
/* 确保内容区域不会产生滚动条 */ | |
.gr-form, .gr-box { | |
overflow: visible !important; | |
max-height: none !important; | |
} | |
/* Reference Solutions - 禁止Solution部分的滚动 */ | |
#ref-solution { | |
overflow: hidden !important; | |
max-height: none !important; | |
height: auto !important; | |
} | |
/* 确保Solution内容容器也禁止滚动 */ | |
#ref-solution .gr-markdown { | |
overflow: hidden !important; | |
max-height: none !important; | |
height: auto !important; | |
} | |
""" | |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo: | |
# Remove KaTeX loading script since we're using Gradio's native Markdown with LaTeX | |
current_dataset_state = gr.State(value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else "") | |
current_model_state = gr.State(value=None) | |
comparison_data_state = gr.State(value={}) | |
# 添加当前样本状态 | |
current_sample_state = gr.State(value="0") | |
# 添加当前问题的样本数据状态 | |
current_samples_data_state = gr.State(value=[]) | |
# 为Comparison标签页添加独立状态 | |
comp_dataset_state = gr.State(value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else "") | |
comp_model_state_left = gr.State(value=None) | |
comp_sample_state_left = gr.State(value="0") | |
comp_samples_data_state_left = gr.State(value=[]) | |
comp_model_state_right = gr.State(value=None) | |
comp_sample_state_right = gr.State(value="0") | |
comp_samples_data_state_right = gr.State(value=[]) | |
# 创建占位符State组件替代None | |
dummy_state = gr.State(value=None) | |
# Add JavaScript for handling problem grid clicks | |
demo.load(lambda: None, js=""" | |
() => { | |
// Handle problem button clicks for single model tab | |
function setupProblemGridListeners() { | |
document.addEventListener('click', function(e) { | |
if (e.target.closest('.problem-btn')) { | |
const problemBtn = e.target.closest('.problem-btn'); | |
const problemId = problemBtn.getAttribute('data-problem-id'); | |
if (problemId) { | |
const problemInput = document.getElementById('problem-state-input'); | |
if (problemInput) { | |
problemInput.querySelector('input').value = problemId; | |
problemInput.querySelector('input').dispatchEvent(new Event('input', {bubbles: true})); | |
} | |
} | |
} | |
// Handle comparison problem button clicks | |
if (e.target.closest('#comp-problem-grid-container-left .problem-btn') || | |
e.target.closest('#comp-problem-grid-container-right .problem-btn')) { | |
const problemBtn = e.target.closest('.problem-btn'); | |
const problemId = problemBtn.getAttribute('data-problem-id'); | |
if (problemId) { | |
const problemInput = document.getElementById('comp-problem-state-input'); | |
if (problemInput) { | |
problemInput.querySelector('input').value = problemId; | |
problemInput.querySelector('input').dispatchEvent(new Event('input', {bubbles: true})); | |
} | |
} | |
} | |
// Handle reference problem button clicks | |
if (e.target.closest('#ref-problem-grid-container .ref-problem-btn')) { | |
const problemBtn = e.target.closest('.ref-problem-btn'); | |
const problemId = problemBtn.getAttribute('data-problem-id'); | |
if (problemId) { | |
const problemInput = document.getElementById('ref-problem-state-input'); | |
if (problemInput) { | |
problemInput.querySelector('input').value = problemId; | |
problemInput.querySelector('input').dispatchEvent(new Event('input', {bubbles: true})); | |
} | |
} | |
} | |
}); | |
} | |
// Set up listeners initially and after any DOM changes | |
setupProblemGridListeners(); | |
// Re-setup listeners whenever the DOM changes (for dynamic content) | |
const observer = new MutationObserver(function(mutations) { | |
setupProblemGridListeners(); | |
}); | |
observer.observe(document.body, {childList: true, subtree: true}); | |
} | |
""") | |
with gr.Tabs(): | |
with gr.TabItem("Single Model Analysis"): | |
with gr.Row(variant='compact'): | |
with gr.Column(scale=1, min_width=280): | |
dataset_radio_single = gr.Radio( | |
choices=AVAILABLE_DATASETS, | |
value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else None, | |
label="Select Dataset", | |
interactive=True | |
) | |
model_dropdown = gr.Dropdown( | |
choices=[], # Populated by callback | |
label="Select Model", | |
interactive=True | |
) | |
problem_state_input = gr.Textbox( | |
value="", | |
elem_id="problem-state-input", | |
visible=True, | |
label="Enter Problem ID (0 - 99, acc. below)", | |
container=True, | |
interactive=True, | |
every=0.5 | |
) | |
#gr.Markdown("#### Problem Grid") | |
problem_grid_html_output = gr.HTML( | |
value="<div>Select model and dataset to see problems.</div>", | |
elem_id="problem-grid-container" | |
) | |
gr.Markdown("#### Model Statistics") | |
model_stats_df = gr.DataFrame( | |
headers=["Metric", "Value"], | |
wrap=True, | |
elem_classes="dataframe-container dark-mode-compatible dark-mode-bg-secondary" | |
) | |
with gr.Column(scale=3, min_width=400): | |
with gr.Tabs(): | |
with gr.TabItem("Problem Statement"): | |
problem_markdown_output = gr.Markdown( | |
"Please fill in all the fields.", | |
latex_delimiters=[ | |
{"left": "$", "right": "$", "display": False}, | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\[", "right": "\\]", "display": True} | |
] | |
) | |
with gr.TabItem("Reference Answer"): | |
answer_markdown_output = gr.Markdown( | |
"No answer available.", | |
latex_delimiters=[ | |
{"left": "$", "right": "$", "display": False}, | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\[", "right": "\\]", "display": True} | |
] | |
) | |
# 样本网格 | |
samples_grid_output = gr.HTML("") | |
# 在样本网格下方添加样本选择输入框 | |
with gr.Row(): | |
# 样本选择输入框 | |
sample_number_input = gr.Textbox( | |
value="0", | |
elem_id="sample-number-input", | |
visible=True, | |
label="Enter Sample Number (0 - 63)", | |
container=True, | |
interactive=True, | |
every=0.5 | |
) | |
# 样本内容显示区域 - 使用HTML和Markdown组件分别显示元数据和响应内容 | |
sample_metadata_output = gr.HTML( | |
value="<div>Select a problem first to view samples.</div>", | |
elem_classes="sample-metadata dark-mode-bg-secondary", | |
elem_id="sample-metadata-area" | |
) | |
sample_response_output = gr.Markdown( | |
value="Select a problem first to view samples.", | |
elem_classes="sample-response dark-mode-bg-secondary", | |
elem_id="sample-response-area", | |
latex_delimiters=[ | |
{"left": "$", "right": "$", "display": False}, | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\[", "right": "\\]", "display": True} | |
] | |
) | |
with gr.TabItem("Model Comparison"): | |
# 共享部分 | |
with gr.Row(variant='compact'): | |
comp_dataset_radio = gr.Radio( | |
choices=AVAILABLE_DATASETS, | |
value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else None, | |
label="Select Dataset", | |
interactive=True | |
) | |
comp_problem_state_input = gr.Textbox( | |
value="", | |
elem_id="comp-problem-state-input", | |
visible=True, | |
label="Enter Problem ID (0 - 99, acc. below)", | |
container=True, | |
interactive=True, | |
every=0.5 | |
) | |
# 移动的共享问题和答案显示到这里 | |
with gr.Row(variant='compact'): | |
with gr.Column(scale=1): | |
with gr.Tabs(): | |
with gr.TabItem("Problem Statement"): | |
comp_problem_markdown_output = gr.Markdown( | |
"Please select models and problem.", | |
latex_delimiters=[ | |
{"left": "$", "right": "$", "display": False}, | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\[", "right": "\\]", "display": True} | |
] | |
) | |
with gr.TabItem("Reference Answer"): | |
comp_answer_markdown_output = gr.Markdown( | |
"No answer available.", | |
latex_delimiters=[ | |
{"left": "$", "right": "$", "display": False}, | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\[", "right": "\\]", "display": True} | |
] | |
) | |
# 左右两部分模型比较 | |
with gr.Row(variant='compact'): | |
# 左侧模型 | |
with gr.Column(scale=1): | |
comp_model_dropdown_left = gr.Dropdown( | |
choices=[], # Populated by callback | |
label="Select Model 1", | |
interactive=True | |
) | |
gr.Markdown("#### Problem Grid") | |
comp_problem_grid_html_output_left = gr.HTML( | |
value="<div>Select model and dataset to see problems.</div>", | |
elem_id="comp-problem-grid-container-left" | |
) | |
# 样本网格和选择器 | |
comp_samples_grid_output_left = gr.HTML("") | |
with gr.Row(): | |
comp_sample_number_input_left = gr.Textbox( | |
value="0", | |
elem_id="comp-sample-number-input-left", | |
visible=True, | |
label="Enter Sample Number (0 - 63)", | |
container=True, | |
interactive=True, | |
every=0.5 | |
) | |
# 样本内容显示区域 - 使用HTML和Markdown组件分别显示元数据和响应内容 | |
comp_sample_metadata_output_left = gr.HTML( | |
value="<div>Select a problem first to view samples.</div>", | |
elem_classes="sample-metadata dark-mode-bg-secondary", | |
elem_id="comp-sample-metadata-area-left" | |
) | |
comp_sample_response_output_left = gr.Markdown( | |
value="Select a problem first to view samples.", | |
elem_classes="sample-response dark-mode-bg-secondary", | |
elem_id="comp-sample-response-area-left", | |
latex_delimiters=[ | |
{"left": "$", "right": "$", "display": False}, | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\[", "right": "\\]", "display": True} | |
] | |
) | |
# 右侧模型 | |
with gr.Column(scale=1): | |
comp_model_dropdown_right = gr.Dropdown( | |
choices=[], # Populated by callback | |
label="Select Model 2", | |
interactive=True | |
) | |
gr.Markdown("#### Problem Grid") | |
comp_problem_grid_html_output_right = gr.HTML( | |
value="<div>Select model and dataset to see problems.</div>", | |
elem_id="comp-problem-grid-container-right" | |
) | |
# 样本网格和选择器 | |
comp_samples_grid_output_right = gr.HTML("") | |
with gr.Row(): | |
comp_sample_number_input_right = gr.Textbox( | |
value="0", | |
elem_id="comp-sample-number-input-right", | |
visible=True, | |
label="Enter Sample Number (0 - 63)", | |
container=True, | |
interactive=True, | |
every=0.5 | |
) | |
# 样本内容显示区域 - 使用HTML和Markdown组件分别显示元数据和响应内容 | |
comp_sample_metadata_output_right = gr.HTML( | |
value="<div>Select a problem first to view samples.</div>", | |
elem_classes="sample-metadata dark-mode-bg-secondary", | |
elem_id="comp-sample-metadata-area-right" | |
) | |
comp_sample_response_output_right = gr.Markdown( | |
value="Select a problem first to view samples.", | |
elem_classes="sample-response dark-mode-bg-secondary", | |
elem_id="comp-sample-response-area-right", | |
latex_delimiters=[ | |
{"left": "$", "right": "$", "display": False}, | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\[", "right": "\\]", "display": True} | |
] | |
) | |
with gr.TabItem("Reference Solutions"): | |
with gr.Row(variant='compact'): | |
with gr.Column(scale=1, min_width=280): | |
ref_problem_state_input = gr.Textbox( | |
value="", | |
elem_id="ref-problem-state-input", | |
visible=True, | |
label="Enter Problem ID", | |
container=True, | |
interactive=True, | |
every=0.5 | |
) | |
with gr.Column(scale=3, min_width=400): | |
gr.Markdown("#### Problem Grid (OlymMATH-HARD: All models avg. acc. - Top: EN, Bottom: ZH)") | |
ref_problem_grid_html_output = gr.HTML( | |
value="<div>Loading reference data...</div>", | |
elem_id="ref-problem-grid-container" | |
) | |
# 问题内容显示区域 - 左右分布 | |
with gr.Row(variant='compact'): | |
# 左侧:问题信息 | |
with gr.Column(scale=1): | |
gr.Markdown("#### Problem (EN)") | |
ref_problem_en_output = gr.Markdown( | |
"Please select a problem.", | |
latex_delimiters=[ | |
{"left": "$", "right": "$", "display": False}, | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\[", "right": "\\]", "display": True} | |
] | |
) | |
gr.Markdown("#### Problem (ZH)") | |
ref_problem_zh_output = gr.Markdown( | |
"Please select a problem.", | |
latex_delimiters=[ | |
{"left": "$", "right": "$", "display": False}, | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\[", "right": "\\]", "display": True} | |
] | |
) | |
gr.Markdown("#### Subject") | |
ref_subject_output = gr.Markdown("Please select a problem.") | |
gr.Markdown("#### Answer") | |
ref_answer_output = gr.Markdown( | |
"Please select a problem.", | |
latex_delimiters=[ | |
{"left": "$", "right": "$", "display": False}, | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\[", "right": "\\]", "display": True} | |
] | |
) | |
# 右侧:解答 | |
with gr.Column(scale=1): | |
gr.Markdown("#### Solution") | |
ref_solution_output = gr.Markdown( | |
"Please select a problem.", | |
elem_id="ref-solution", | |
latex_delimiters=[ | |
{"left": "$", "right": "$", "display": False}, | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\[", "right": "\\]", "display": True}, | |
{"left": "\\begin{align}", "right": "\\end{align}", "display": True}, | |
{"left": "\\begin{aligned}", "right": "\\end{aligned}", "display": True}, | |
{"left": "\\begin{equation}", "right": "\\end{equation}", "display": True} | |
] | |
) | |
# --- Event Handlers --- | |
def update_available_models_for_dropdowns(selected_dataset): | |
# This function can be used to update model lists if they are dataset-dependent | |
# For now, assume get_available_models() gets all models irrespective of dataset for dropdown population | |
all_models = db.get_available_models() | |
# For single model tab, format with accuracy on the selected dataset | |
single_model_options = [] | |
model_to_display_map = {} # 映射用于存储真实模型名称到显示名称的映射 | |
if selected_dataset and all_models: | |
model_accs = db.get_all_model_accuracies(selected_dataset) | |
model_acc_map = {name: acc for name, acc in model_accs} | |
single_model_options = [] | |
for name in all_models: | |
# 使用MODEL_TRANS映射模型名称 | |
display_name = MODEL_TRANS.get(name, name) | |
acc_display = f" ({model_acc_map.get(name, 0):.1%})" if model_acc_map.get(name) is not None else "" | |
display_text = f"{display_name}{acc_display}" | |
single_model_options.append(display_text) | |
model_to_display_map[display_text] = name # 存储映射关系 | |
else: | |
for name in all_models: | |
display_name = MODEL_TRANS.get(name, name) | |
single_model_options.append(display_name) | |
model_to_display_map[display_name] = name | |
# 将映射存储到全局数据库对象中以便后续使用 | |
db.model_display_to_real = model_to_display_map | |
# For comparison tab, also use formatted model names with accuracy | |
comp_model_choices = single_model_options # 使用和单模型相同的选项,包含准确率 | |
db.comp_model_display_to_real = model_to_display_map # 使用相同的映射 | |
return gr.Dropdown(choices=single_model_options if single_model_options else [], value=None), \ | |
gr.Dropdown(choices=comp_model_choices if comp_model_choices else [], value=None) | |
def update_problem_grid_and_stats(selected_model_formatted, selected_dataset, mode='default'): | |
if not selected_model_formatted or not selected_dataset: | |
# Return empty/default values for all outputs, including the state | |
return gr.DataFrame(value=[]), gr.HTML("<div>Please select a model and dataset first.</div>"), None | |
# 从映射中获取真实模型名称 | |
model_name = db.model_display_to_real.get(selected_model_formatted, selected_model_formatted) | |
# 如果找不到确切匹配,可能是因为准确率等动态内容导致,尝试前缀匹配 | |
if model_name == selected_model_formatted: | |
for display_name, real_name in db.model_display_to_real.items(): | |
if selected_model_formatted.startswith(display_name.split(" (")[0]): | |
model_name = real_name | |
break | |
stats_data = db.get_model_statistics(model_name, selected_dataset) | |
problem_list = db.get_problems_by_model_dataset(model_name, selected_dataset) | |
grid_html = create_problem_grid_html(problem_list, mode=mode) | |
# Correctly return the actual value for the current_model_state output | |
return gr.DataFrame(value=stats_data), gr.HTML(value=grid_html), model_name | |
# Single Model Tab interactions | |
dataset_radio_single.change( | |
fn=update_available_models_for_dropdowns, | |
inputs=[dataset_radio_single], | |
outputs=[model_dropdown, comp_model_dropdown_left] | |
).then( | |
lambda ds: (gr.DataFrame(value=[]), gr.HTML("<div>Select a model.</div>"), gr.State(value=None), ds, ""), # 清空所有输出,包括problem_state_input | |
inputs=[dataset_radio_single], | |
outputs=[model_stats_df, problem_grid_html_output, current_model_state, current_dataset_state, problem_state_input] | |
).then( | |
# 重置Sample Number为0 | |
fn=lambda: "0", | |
inputs=[], | |
outputs=[sample_number_input] | |
).then( | |
lambda: ("Please fill in all the fields.", "No answer available.", "", gr.State([]), "<div>Select a problem first to view samples.</div>", ""), | |
inputs=[], | |
outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output] | |
) | |
# Initial population of model dropdowns based on default dataset | |
demo.load( | |
fn=update_available_models_for_dropdowns, | |
inputs=[current_dataset_state], # Uses initial value of state | |
outputs=[model_dropdown, comp_model_dropdown_left] | |
).then( | |
lambda ds_val: (gr.DataFrame(value=[]), gr.HTML("<div>Select a model.</div>"), ds_val), # Also update dataset state for single tab | |
inputs=[current_dataset_state], | |
outputs=[model_stats_df, problem_grid_html_output, current_dataset_state] | |
).then( | |
lambda: ("Please fill in all the fields.", "No answer available.", "", gr.State([]), "<div>Select a problem first to view samples.</div>", ""), | |
inputs=[], | |
outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output] | |
).then( | |
# 重置Sample Number为0 | |
fn=lambda: "0", | |
inputs=[], | |
outputs=[sample_number_input] | |
) | |
# ==== 比较页面事件处理 ==== | |
# 初始化两侧模型下拉列表 | |
demo.load( | |
fn=update_available_models_for_dropdowns, | |
inputs=[comp_dataset_state], | |
outputs=[model_dropdown, comp_model_dropdown_left] | |
).then( | |
fn=update_available_models_for_dropdowns, | |
inputs=[comp_dataset_state], | |
outputs=[model_dropdown, comp_model_dropdown_right] | |
) | |
# 数据集改变事件 | |
comp_dataset_radio.change( | |
fn=lambda ds: ds, | |
inputs=[comp_dataset_radio], | |
outputs=[comp_dataset_state] | |
).then( | |
fn=update_available_models_for_dropdowns, | |
inputs=[comp_dataset_state], | |
outputs=[model_dropdown, comp_model_dropdown_left] | |
).then( | |
fn=update_available_models_for_dropdowns, | |
inputs=[comp_dataset_state], | |
outputs=[model_dropdown, comp_model_dropdown_right] | |
).then( | |
lambda: ("Please select a dataset and enter a problem ID.", "No answer available."), | |
inputs=[], | |
outputs=[comp_problem_markdown_output, comp_answer_markdown_output] | |
) | |
# 为比较页面的问题ID添加单独的更新逻辑 | |
comp_problem_state_input.change( | |
fn=handle_comparison_problem_update, | |
inputs=[comp_problem_state_input, comp_dataset_state], | |
outputs=[comp_problem_markdown_output, comp_answer_markdown_output] | |
) | |
# 创建包装函数,预设模式参数 | |
def update_problem_grid_comparison(model, dataset): | |
return update_problem_grid_and_stats(model, dataset, mode='comparison') | |
# 问题选择的包装函数 | |
def handle_problem_select_comparison(problem_id, model_state, dataset_state): | |
return handle_problem_select(problem_id, model_state, dataset_state, mode='comparison') | |
# 修改model_dropdown的处理函数,以重新查询当前问题响应 - 比较页面左侧 | |
def update_model_and_requery_problem_left(model_dropdown_value, current_dataset, current_problem_id): | |
# 首先更新模型统计和问题网格 | |
_, grid_html, new_model_state = update_problem_grid_comparison(model_dropdown_value, current_dataset) | |
# 如果有选择的问题ID,重新查询它的响应 | |
if current_problem_id: | |
problem_content, answer_content, samples_grid_html, new_samples_data = handle_problem_select_comparison(current_problem_id, new_model_state, current_dataset) | |
# 获取第一个样本的内容 | |
first_metadata, first_response = handle_first_sample(new_samples_data) | |
return grid_html, new_model_state, problem_content, answer_content, samples_grid_html, new_samples_data, first_metadata, first_response | |
else: | |
# 没有问题ID,只返回更新的模型状态 | |
return grid_html, new_model_state, "Please enter a problem ID.", "No answer available.", "", gr.State([]), "<div>Select a problem first to view samples.</div>", "" | |
# 修改model_dropdown的处理函数,以重新查询当前问题响应 - 比较页面右侧 | |
def update_model_and_requery_problem_right(model_dropdown_value, current_dataset, current_problem_id): | |
# 首先更新模型统计和问题网格 | |
_, grid_html, new_model_state = update_problem_grid_comparison(model_dropdown_value, current_dataset) | |
# 如果有选择的问题ID,重新查询它的响应 | |
if current_problem_id: | |
# 对于右侧,我们不需要更新问题和答案内容 | |
_, _, samples_grid_html, new_samples_data = handle_problem_select_comparison(current_problem_id, new_model_state, current_dataset) | |
# 获取第一个样本的内容 | |
first_metadata, first_response = handle_first_sample(new_samples_data) | |
return grid_html, new_model_state, samples_grid_html, new_samples_data, first_metadata, first_response | |
else: | |
# 没有问题ID,只返回更新的模型状态 | |
return grid_html, new_model_state, "", gr.State([]), "<div>Select a problem first to view samples.</div>", "" | |
# 左侧模型选择事件 | |
comp_model_dropdown_left.change( | |
fn=update_model_and_requery_problem_left, | |
inputs=[comp_model_dropdown_left, comp_dataset_state, comp_problem_state_input], | |
outputs=[comp_problem_grid_html_output_left, comp_model_state_left, comp_problem_markdown_output, comp_answer_markdown_output, comp_samples_grid_output_left, comp_samples_data_state_left, comp_sample_metadata_output_left, comp_sample_response_output_left] | |
).then( | |
# 重置Sample Number为0 | |
fn=lambda: "0", | |
inputs=[], | |
outputs=[comp_sample_number_input_left] | |
) | |
# 右侧模型选择事件 | |
comp_model_dropdown_right.change( | |
fn=update_model_and_requery_problem_right, | |
inputs=[comp_model_dropdown_right, comp_dataset_state, comp_problem_state_input], | |
outputs=[comp_problem_grid_html_output_right, comp_model_state_right, comp_samples_grid_output_right, comp_samples_data_state_right, comp_sample_metadata_output_right, comp_sample_response_output_right] | |
).then( | |
# 重置Sample Number为0 | |
fn=lambda: "0", | |
inputs=[], | |
outputs=[comp_sample_number_input_right] | |
) | |
# 左侧样本选择 | |
comp_sample_number_input_left.change( | |
fn=handle_sample_select, | |
inputs=[comp_sample_number_input_left, comp_samples_data_state_left], | |
outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] | |
) | |
# 右侧样本选择 | |
comp_sample_number_input_right.change( | |
fn=handle_sample_select, | |
inputs=[comp_sample_number_input_right, comp_samples_data_state_right], | |
outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] | |
) | |
# 为比较页面问题选择事件添加处理 | |
comp_problem_state_input.change( | |
fn=handle_problem_select_comparison, | |
inputs=[comp_problem_state_input, comp_model_state_left, comp_dataset_state], | |
outputs=[comp_problem_markdown_output, comp_answer_markdown_output, comp_samples_grid_output_left, comp_samples_data_state_left] | |
).then( | |
# 重置Sample Number为0 | |
fn=lambda: "0", | |
inputs=[], | |
outputs=[comp_sample_number_input_left] | |
).then( | |
fn=handle_first_sample, | |
inputs=[comp_samples_data_state_left], | |
outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] | |
) | |
# 问题选择事件 - 右侧模型 | |
comp_problem_state_input.change( | |
fn=handle_problem_select_comparison, | |
inputs=[comp_problem_state_input, comp_model_state_right, comp_dataset_state], | |
outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right] | |
).then( | |
# 重置Sample Number为0 | |
fn=lambda: "0", | |
inputs=[], | |
outputs=[comp_sample_number_input_right] | |
).then( | |
fn=handle_first_sample, | |
inputs=[comp_samples_data_state_right], | |
outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] | |
) | |
# This is the crucial link: problem_state_input is changed by user, triggers this Python callback. | |
problem_state_input.change( | |
fn=handle_problem_select, | |
inputs=[problem_state_input, current_model_state, current_dataset_state], | |
outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state] | |
).then( | |
# 重置Sample Number为0 | |
fn=lambda: "0", | |
inputs=[], | |
outputs=[sample_number_input] | |
).then( | |
fn=handle_first_sample, | |
inputs=[current_samples_data_state], | |
outputs=[sample_metadata_output, sample_response_output] | |
) | |
# Also listen for direct input event which may be more reliable than change | |
problem_state_input.input( | |
fn=handle_problem_select, | |
inputs=[problem_state_input, current_model_state, current_dataset_state], | |
outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state] | |
).then( | |
# 重置Sample Number为0 | |
fn=lambda: "0", | |
inputs=[], | |
outputs=[sample_number_input] | |
).then( | |
fn=handle_first_sample, | |
inputs=[current_samples_data_state], | |
outputs=[sample_metadata_output, sample_response_output] | |
) | |
# 添加样本编号的事件处理 | |
sample_number_input.change( | |
fn=handle_sample_select, | |
inputs=[sample_number_input, current_samples_data_state], | |
outputs=[sample_metadata_output, sample_response_output] | |
) | |
sample_number_input.input( | |
fn=handle_sample_select, | |
inputs=[sample_number_input, current_samples_data_state], | |
outputs=[sample_metadata_output, sample_response_output] | |
) | |
# 修改model_dropdown.change处理函数,以重新查询当前问题响应 | |
def update_model_and_requery_problem(model_dropdown_value, current_dataset, current_problem_id): | |
# 首先更新模型统计和问题网格 | |
stats_df, grid_html, new_model_state = update_problem_grid_and_stats(model_dropdown_value, current_dataset) | |
# 如果有选择的问题ID,重新查询它的响应 | |
if current_problem_id: | |
problem_content, answer_content, samples_grid_html, new_samples_data = handle_problem_select(current_problem_id, new_model_state, current_dataset) | |
# 获取第一个样本的内容 | |
first_metadata, first_response = handle_first_sample(new_samples_data) | |
return stats_df, grid_html, new_model_state, problem_content, answer_content, samples_grid_html, new_samples_data, first_metadata, first_response | |
else: | |
# 没有问题ID,只返回更新的模型状态 | |
return stats_df, grid_html, new_model_state, "Please fill in all the fields.", "No answer available.", "", gr.State([]), "<div>Select a problem first to view samples.</div>", "" | |
model_dropdown.change( | |
fn=update_model_and_requery_problem, | |
inputs=[model_dropdown, current_dataset_state, problem_state_input], | |
outputs=[model_stats_df, problem_grid_html_output, current_model_state, problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output] | |
).then( | |
# 重置Sample Number为0 | |
fn=lambda: "0", | |
inputs=[], | |
outputs=[sample_number_input] | |
) | |
# 为引用解决方案标签页添加处理器 | |
# 初始化引用问题网格 | |
demo.load( | |
fn=lambda: create_reference_problem_grid_html(reference_loader, db), | |
inputs=[], | |
outputs=[ref_problem_grid_html_output] | |
) | |
# 引用问题选择事件 | |
ref_problem_state_input.change( | |
fn=handle_reference_problem_select, | |
inputs=[ref_problem_state_input, gr.State(reference_loader)], | |
outputs=[ref_problem_en_output, ref_problem_zh_output, ref_subject_output, ref_answer_output, ref_solution_output] | |
) | |
# This is the crucial link: problem_state_input is changed by user, triggers this Python callback. | |
problem_state_input.change( | |
fn=handle_problem_select, | |
inputs=[problem_state_input, current_model_state, current_dataset_state], | |
outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state] | |
).then( | |
# 重置Sample Number为0 | |
fn=lambda: "0", | |
inputs=[], | |
outputs=[sample_number_input] | |
).then( | |
fn=handle_first_sample, | |
inputs=[current_samples_data_state], | |
outputs=[sample_metadata_output, sample_response_output] | |
) | |
return demo | |
def monitor_memory_usage(): | |
"""监控内存使用情况并在必要时释放缓存""" | |
global db | |
try: | |
process = psutil.Process(os.getpid()) | |
memory_info = process.memory_info() | |
memory_usage_mb = memory_info.rss / 1024 / 1024 | |
# 如果内存使用超过12GB (激进设置),清理缓存 | |
if memory_usage_mb > 12000: # 12GB | |
if db: | |
db.clear_cache('response') # 优先清理响应缓存 | |
gc.collect() | |
# 如果内存使用超过14GB,更激进地清理 | |
if memory_usage_mb > 14000: # 14GB | |
if db: | |
db.clear_cache() # 清理所有缓存 | |
gc.collect() | |
return f"Memory: {memory_usage_mb:.1f} MB" | |
except Exception as e: | |
return "Memory monitor error" | |
def create_reference_problem_grid_html(reference_loader, db): | |
"""Create HTML for reference problem grid with average accuracies (using cache)""" | |
global reference_accuracy_cache | |
if not db: | |
return "<div>Database not available.</div>" | |
if not reference_loader: | |
return "<div><strong>No reference data available.</strong><br>Please ensure <code>extra.jsonl</code> file is in the same directory as the database file or in the current working directory.</div>" | |
problem_ids = reference_loader.get_all_problem_ids() | |
if not problem_ids: | |
return "<div>No reference problems found in extra.jsonl file.</div>" | |
# 如果缓存为空,返回加载提示 | |
if not reference_accuracy_cache: | |
return "<div><strong>Computing problem accuracies...</strong><br>This may take a moment on first load.</div>" | |
print(f"Using cached accuracies for {len(problem_ids)} reference problems") | |
# 创建两行网格:第一行英文,第二行中文 | |
custom_style = "<style>.ref-problem-btn, .ref-problem-btn div { color: white !important; }</style>" | |
html_en = "" | |
html_zh = "" | |
# 按数字顺序排序 | |
sorted_problem_ids = sorted(problem_ids, key=int) | |
for pid in sorted_problem_ids: | |
# 从缓存获取准确率 | |
accuracy_data = reference_accuracy_cache.get(pid, {"EN": 0.0, "ZH": 0.0}) | |
en_acc = accuracy_data["EN"] | |
zh_acc = accuracy_data["ZH"] | |
# 英文版本按钮 | |
en_bg_color = get_gradient_color(en_acc) | |
en_acc_pct = int(en_acc * 100) | |
html_en += f""" | |
<div | |
data-problem-id="{pid}" | |
class="ref-problem-btn" | |
title="ID: {pid} (EN) - Avg Acc: {en_acc_pct}%" | |
style='background-color: {en_bg_color}; color: white !important; | |
border-radius: 4px; padding: 5px; text-align: center; font-size: 0.7em; | |
min-height: 36px; user-select: none; width: 100%; | |
display: flex; flex-direction: column; justify-content: center; | |
overflow: hidden; text-overflow: ellipsis; white-space: nowrap; cursor: pointer;'> | |
<div style="font-weight: bold; color: white !important;">{pid}</div> | |
<div style="color: white !important;">{en_acc_pct}%</div> | |
</div> | |
""" | |
# 中文版本按钮 | |
zh_bg_color = get_gradient_color(zh_acc) | |
zh_acc_pct = int(zh_acc * 100) | |
html_zh += f""" | |
<div | |
data-problem-id="{pid}" | |
class="ref-problem-btn" | |
title="ID: {pid} (ZH) - Avg Acc: {zh_acc_pct}%" | |
style='background-color: {zh_bg_color}; color: white !important; | |
border-radius: 4px; padding: 5px; text-align: center; font-size: 0.7em; | |
min-height: 36px; user-select: none; width: 100%; | |
display: flex; flex-direction: column; justify-content: center; | |
overflow: hidden; text-overflow: ellipsis; white-space: nowrap; cursor: pointer;'> | |
<div style="font-weight: bold; color: white !important;">{pid}</div> | |
<div style="color: white !important;">{zh_acc_pct}%</div> | |
</div> | |
""" | |
# 计算网格列数(根据问题数量) | |
grid_cols = len(sorted_problem_ids) if len(sorted_problem_ids) <= 30 else 30 | |
# 组合成完整的HTML | |
grid_html = f""" | |
{custom_style} | |
<div style='margin-bottom: 10px;'> | |
<div style='display: grid; grid-template-columns: repeat({grid_cols}, 1fr); gap: 2px;'>{html_en}</div> | |
</div> | |
<div> | |
<div style='display: grid; grid-template-columns: repeat({grid_cols}, 1fr); gap: 2px;'>{html_zh}</div> | |
</div> | |
""" | |
return grid_html | |
def handle_reference_problem_select(problem_id, reference_loader): | |
"""Handle reference problem selection and display all information""" | |
if not problem_id or not reference_loader: | |
return ("Please select a problem.", "Please select a problem.", | |
"Please select a problem.", "Please select a problem.", "Please select a problem.") | |
try: | |
problem_id_int = int(problem_id) | |
except ValueError: | |
return ("Please enter a valid problem ID.", "Please enter a valid problem ID.", | |
"Please enter a valid problem ID.", "Please enter a valid problem ID.", "Please enter a valid problem ID.") | |
reference_data = reference_loader.get_problem_data(problem_id_int) | |
if not reference_data: | |
error_msg = f"Problem {problem_id_int} not found in reference data." | |
return (error_msg, error_msg, "No subject available.", "No answer available.", "Solution not available.") | |
# 格式化各个部分 | |
en_problem = format_markdown_with_math(reference_data.get('en_problem', 'Problem (EN) not available.')) | |
zh_problem = format_markdown_with_math(reference_data.get('zh_problem', 'Problem (ZH) not available.')) | |
# 处理答案格式 - 使用特殊的答案格式处理 | |
answer_text = reference_data.get('answer', 'No answer available.') | |
answer = format_answer_with_math(answer_text) | |
# 科目显示 | |
subject_en = reference_data.get('subject', 'Unknown') | |
subject_zh = SUBJECT_TRANS_EN_TO_ZH.get(subject_en, subject_en) | |
subject_display = f"**{subject_en}** / **{subject_zh}**" | |
# Solution - 使用solution字段,通常是中文解答 | |
solution_text = reference_data.get('solution', 'Solution not available.') | |
if solution_text != 'Solution not available.': | |
solution = format_solution_latex(solution_text) | |
else: | |
solution = solution_text | |
return (en_problem, zh_problem, subject_display, answer, solution) | |
def test_reference_data_availability(db, reference_loader): | |
"""Test function to check if reference data is available""" | |
print("=== Reference Data Availability Test ===") | |
# Test database | |
if not db: | |
print("❌ Database is not available") | |
return False | |
# Check database schema | |
try: | |
cursor = db.conn.cursor() | |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") | |
tables = [row[0] for row in cursor.fetchall()] | |
print(f"✅ Database tables: {tables}") | |
# Check problems table | |
cursor.execute("SELECT COUNT(*) FROM problems") | |
problem_count = cursor.fetchone()[0] | |
print(f"✅ Problems table: {problem_count} problems") | |
# Check responses table | |
cursor.execute("SELECT COUNT(*) FROM responses") | |
response_count = cursor.fetchone()[0] | |
print(f"✅ Responses table: {response_count} responses") | |
# Check unique datasets | |
cursor.execute("SELECT DISTINCT dataset FROM responses") | |
datasets = [row[0] for row in cursor.fetchall()] | |
print(f"✅ Available datasets: {datasets}") | |
# Check some sample unique_ids from problems | |
cursor.execute("SELECT unique_id FROM problems LIMIT 10") | |
sample_ids = [row[0] for row in cursor.fetchall()] | |
print(f"✅ Sample problem unique_ids: {sample_ids}") | |
except Exception as e: | |
print(f"❌ Error checking database schema: {e}") | |
models = db.get_available_models() | |
print(f"✅ Database connected: {len(models)} models available") | |
# Test reference loader | |
if not reference_loader: | |
print("❌ Reference loader is not available (extra.jsonl not found)") | |
return False | |
problem_ids = reference_loader.get_all_problem_ids() | |
print(f"✅ Reference loader: {len(problem_ids)} problems available: {problem_ids}") | |
# Test a specific problem (simplified test) | |
if problem_ids: | |
test_id = problem_ids[0] | |
en_unique_id = f"OlymMATH-HARD-{test_id}-EN" | |
zh_unique_id = f"OlymMATH-HARD-{test_id}-ZH" | |
print(f"Testing with constructed IDs: {en_unique_id}, {zh_unique_id}") | |
# Check if problems exist in database | |
problem_en, responses_en = db.get_problem_data(None, "EN-HARD", en_unique_id) | |
problem_zh, responses_zh = db.get_problem_data(None, "ZH-HARD", zh_unique_id) | |
print(f"Test problem {test_id}:") | |
print(f" EN problem exists: {problem_en is not None}") | |
print(f" ZH problem exists: {problem_zh is not None}") | |
if responses_en: | |
print(f" EN responses: {len(responses_en)} found") | |
if responses_zh: | |
print(f" ZH responses: {len(responses_zh)} found") | |
print("=== End Test ===") | |
return True | |
def test_latex_formatting(): | |
"""Test function to verify LaTeX environment processing""" | |
test_text = """ | |
易知,1, 4, 6, 7, 9 这五个数中的任意两个数之差均不为 4 或 7. | |
$$ | |
\\begin{aligned} | |
\\sum_{n=1}^{2023}f_{n} &= \\sum_{k=0}^{183}\\sum_{i=0}^{10}f_{11k+i} \\\\ | |
&= \\sum_{k=0}^{183}(11 \\times 5k+1+2+3+5 \\times 4+2 \\times 5) \\\\ | |
&= 55 \\times \\frac{183 \\times 184}{2}+184 \\times 36 \\\\ | |
&= 932604. | |
\\end{aligned} | |
$$ | |
故答案为:$\\boxed{932604}$. | |
""" | |
formatted = format_markdown_with_math(test_text) | |
print("=== LaTeX Formatting Test ===") | |
print("Original text contains \\begin{aligned}:", "\\begin{aligned}" in test_text) | |
print("Formatted text contains \\begin{aligned}:", "\\begin{aligned}" in formatted) | |
print("Formatted text (first 300 chars):", formatted[:300]) | |
print("=== End Test ===") | |
return formatted | |
def format_solution_latex(text): | |
"""Preprocess solution text by converting LaTeX delimiters from MathJax to KaTeX format""" | |
if text is None: | |
return "" | |
# Convert $$xxx$$ to \[xxx\] (display math) | |
# Use non-greedy matching and handle multiple lines | |
text = re.sub(r'\$\$(.*?)\$\$', r'\\[\1\\]', text, flags=re.DOTALL) | |
# Convert $xxx$ to \(xxx\) (inline math) | |
# Be careful not to match already converted \[...\] content | |
text = re.sub(r'(?<!\\)\$([^$\n]+?)\$(?!\])', r'\\(\1\\)', text) | |
# Convert newlines for markdown | |
text = text.replace('\r\n', '\n').replace('\r', '\n') | |
# Clean up excessive newlines | |
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text) | |
return text | |
def format_answer_with_math(text): | |
"""Special formatting for answer fields - manually wrap with \(\) delimiters""" | |
if text is None or text.strip() == "" or text == "No answer available.": | |
return text | |
# Convert newlines for markdown | |
text = text.replace('\r\n', '\n').replace('\r', '\n') | |
# Convert $$xxx$$ to $xxx$ first (same as before) | |
text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', text, flags=re.DOTALL) | |
# Check if answer already contains dollar signs, if not add them | |
if '$' not in text and text.strip(): | |
text = f"${text}$" | |
# Now convert $xxx$ to \(xxx\) for proper rendering | |
text = re.sub(r'(?<!\\)\$([^$\n]+?)\$', r'\\(\1\\)', text) | |
# Clean up excessive newlines | |
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text) | |
return text | |
# 修改主函数以使用优化策略 | |
if __name__ == "__main__": | |
DB_PATH = "data.db" | |
# 检查数据库文件是否存在,如果不存在则从 Hugging Face 下载 | |
if not os.path.exists(DB_PATH): | |
try: | |
DB_PATH = hf_hub_download( | |
repo_id="CoderBak/OlymMATH-data", | |
filename="data.db", | |
repo_type="dataset" | |
) | |
except Exception as e: | |
# 创建一个显示错误信息的简单 Gradio 应用 | |
with gr.Blocks() as error_demo: | |
gr.Markdown(f"# Error: Database Download Failed\n{str(e)}") | |
error_demo.launch(server_name="0.0.0.0") | |
exit(1) | |
if os.path.exists(DB_PATH): | |
# 创建UI并启动 | |
db = ModelDatabase(DB_PATH) | |
# 添加清理函数 | |
def cleanup(): | |
global db | |
if db: | |
db.close() | |
# 注册清理函数 | |
import atexit | |
atexit.register(cleanup) | |
# 创建UI | |
main_demo = create_ui(DB_PATH) | |
# 使用兼容的启动参数 | |
main_demo.launch( | |
server_name="0.0.0.0", | |
share=False, | |
inbrowser=False | |
) | |
else: | |
# 创建一个显示错误信息的简单 Gradio 应用 | |
with gr.Blocks() as error_demo: | |
gr.Markdown(f"# Error: Database Not Found\nCould not find `{DB_PATH}`. Please ensure the database file is correctly placed and accessible.") | |
error_demo.launch(server_name="0.0.0.0") |