S2S-Arena / app.py
KurtDu's picture
Update app.py
8c71012 verified
raw
history blame
7.3 kB
import os
import json
import random
import uuid
from flask import Flask, request, jsonify, session, render_template
from flask_cors import CORS
from flask_session import Session # 引入 Flask-Session
from datetime import datetime
from elo_rank import EloRank
app = Flask(__name__)
CORS(app, supports_credentials=True)
# 配置 Flask-Session
app.config['SESSION_TYPE'] = 'filesystem' # 使用文件系统存储
app.config['SESSION_PERMANENT'] = False # 不持久化 session
app.config['SESSION_USE_SIGNER'] = True # 为 session 数据添加签名保护
app.config['SESSION_FILE_DIR'] = '/tmp/flask_session/' # 存储 session 文件的路径
# 确保目录存在
if not os.path.exists('/tmp/flask_session/'):
os.makedirs('/tmp/flask_session/')
# 初始化 Session
Session(app)
app.secret_key = 'supersecretkey'
base_dir = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(base_dir, '/app/data')
RESULTS_DIR = os.path.join(base_dir, '/app/results')
# 实例化 EloRank 系统
elo_rank_system = EloRank()
# 初始化 Elo 排名的模型
models = [
'output_path_4o', 'output_path_miniomni', 'output_path_speechgpt',
'output_path_funaudio', 'output_path_4o_cascade', 'output_path_4o_llama_omni'
]
for model in models:
elo_rank_system.add_model(model)
def print_directory_structure(start_path, indent=''):
for item in os.listdir(start_path):
item_path = os.path.join(start_path, item)
if os.path.isdir(item_path):
print(f"{indent}📁 {item}/")
print_directory_structure(item_path, indent + ' ')
else:
print(f"{indent}📄 {item}")
def load_test_data(task):
"""Load the JSON file corresponding to the selected task"""
# 调用函数,打印当前目录结构
try:
with open('/app/test_text.txt', 'r') as file:
content = file.read()
print(content)
except FileNotFoundError:
print("Test text file not found.")
try:
with open(os.path.join(DATA_DIR, f"{task}.json"), "r", encoding='utf-8') as f:
test_data = json.load(f)
except FileNotFoundError:
return jsonify({"message": "Test data file not found"}), 400
# 更新音频路径,将它们指向 Flask 静态文件夹
for item in test_data:
item['input_path'] = f"/app/static/audio{item['input_path']}"
item['output_path_4o'] = f"/app/static/audio{item['output_path_4o']}"
item['output_path_miniomni'] = f"/app/static/audio{item['output_path_miniomni']}"
item['output_path_speechgpt'] = f"/app/static/audio{item['output_path_speechgpt']}"
item['output_path_funaudio'] = f"/app/static/audio{item['output_path_funaudio']}"
item['output_path_4o_cascade'] = f"/app/static/audio{item['output_path_4o_cascade']}"
item['output_path_4o_llama_omni'] = f"/app/static/audio{item['output_path_4o_llama_omni']}"
return test_data
def save_result(task, username, result_data, session_id):
"""Save user's result in a separate file"""
file_path = os.path.join(RESULTS_DIR, f"{task}_{username}_{session_id}.jsonl")
# 获取所有模型的 Elo 分数
elo_scores = {model: elo_rank_system.get_rating(model) for model in models}
# 添加 Elo 分数和时间戳到结果数据
result_data['elo_scores'] = elo_scores
result_data['timestamp'] = datetime.now().isoformat()
with open(file_path, "a", encoding='utf-8') as f:
f.write(json.dumps(result_data) + "\n")
@app.route('/start_test', methods=['POST'])
def start_test():
"""Initiate the test for a user with the selected task"""
data = request.json
task = data['task']
username = data['username']
# Load the test data
test_data = load_test_data(task)
if isinstance(test_data, tuple):
return test_data # 返回错误信息
# Shuffle test data for the user
random.shuffle(test_data)
# Generate a unique session ID
session_id = str(uuid.uuid4())
# Store in session
session['task'] = task
session['username'] = username
session['test_data'] = test_data
session['current_index'] = 0
session['session_id'] = session_id
task_description = test_data[0].get('task_description', '')
return jsonify({
"message": "Test started",
"total_tests": len(test_data),
"task_description": task_description
})
@app.route('/next_test', methods=['GET'])
def next_test():
"""Serve the next test item"""
if 'current_index' not in session or 'test_data' not in session:
return jsonify({"message": "Session data missing"}), 400
current_index = session['current_index']
test_data = session['test_data']
if current_index >= len(test_data):
return jsonify({"message": "Test completed"}), 200
# 使用 EloRank 的 sample_next_match 来选择两款模型
selected_models = elo_rank_system.sample_next_match()
if not selected_models or len(selected_models) != 2:
return jsonify({"message": "Error selecting models"}), 500
# Serve test data with the two selected models
current_test = test_data[current_index]
session['selected_models'] = selected_models
session['current_index'] += 1
return jsonify({
"text": current_test["text"],
"input_path": current_test["input_path"],
"model_a": selected_models[0],
"model_b": selected_models[1],
"audio_a": current_test[selected_models[0]],
"audio_b": current_test[selected_models[1]]
})
@app.route('/submit_result', methods=['POST'])
def submit_result():
"""Submit the user's result and save it"""
data = request.json
chosen_model = data['chosen_model']
username = session.get('username')
task = session.get('task')
current_index = session.get('current_index') - 1
session_id = session.get('session_id')
if not username or not task or current_index < 0:
return jsonify({"message": "No active test found"}), 400
selected_models = session['selected_models']
model_a = selected_models[0]
model_b = selected_models[1]
result = {
"name": username,
"chosen_model": chosen_model,
"model_a": model_a,
"model_b": model_b,
"result": {
model_a: 1 if chosen_model == 'A' else 0,
model_b: 1 if chosen_model == 'B' else 0
}
}
test_data = session['test_data'][current_index]
result_data = {**test_data, **result}
save_result(task, username, result_data, session_id)
# 更新 Elo 排名系统
if chosen_model == 'A':
elo_rank_system.record_match(model_a, model_b)
else:
elo_rank_system.record_match(model_b, model_a)
return jsonify({
"message": "Result submitted",
"model_a": model_a,
"model_b": model_b,
"chosen_model": chosen_model
})
@app.route('/end_test', methods=['GET'])
def end_test():
"""End the test session"""
session.clear()
return jsonify({"message": "Test completed"})
@app.route('/')
def index():
return render_template('index.html')
if __name__ == '__main__':
if not os.path.exists(RESULTS_DIR):
os.makedirs(RESULTS_DIR)
app.run(host="0.0.0.0", debug=True, port=8080)