TTS-Arena-V2 / app.py
GitHub Actions
Sync from GitHub repo
807d96f
raw
history blame contribute delete
57.6 kB
import os
from huggingface_hub import HfApi, hf_hub_download
from apscheduler.schedulers.background import BackgroundScheduler
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import threading # Added for locking
year = datetime.now().year
month = datetime.now().month
# Check if running in a Huggin Face Space
IS_SPACES = False
if os.getenv("SPACE_REPO_NAME"):
print("Running in a Hugging Face Space 🤗")
IS_SPACES = True
# Setup database sync for HF Spaces
if not os.path.exists("instance/tts_arena.db"):
os.makedirs("instance", exist_ok=True)
try:
print("Database not found, downloading from HF dataset...")
hf_hub_download(
repo_id="TTS-AGI/database-arena-v2",
filename="tts_arena.db",
repo_type="dataset",
local_dir="instance",
token=os.getenv("HF_TOKEN"),
)
print("Database downloaded successfully ✅")
except Exception as e:
print(f"Error downloading database from HF dataset: {str(e)} ⚠️")
from flask import (
Flask,
render_template,
g,
request,
jsonify,
send_file,
redirect,
url_for,
session,
abort,
)
from flask_login import LoginManager, current_user
from models import *
from auth import auth, init_oauth, is_admin
from admin import admin
import os
from dotenv import load_dotenv
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
import uuid
import tempfile
import shutil
from tts import predict_tts
import random
import json
from datetime import datetime, timedelta
from flask_migrate import Migrate
import requests
import functools
import time # Added for potential retries
# Load environment variables
if not IS_SPACES:
load_dotenv() # Only load .env if not running in a Hugging Face Space
app = Flask(__name__)
app.config["SECRET_KEY"] = os.getenv("SECRET_KEY", os.urandom(24))
app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv(
"DATABASE_URI", "sqlite:///tts_arena.db"
)
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
app.config["SESSION_COOKIE_SECURE"] = True
app.config["SESSION_COOKIE_SAMESITE"] = (
"None" if IS_SPACES else "Lax"
) # HF Spaces uses iframes to load the app, so we need to set SAMESITE to None
app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(days=30) # Set to desired duration
# Force HTTPS when running in HuggingFace Spaces
if IS_SPACES:
app.config["PREFERRED_URL_SCHEME"] = "https"
# Cloudflare Turnstile settings
app.config["TURNSTILE_ENABLED"] = (
os.getenv("TURNSTILE_ENABLED", "False").lower() == "true"
)
app.config["TURNSTILE_SITE_KEY"] = os.getenv("TURNSTILE_SITE_KEY", "")
app.config["TURNSTILE_SECRET_KEY"] = os.getenv("TURNSTILE_SECRET_KEY", "")
app.config["TURNSTILE_VERIFY_URL"] = (
"https://challenges.cloudflare.com/turnstile/v0/siteverify"
)
migrate = Migrate(app, db)
# Initialize extensions
db.init_app(app)
login_manager = LoginManager()
login_manager.init_app(app)
login_manager.login_view = "auth.login"
# Initialize OAuth
init_oauth(app)
# Configure rate limits
limiter = Limiter(
app=app,
key_func=get_remote_address,
default_limits=["2000 per day", "50 per minute"],
storage_uri="memory://",
)
# TTS Cache Configuration - Read from environment
TTS_CACHE_SIZE = int(os.getenv("TTS_CACHE_SIZE", "10"))
CACHE_AUDIO_SUBDIR = "cache"
tts_cache = {} # sentence -> {model_a, model_b, audio_a, audio_b, created_at}
tts_cache_lock = threading.Lock()
cache_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix='CacheReplacer')
all_harvard_sentences = [] # Keep the full list available
# Create temp directories
TEMP_AUDIO_DIR = os.path.join(tempfile.gettempdir(), "tts_arena_audio")
CACHE_AUDIO_DIR = os.path.join(TEMP_AUDIO_DIR, CACHE_AUDIO_SUBDIR)
os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache subdir exists
# Store active TTS sessions
app.tts_sessions = {}
tts_sessions = app.tts_sessions
# Store active conversational sessions
app.conversational_sessions = {}
conversational_sessions = app.conversational_sessions
# Register blueprints
app.register_blueprint(auth, url_prefix="/auth")
app.register_blueprint(admin)
@login_manager.user_loader
def load_user(user_id):
return User.query.get(int(user_id))
@app.before_request
def before_request():
g.user = current_user
g.is_admin = is_admin(current_user)
# Ensure HTTPS for HuggingFace Spaces environment
if IS_SPACES and request.headers.get("X-Forwarded-Proto") == "http":
url = request.url.replace("http://", "https://", 1)
return redirect(url, code=301)
# Check if Turnstile verification is required
if app.config["TURNSTILE_ENABLED"]:
# Exclude verification routes
excluded_routes = ["verify_turnstile", "turnstile_page", "static"]
if request.endpoint not in excluded_routes:
# Check if user is verified
if not session.get("turnstile_verified"):
# Save original URL for redirect after verification
redirect_url = request.url
# Force HTTPS in HuggingFace Spaces
if IS_SPACES and redirect_url.startswith("http://"):
redirect_url = redirect_url.replace("http://", "https://", 1)
# If it's an API request, return a JSON response
if request.path.startswith("/api/"):
return jsonify({"error": "Turnstile verification required"}), 403
# For regular requests, redirect to verification page
return redirect(url_for("turnstile_page", redirect_url=redirect_url))
else:
# Check if verification has expired (default: 24 hours)
verification_timeout = (
int(os.getenv("TURNSTILE_TIMEOUT_HOURS", "24")) * 3600
) # Convert hours to seconds
verified_at = session.get("turnstile_verified_at", 0)
current_time = datetime.utcnow().timestamp()
if current_time - verified_at > verification_timeout:
# Verification expired, clear status and redirect to verification page
session.pop("turnstile_verified", None)
session.pop("turnstile_verified_at", None)
redirect_url = request.url
# Force HTTPS in HuggingFace Spaces
if IS_SPACES and redirect_url.startswith("http://"):
redirect_url = redirect_url.replace("http://", "https://", 1)
if request.path.startswith("/api/"):
return jsonify({"error": "Turnstile verification expired"}), 403
return redirect(
url_for("turnstile_page", redirect_url=redirect_url)
)
@app.route("/turnstile", methods=["GET"])
def turnstile_page():
"""Display Cloudflare Turnstile verification page"""
redirect_url = request.args.get("redirect_url", url_for("arena", _external=True))
# Force HTTPS in HuggingFace Spaces
if IS_SPACES and redirect_url.startswith("http://"):
redirect_url = redirect_url.replace("http://", "https://", 1)
return render_template(
"turnstile.html",
turnstile_site_key=app.config["TURNSTILE_SITE_KEY"],
redirect_url=redirect_url,
)
@app.route("/verify-turnstile", methods=["POST"])
def verify_turnstile():
"""Verify Cloudflare Turnstile token"""
token = request.form.get("cf-turnstile-response")
redirect_url = request.form.get("redirect_url", url_for("arena", _external=True))
# Force HTTPS in HuggingFace Spaces
if IS_SPACES and redirect_url.startswith("http://"):
redirect_url = redirect_url.replace("http://", "https://", 1)
if not token:
# If AJAX request, return JSON error
if request.headers.get("X-Requested-With") == "XMLHttpRequest":
return (
jsonify({"success": False, "error": "Missing verification token"}),
400,
)
# Otherwise redirect back to turnstile page
return redirect(url_for("turnstile_page", redirect_url=redirect_url))
# Verify token with Cloudflare
data = {
"secret": app.config["TURNSTILE_SECRET_KEY"],
"response": token,
"remoteip": request.remote_addr,
}
try:
response = requests.post(app.config["TURNSTILE_VERIFY_URL"], data=data)
result = response.json()
if result.get("success"):
# Set verification status in session
session["turnstile_verified"] = True
session["turnstile_verified_at"] = datetime.utcnow().timestamp()
# Determine response type based on request
is_xhr = request.headers.get("X-Requested-With") == "XMLHttpRequest"
accepts_json = "application/json" in request.headers.get("Accept", "")
# If AJAX or JSON request, return success JSON
if is_xhr or accepts_json:
return jsonify({"success": True, "redirect": redirect_url})
# For regular form submissions, redirect to the target URL
return redirect(redirect_url)
else:
# Verification failed
app.logger.warning(f"Turnstile verification failed: {result}")
# If AJAX request, return JSON error
if request.headers.get("X-Requested-With") == "XMLHttpRequest":
return jsonify({"success": False, "error": "Verification failed"}), 403
# Otherwise redirect back to turnstile page
return redirect(url_for("turnstile_page", redirect_url=redirect_url))
except Exception as e:
app.logger.error(f"Turnstile verification error: {str(e)}")
# If AJAX request, return JSON error
if request.headers.get("X-Requested-With") == "XMLHttpRequest":
return (
jsonify(
{"success": False, "error": "Server error during verification"}
),
500,
)
# Otherwise redirect back to turnstile page
return redirect(url_for("turnstile_page", redirect_url=redirect_url))
with open("sentences.txt", "r") as f, open("emotional_sentences.txt", "r") as f_emotional:
# Store all sentences and clean them up
all_harvard_sentences = [line.strip() for line in f.readlines() if line.strip()] + [line.strip() for line in f_emotional.readlines() if line.strip()]
# Shuffle for initial random selection if needed, but main list remains ordered
initial_sentences = random.sample(all_harvard_sentences, min(len(all_harvard_sentences), 500)) # Limit initial pass for template
@app.route("/")
def arena():
# Pass a subset of sentences for the random button fallback
return render_template("arena.html", harvard_sentences=json.dumps(initial_sentences))
@app.route("/leaderboard")
def leaderboard():
tts_leaderboard = get_leaderboard_data(ModelType.TTS)
conversational_leaderboard = get_leaderboard_data(ModelType.CONVERSATIONAL)
top_voters = get_top_voters(10) # Get top 10 voters
# Initialize personal leaderboard data
tts_personal_leaderboard = None
conversational_personal_leaderboard = None
user_leaderboard_visibility = None
# If user is logged in, get their personal leaderboard and visibility setting
if current_user.is_authenticated:
tts_personal_leaderboard = get_user_leaderboard(current_user.id, ModelType.TTS)
conversational_personal_leaderboard = get_user_leaderboard(
current_user.id, ModelType.CONVERSATIONAL
)
user_leaderboard_visibility = current_user.show_in_leaderboard
# Get key dates for the timeline
tts_key_dates = get_key_historical_dates(ModelType.TTS)
conversational_key_dates = get_key_historical_dates(ModelType.CONVERSATIONAL)
# Format dates for display in the dropdown
formatted_tts_dates = [date.strftime("%B %Y") for date in tts_key_dates]
formatted_conversational_dates = [
date.strftime("%B %Y") for date in conversational_key_dates
]
return render_template(
"leaderboard.html",
tts_leaderboard=tts_leaderboard,
conversational_leaderboard=conversational_leaderboard,
tts_personal_leaderboard=tts_personal_leaderboard,
conversational_personal_leaderboard=conversational_personal_leaderboard,
tts_key_dates=tts_key_dates,
conversational_key_dates=conversational_key_dates,
formatted_tts_dates=formatted_tts_dates,
formatted_conversational_dates=formatted_conversational_dates,
top_voters=top_voters,
user_leaderboard_visibility=user_leaderboard_visibility
)
@app.route("/api/historical-leaderboard/<model_type>")
def historical_leaderboard(model_type):
"""Get historical leaderboard data for a specific date"""
if model_type not in [ModelType.TTS, ModelType.CONVERSATIONAL]:
return jsonify({"error": "Invalid model type"}), 400
# Get date from query parameter
date_str = request.args.get("date")
if not date_str:
return jsonify({"error": "Date parameter is required"}), 400
try:
# Parse date from URL parameter (format: YYYY-MM-DD)
target_date = datetime.strptime(date_str, "%Y-%m-%d")
# Get historical leaderboard data
leaderboard_data = get_historical_leaderboard_data(model_type, target_date)
return jsonify(
{"date": target_date.strftime("%B %d, %Y"), "leaderboard": leaderboard_data}
)
except ValueError:
return jsonify({"error": "Invalid date format. Use YYYY-MM-DD"}), 400
@app.route("/about")
def about():
return render_template("about.html")
# --- TTS Caching Functions ---
def generate_and_save_tts(text, model_id, output_dir):
"""Generates TTS and saves it to a specific directory, returning the full path."""
temp_audio_path = None # Initialize to None
try:
app.logger.debug(f"[TTS Gen {model_id}] Starting generation for: '{text[:30]}...'")
# If predict_tts saves file itself and returns path:
temp_audio_path = predict_tts(text, model_id)
app.logger.debug(f"[TTS Gen {model_id}] predict_tts returned: {temp_audio_path}")
if not temp_audio_path or not os.path.exists(temp_audio_path):
app.logger.warning(f"[TTS Gen {model_id}] predict_tts failed or returned invalid path: {temp_audio_path}")
raise ValueError("predict_tts did not return a valid path or file does not exist")
file_uuid = str(uuid.uuid4())
dest_path = os.path.join(output_dir, f"{file_uuid}.wav")
app.logger.debug(f"[TTS Gen {model_id}] Moving {temp_audio_path} to {dest_path}")
# Move the file generated by predict_tts to the target cache directory
shutil.move(temp_audio_path, dest_path)
app.logger.debug(f"[TTS Gen {model_id}] Move successful. Returning {dest_path}")
return dest_path
except Exception as e:
app.logger.error(f"Error generating/saving TTS for model {model_id} and text '{text[:30]}...': {str(e)}")
# Ensure temporary file from predict_tts (if any) is cleaned up
if temp_audio_path and os.path.exists(temp_audio_path):
try:
app.logger.debug(f"[TTS Gen {model_id}] Cleaning up temporary file {temp_audio_path} after error.")
os.remove(temp_audio_path)
except OSError:
pass # Ignore error if file couldn't be removed
return None
def _generate_cache_entry_task(sentence):
"""Task function to generate audio for a sentence and add to cache."""
# Wrap the entire task in an application context
with app.app_context():
if not sentence:
# Select a new sentence if not provided (for replacement)
with tts_cache_lock:
cached_keys = set(tts_cache.keys())
available_sentences = [s for s in all_harvard_sentences if s not in cached_keys]
if not available_sentences:
app.logger.warning("No more unique Harvard sentences available for caching.")
return
sentence = random.choice(available_sentences)
# app.logger.info removed duplicate log
print(f"[Cache Task] Querying models for: '{sentence[:50]}...'")
available_models = Model.query.filter_by(
model_type=ModelType.TTS, is_active=True
).all()
if len(available_models) < 2:
app.logger.error("Not enough active TTS models to generate cache entry.")
return
try:
models = random.sample(available_models, 2)
model_a_id = models[0].id
model_b_id = models[1].id
# Generate audio concurrently using a local executor for clarity within the task
with ThreadPoolExecutor(max_workers=2, thread_name_prefix='AudioGen') as audio_executor:
future_a = audio_executor.submit(generate_and_save_tts, sentence, model_a_id, CACHE_AUDIO_DIR)
future_b = audio_executor.submit(generate_and_save_tts, sentence, model_b_id, CACHE_AUDIO_DIR)
timeout_seconds = 120
audio_a_path = future_a.result(timeout=timeout_seconds)
audio_b_path = future_b.result(timeout=timeout_seconds)
if audio_a_path and audio_b_path:
with tts_cache_lock:
# Only add if the sentence isn't already back in the cache
# And ensure cache size doesn't exceed limit
if sentence not in tts_cache and len(tts_cache) < TTS_CACHE_SIZE:
tts_cache[sentence] = {
"model_a": model_a_id,
"model_b": model_b_id,
"audio_a": audio_a_path,
"audio_b": audio_b_path,
"created_at": datetime.utcnow(),
}
app.logger.info(f"Successfully cached entry for: '{sentence[:50]}...'")
elif sentence in tts_cache:
app.logger.warning(f"Sentence '{sentence[:50]}...' already re-cached. Discarding new generation.")
# Clean up the newly generated files if not added
if os.path.exists(audio_a_path): os.remove(audio_a_path)
if os.path.exists(audio_b_path): os.remove(audio_b_path)
else: # Cache is full
app.logger.warning(f"Cache is full ({len(tts_cache)} entries). Discarding new generation for '{sentence[:50]}...'.")
# Clean up the newly generated files if not added
if os.path.exists(audio_a_path): os.remove(audio_a_path)
if os.path.exists(audio_b_path): os.remove(audio_b_path)
else:
app.logger.error(f"Failed to generate one or both audio files for cache: '{sentence[:50]}...'")
# Clean up whichever file might have been created
if audio_a_path and os.path.exists(audio_a_path): os.remove(audio_a_path)
if audio_b_path and os.path.exists(audio_b_path): os.remove(audio_b_path)
except Exception as e:
# Log the exception within the app context
app.logger.error(f"Exception in _generate_cache_entry_task for '{sentence[:50]}...': {str(e)}", exc_info=True)
def initialize_tts_cache():
print("Initializing TTS cache")
"""Selects initial sentences and starts generation tasks."""
with app.app_context(): # Ensure access to models
if not all_harvard_sentences:
app.logger.error("Harvard sentences not loaded. Cannot initialize cache.")
return
initial_selection = random.sample(all_harvard_sentences, min(len(all_harvard_sentences), TTS_CACHE_SIZE))
app.logger.info(f"Initializing TTS cache with {len(initial_selection)} sentences...")
for sentence in initial_selection:
# Use the main cache_executor for initial population too
cache_executor.submit(_generate_cache_entry_task, sentence)
app.logger.info("Submitted initial cache generation tasks.")
# --- End TTS Caching Functions ---
@app.route("/api/tts/generate", methods=["POST"])
@limiter.limit("10 per minute") # Keep limit, cached responses are still requests
def generate_tts():
# If verification not setup, handle it first
if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
return jsonify({"error": "Turnstile verification required"}), 403
data = request.json
text = data.get("text", "").strip() # Ensure text is stripped
if not text or len(text) > 1000:
return jsonify({"error": "Invalid or too long text"}), 400
# --- Cache Check ---
cache_hit = False
session_data_from_cache = None
with tts_cache_lock:
if text in tts_cache:
cache_hit = True
cached_entry = tts_cache.pop(text) # Remove from cache immediately
app.logger.info(f"TTS Cache HIT for: '{text[:50]}...'")
# Prepare session data using cached info
session_id = str(uuid.uuid4())
session_data_from_cache = {
"model_a": cached_entry["model_a"],
"model_b": cached_entry["model_b"],
"audio_a": cached_entry["audio_a"], # Paths are now from cache_dir
"audio_b": cached_entry["audio_b"],
"text": text,
"created_at": datetime.utcnow(),
"expires_at": datetime.utcnow() + timedelta(minutes=30),
"voted": False,
}
app.tts_sessions[session_id] = session_data_from_cache
# Trigger background task to replace the used cache entry
cache_executor.submit(_generate_cache_entry_task, None) # Pass None to signal replacement
if cache_hit and session_data_from_cache:
# Return response using cached data
# Note: The files are now managed by the session lifecycle (cleanup_session)
return jsonify(
{
"session_id": session_id,
"audio_a": f"/api/tts/audio/{session_id}/a",
"audio_b": f"/api/tts/audio/{session_id}/b",
"expires_in": 1800, # 30 minutes in seconds
"cache_hit": True,
}
)
# --- End Cache Check ---
# --- Cache Miss: Generate on the fly ---
app.logger.info(f"TTS Cache MISS for: '{text[:50]}...'. Generating on the fly.")
available_models = Model.query.filter_by(
model_type=ModelType.TTS, is_active=True
).all()
if len(available_models) < 2:
return jsonify({"error": "Not enough TTS models available"}), 500
selected_models = random.sample(available_models, 2)
try:
audio_files = []
model_ids = []
# Function to process a single model (generate directly to TEMP_AUDIO_DIR, not cache subdir)
def process_model_on_the_fly(model):
# Generate and save directly to the main temp dir
# Assume predict_tts handles saving temporary files
temp_audio_path = predict_tts(text, model.id)
if not temp_audio_path or not os.path.exists(temp_audio_path):
raise ValueError(f"predict_tts failed for model {model.id}")
# Create a unique name in the main TEMP_AUDIO_DIR for the session
file_uuid = str(uuid.uuid4())
dest_path = os.path.join(TEMP_AUDIO_DIR, f"{file_uuid}.wav")
shutil.move(temp_audio_path, dest_path) # Move from predict_tts's temp location
return {"model_id": model.id, "audio_path": dest_path}
# Use ThreadPoolExecutor to process models concurrently
with ThreadPoolExecutor(max_workers=2) as executor:
results = list(executor.map(process_model_on_the_fly, selected_models))
# Extract results
for result in results:
model_ids.append(result["model_id"])
audio_files.append(result["audio_path"])
# Create session
session_id = str(uuid.uuid4())
app.tts_sessions[session_id] = {
"model_a": model_ids[0],
"model_b": model_ids[1],
"audio_a": audio_files[0], # Paths are now from TEMP_AUDIO_DIR directly
"audio_b": audio_files[1],
"text": text,
"created_at": datetime.utcnow(),
"expires_at": datetime.utcnow() + timedelta(minutes=30),
"voted": False,
}
# Return audio file paths and session
return jsonify(
{
"session_id": session_id,
"audio_a": f"/api/tts/audio/{session_id}/a",
"audio_b": f"/api/tts/audio/{session_id}/b",
"expires_in": 1800,
"cache_hit": False,
}
)
except Exception as e:
app.logger.error(f"TTS on-the-fly generation error: {str(e)}", exc_info=True)
# Cleanup any files potentially created during the failed attempt
if 'results' in locals():
for res in results:
if 'audio_path' in res and os.path.exists(res['audio_path']):
try:
os.remove(res['audio_path'])
except OSError:
pass
return jsonify({"error": "Failed to generate TTS"}), 500
# --- End Cache Miss ---
@app.route("/api/tts/audio/<session_id>/<model_key>")
def get_audio(session_id, model_key):
# If verification not setup, handle it first
if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
return jsonify({"error": "Turnstile verification required"}), 403
if session_id not in app.tts_sessions:
return jsonify({"error": "Invalid or expired session"}), 404
session_data = app.tts_sessions[session_id]
# Check if session expired
if datetime.utcnow() > session_data["expires_at"]:
cleanup_session(session_id)
return jsonify({"error": "Session expired"}), 410
if model_key == "a":
audio_path = session_data["audio_a"]
elif model_key == "b":
audio_path = session_data["audio_b"]
else:
return jsonify({"error": "Invalid model key"}), 400
# Check if file exists
if not os.path.exists(audio_path):
return jsonify({"error": "Audio file not found"}), 404
return send_file(audio_path, mimetype="audio/wav")
@app.route("/api/tts/vote", methods=["POST"])
@limiter.limit("30 per minute")
def submit_vote():
# If verification not setup, handle it first
if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
return jsonify({"error": "Turnstile verification required"}), 403
data = request.json
session_id = data.get("session_id")
chosen_model_key = data.get("chosen_model") # "a" or "b"
if not session_id or session_id not in app.tts_sessions:
return jsonify({"error": "Invalid or expired session"}), 404
if not chosen_model_key or chosen_model_key not in ["a", "b"]:
return jsonify({"error": "Invalid chosen model"}), 400
session_data = app.tts_sessions[session_id]
# Check if session expired
if datetime.utcnow() > session_data["expires_at"]:
cleanup_session(session_id)
return jsonify({"error": "Session expired"}), 410
# Check if already voted
if session_data["voted"]:
return jsonify({"error": "Vote already submitted for this session"}), 400
# Get model IDs and audio paths
chosen_id = (
session_data["model_a"] if chosen_model_key == "a" else session_data["model_b"]
)
rejected_id = (
session_data["model_b"] if chosen_model_key == "a" else session_data["model_a"]
)
chosen_audio_path = (
session_data["audio_a"] if chosen_model_key == "a" else session_data["audio_b"]
)
rejected_audio_path = (
session_data["audio_b"] if chosen_model_key == "a" else session_data["audio_a"]
)
# Record vote in database
user_id = current_user.id if current_user.is_authenticated else None
vote, error = record_vote(
user_id, session_data["text"], chosen_id, rejected_id, ModelType.TTS
)
if error:
return jsonify({"error": error}), 500
# --- Save preference data ---
try:
vote_uuid = str(uuid.uuid4())
vote_dir = os.path.join("./votes", vote_uuid)
os.makedirs(vote_dir, exist_ok=True)
# Copy audio files
shutil.copy(chosen_audio_path, os.path.join(vote_dir, "chosen.wav"))
shutil.copy(rejected_audio_path, os.path.join(vote_dir, "rejected.wav"))
# Create metadata
chosen_model_obj = Model.query.get(chosen_id)
rejected_model_obj = Model.query.get(rejected_id)
metadata = {
"text": session_data["text"],
"chosen_model": chosen_model_obj.name if chosen_model_obj else "Unknown",
"chosen_model_id": chosen_model_obj.id if chosen_model_obj else "Unknown",
"rejected_model": rejected_model_obj.name if rejected_model_obj else "Unknown",
"rejected_model_id": rejected_model_obj.id if rejected_model_obj else "Unknown",
"session_id": session_id,
"timestamp": datetime.utcnow().isoformat(),
"username": current_user.username if current_user.is_authenticated else None,
"model_type": "TTS"
}
with open(os.path.join(vote_dir, "metadata.json"), "w") as f:
json.dump(metadata, f, indent=2)
except Exception as e:
app.logger.error(f"Error saving preference data for vote {session_id}: {str(e)}")
# Continue even if saving preference data fails, vote is already recorded
# Mark session as voted
session_data["voted"] = True
# Return updated models (use previously fetched objects)
return jsonify(
{
"success": True,
"chosen_model": {"id": chosen_id, "name": chosen_model_obj.name if chosen_model_obj else "Unknown"},
"rejected_model": {
"id": rejected_id,
"name": rejected_model_obj.name if rejected_model_obj else "Unknown",
},
"names": {
"a": (
chosen_model_obj.name if chosen_model_key == "a" else rejected_model_obj.name
if chosen_model_obj and rejected_model_obj else "Unknown"
),
"b": (
rejected_model_obj.name if chosen_model_key == "a" else chosen_model_obj.name
if chosen_model_obj and rejected_model_obj else "Unknown"
),
},
}
)
def cleanup_session(session_id):
"""Remove session and its audio files"""
if session_id in app.tts_sessions:
session = app.tts_sessions[session_id]
# Remove audio files
for audio_file in [session["audio_a"], session["audio_b"]]:
if os.path.exists(audio_file):
try:
os.remove(audio_file)
except Exception as e:
app.logger.error(f"Error removing audio file: {str(e)}")
# Remove session
del app.tts_sessions[session_id]
@app.route("/api/conversational/generate", methods=["POST"])
@limiter.limit("5 per minute")
def generate_podcast():
# If verification not setup, handle it first
if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
return jsonify({"error": "Turnstile verification required"}), 403
data = request.json
script = data.get("script")
if not script or not isinstance(script, list) or len(script) < 2:
return jsonify({"error": "Invalid script format or too short"}), 400
# Validate script format
for line in script:
if not isinstance(line, dict) or "text" not in line or "speaker_id" not in line:
return (
jsonify(
{
"error": "Invalid script line format. Each line must have text and speaker_id"
}
),
400,
)
if (
not line["text"]
or not isinstance(line["speaker_id"], int)
or line["speaker_id"] not in [0, 1]
):
return (
jsonify({"error": "Invalid script content. Speaker ID must be 0 or 1"}),
400,
)
# Get two conversational models (currently only CSM and PlayDialog)
available_models = Model.query.filter_by(
model_type=ModelType.CONVERSATIONAL, is_active=True
).all()
if len(available_models) < 2:
return jsonify({"error": "Not enough conversational models available"}), 500
selected_models = random.sample(available_models, 2)
try:
# Generate audio for both models concurrently
audio_files = []
model_ids = []
# Function to process a single model
def process_model(model):
# Call conversational TTS service
audio_content = predict_tts(script, model.id)
# Save to temp file with unique name
file_uuid = str(uuid.uuid4())
dest_path = os.path.join(TEMP_AUDIO_DIR, f"{file_uuid}.wav")
with open(dest_path, "wb") as f:
f.write(audio_content)
return {"model_id": model.id, "audio_path": dest_path}
# Use ThreadPoolExecutor to process models concurrently
with ThreadPoolExecutor(max_workers=2) as executor:
results = list(executor.map(process_model, selected_models))
# Extract results
for result in results:
model_ids.append(result["model_id"])
audio_files.append(result["audio_path"])
# Create session
session_id = str(uuid.uuid4())
script_text = " ".join([line["text"] for line in script])
app.conversational_sessions[session_id] = {
"model_a": model_ids[0],
"model_b": model_ids[1],
"audio_a": audio_files[0],
"audio_b": audio_files[1],
"text": script_text[:1000], # Limit text length
"created_at": datetime.utcnow(),
"expires_at": datetime.utcnow() + timedelta(minutes=30),
"voted": False,
"script": script,
}
# Return audio file paths and session
return jsonify(
{
"session_id": session_id,
"audio_a": f"/api/conversational/audio/{session_id}/a",
"audio_b": f"/api/conversational/audio/{session_id}/b",
"expires_in": 1800, # 30 minutes in seconds
}
)
except Exception as e:
app.logger.error(f"Conversational generation error: {str(e)}")
return jsonify({"error": f"Failed to generate podcast: {str(e)}"}), 500
@app.route("/api/conversational/audio/<session_id>/<model_key>")
def get_podcast_audio(session_id, model_key):
# If verification not setup, handle it first
if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
return jsonify({"error": "Turnstile verification required"}), 403
if session_id not in app.conversational_sessions:
return jsonify({"error": "Invalid or expired session"}), 404
session_data = app.conversational_sessions[session_id]
# Check if session expired
if datetime.utcnow() > session_data["expires_at"]:
cleanup_conversational_session(session_id)
return jsonify({"error": "Session expired"}), 410
if model_key == "a":
audio_path = session_data["audio_a"]
elif model_key == "b":
audio_path = session_data["audio_b"]
else:
return jsonify({"error": "Invalid model key"}), 400
# Check if file exists
if not os.path.exists(audio_path):
return jsonify({"error": "Audio file not found"}), 404
return send_file(audio_path, mimetype="audio/wav")
@app.route("/api/conversational/vote", methods=["POST"])
@limiter.limit("30 per minute")
def submit_podcast_vote():
# If verification not setup, handle it first
if app.config["TURNSTILE_ENABLED"] and not session.get("turnstile_verified"):
return jsonify({"error": "Turnstile verification required"}), 403
data = request.json
session_id = data.get("session_id")
chosen_model_key = data.get("chosen_model") # "a" or "b"
if not session_id or session_id not in app.conversational_sessions:
return jsonify({"error": "Invalid or expired session"}), 404
if not chosen_model_key or chosen_model_key not in ["a", "b"]:
return jsonify({"error": "Invalid chosen model"}), 400
session_data = app.conversational_sessions[session_id]
# Check if session expired
if datetime.utcnow() > session_data["expires_at"]:
cleanup_conversational_session(session_id)
return jsonify({"error": "Session expired"}), 410
# Check if already voted
if session_data["voted"]:
return jsonify({"error": "Vote already submitted for this session"}), 400
# Get model IDs and audio paths
chosen_id = (
session_data["model_a"] if chosen_model_key == "a" else session_data["model_b"]
)
rejected_id = (
session_data["model_b"] if chosen_model_key == "a" else session_data["model_a"]
)
chosen_audio_path = (
session_data["audio_a"] if chosen_model_key == "a" else session_data["audio_b"]
)
rejected_audio_path = (
session_data["audio_b"] if chosen_model_key == "a" else session_data["audio_a"]
)
# Record vote in database
user_id = current_user.id if current_user.is_authenticated else None
vote, error = record_vote(
user_id, session_data["text"], chosen_id, rejected_id, ModelType.CONVERSATIONAL
)
if error:
return jsonify({"error": error}), 500
# --- Save preference data ---\
try:
vote_uuid = str(uuid.uuid4())
vote_dir = os.path.join("./votes", vote_uuid)
os.makedirs(vote_dir, exist_ok=True)
# Copy audio files
shutil.copy(chosen_audio_path, os.path.join(vote_dir, "chosen.wav"))
shutil.copy(rejected_audio_path, os.path.join(vote_dir, "rejected.wav"))
# Create metadata
chosen_model_obj = Model.query.get(chosen_id)
rejected_model_obj = Model.query.get(rejected_id)
metadata = {
"script": session_data["script"], # Save the full script
"chosen_model": chosen_model_obj.name if chosen_model_obj else "Unknown",
"chosen_model_id": chosen_model_obj.id if chosen_model_obj else "Unknown",
"rejected_model": rejected_model_obj.name if rejected_model_obj else "Unknown",
"rejected_model_id": rejected_model_obj.id if rejected_model_obj else "Unknown",
"session_id": session_id,
"timestamp": datetime.utcnow().isoformat(),
"username": current_user.username if current_user.is_authenticated else None,
"model_type": "CONVERSATIONAL"
}
with open(os.path.join(vote_dir, "metadata.json"), "w") as f:
json.dump(metadata, f, indent=2)
except Exception as e:
app.logger.error(f"Error saving preference data for conversational vote {session_id}: {str(e)}")
# Continue even if saving preference data fails, vote is already recorded
# Mark session as voted
session_data["voted"] = True
# Return updated models (use previously fetched objects)
return jsonify(
{
"success": True,
"chosen_model": {"id": chosen_id, "name": chosen_model_obj.name if chosen_model_obj else "Unknown"},
"rejected_model": {
"id": rejected_id,
"name": rejected_model_obj.name if rejected_model_obj else "Unknown",
},
"names": {
"a": Model.query.get(session_data["model_a"]).name,
"b": Model.query.get(session_data["model_b"]).name,
},
}
)
def cleanup_conversational_session(session_id):
"""Remove conversational session and its audio files"""
if session_id in app.conversational_sessions:
session = app.conversational_sessions[session_id]
# Remove audio files
for audio_file in [session["audio_a"], session["audio_b"]]:
if os.path.exists(audio_file):
try:
os.remove(audio_file)
except Exception as e:
app.logger.error(
f"Error removing conversational audio file: {str(e)}"
)
# Remove session
del app.conversational_sessions[session_id]
# Schedule periodic cleanup
def setup_cleanup():
def cleanup_expired_sessions():
with app.app_context(): # Ensure app context for logging
current_time = datetime.utcnow()
# Cleanup TTS sessions
expired_tts_sessions = [
sid
for sid, session_data in app.tts_sessions.items()
if current_time > session_data["expires_at"]
]
for sid in expired_tts_sessions:
cleanup_session(sid)
# Cleanup conversational sessions
expired_conv_sessions = [
sid
for sid, session_data in app.conversational_sessions.items()
if current_time > session_data["expires_at"]
]
for sid in expired_conv_sessions:
cleanup_conversational_session(sid)
app.logger.info(f"Cleaned up {len(expired_tts_sessions)} TTS and {len(expired_conv_sessions)} conversational sessions.")
# Also cleanup potentially expired cache entries (e.g., > 1 hour old)
# This prevents stale cache entries if generation is slow or failing
# cleanup_stale_cache_entries()
# Run cleanup every 15 minutes
scheduler = BackgroundScheduler(daemon=True) # Run scheduler as daemon thread
scheduler.add_job(cleanup_expired_sessions, "interval", minutes=15)
scheduler.start()
print("Cleanup scheduler started") # Use print for startup messages
# Schedule periodic tasks (database sync and preference upload)
def setup_periodic_tasks():
"""Setup periodic database synchronization and preference data upload for Spaces"""
if not IS_SPACES:
return
db_path = app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "instance/") # Get relative path
preferences_repo_id = "TTS-AGI/arena-v2-preferences"
database_repo_id = "TTS-AGI/database-arena-v2"
votes_dir = "./votes"
def sync_database():
"""Uploads the database to HF dataset"""
with app.app_context(): # Ensure app context for logging
try:
if not os.path.exists(db_path):
app.logger.warning(f"Database file not found at {db_path}, skipping sync.")
return
api = HfApi(token=os.getenv("HF_TOKEN"))
api.upload_file(
path_or_fileobj=db_path,
path_in_repo="tts_arena.db",
repo_id=database_repo_id,
repo_type="dataset",
)
app.logger.info(f"Database uploaded to {database_repo_id} at {datetime.utcnow()}")
except Exception as e:
app.logger.error(f"Error uploading database to {database_repo_id}: {str(e)}")
def sync_preferences_data():
"""Zips and uploads preference data folders in batches to HF dataset"""
with app.app_context(): # Ensure app context for logging
if not os.path.isdir(votes_dir):
return # Don't log every 5 mins if dir doesn't exist yet
temp_batch_dir = None # Initialize to manage cleanup
temp_individual_zip_dir = None # Initialize for individual zips
local_batch_zip_path = None # Initialize for batch zip path
try:
api = HfApi(token=os.getenv("HF_TOKEN"))
vote_uuids = [d for d in os.listdir(votes_dir) if os.path.isdir(os.path.join(votes_dir, d))]
if not vote_uuids:
return # No data to process
app.logger.info(f"Found {len(vote_uuids)} vote directories to process.")
# Create temporary directories
temp_batch_dir = tempfile.mkdtemp(prefix="hf_batch_")
temp_individual_zip_dir = tempfile.mkdtemp(prefix="hf_indiv_zips_")
app.logger.debug(f"Created temp directories: {temp_batch_dir}, {temp_individual_zip_dir}")
processed_vote_dirs = []
individual_zips_in_batch = []
# 1. Create individual zips and move them to the batch directory
for vote_uuid in vote_uuids:
dir_path = os.path.join(votes_dir, vote_uuid)
individual_zip_base_path = os.path.join(temp_individual_zip_dir, vote_uuid)
individual_zip_path = f"{individual_zip_base_path}.zip"
try:
shutil.make_archive(individual_zip_base_path, 'zip', dir_path)
app.logger.debug(f"Created individual zip: {individual_zip_path}")
# Move the created zip into the batch directory
final_individual_zip_path = os.path.join(temp_batch_dir, f"{vote_uuid}.zip")
shutil.move(individual_zip_path, final_individual_zip_path)
app.logger.debug(f"Moved individual zip to batch dir: {final_individual_zip_path}")
processed_vote_dirs.append(dir_path) # Mark original dir for later cleanup
individual_zips_in_batch.append(final_individual_zip_path)
except Exception as zip_err:
app.logger.error(f"Error creating or moving zip for {vote_uuid}: {str(zip_err)}")
# Clean up partial zip if it exists
if os.path.exists(individual_zip_path):
try:
os.remove(individual_zip_path)
except OSError:
pass
# Continue processing other votes
# Clean up the temporary dir used for creating individual zips
shutil.rmtree(temp_individual_zip_dir)
temp_individual_zip_dir = None # Mark as cleaned
app.logger.debug("Cleaned up temporary individual zip directory.")
if not individual_zips_in_batch:
app.logger.warning("No individual zips were successfully created for batching.")
# Clean up batch dir if it's empty or only contains failed attempts
if temp_batch_dir and os.path.exists(temp_batch_dir):
shutil.rmtree(temp_batch_dir)
temp_batch_dir = None
return
# 2. Create the batch zip file
batch_timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
batch_uuid_short = str(uuid.uuid4())[:8]
batch_zip_filename = f"{batch_timestamp}_batch_{batch_uuid_short}.zip"
# Create batch zip in a standard temp location first
local_batch_zip_base = os.path.join(tempfile.gettempdir(), batch_zip_filename.replace('.zip', ''))
local_batch_zip_path = f"{local_batch_zip_base}.zip"
app.logger.info(f"Creating batch zip: {local_batch_zip_path} with {len(individual_zips_in_batch)} individual zips.")
shutil.make_archive(local_batch_zip_base, 'zip', temp_batch_dir)
app.logger.info(f"Batch zip created successfully: {local_batch_zip_path}")
# 3. Upload the batch zip file
hf_repo_path = f"votes/{year}/{month}/{batch_zip_filename}"
app.logger.info(f"Uploading batch zip to HF Hub: {preferences_repo_id}/{hf_repo_path}")
api.upload_file(
path_or_fileobj=local_batch_zip_path,
path_in_repo=hf_repo_path,
repo_id=preferences_repo_id,
repo_type="dataset",
commit_message=f"Add batch preference data {batch_zip_filename} ({len(individual_zips_in_batch)} votes)"
)
app.logger.info(f"Successfully uploaded batch {batch_zip_filename} to {preferences_repo_id}")
# 4. Cleanup after successful upload
app.logger.info("Cleaning up local files after successful upload.")
# Remove original vote directories that were successfully zipped and uploaded
for dir_path in processed_vote_dirs:
try:
shutil.rmtree(dir_path)
app.logger.debug(f"Removed original vote directory: {dir_path}")
except OSError as e:
app.logger.error(f"Error removing processed vote directory {dir_path}: {str(e)}")
# Remove the temporary batch directory (containing the individual zips)
shutil.rmtree(temp_batch_dir)
temp_batch_dir = None
app.logger.debug("Removed temporary batch directory.")
# Remove the local batch zip file
os.remove(local_batch_zip_path)
local_batch_zip_path = None
app.logger.debug("Removed local batch zip file.")
app.logger.info(f"Finished preference data sync. Uploaded batch {batch_zip_filename}.")
except Exception as e:
app.logger.error(f"Error during preference data batch sync: {str(e)}", exc_info=True)
# If upload failed, the local batch zip might exist, clean it up.
if local_batch_zip_path and os.path.exists(local_batch_zip_path):
try:
os.remove(local_batch_zip_path)
app.logger.debug("Cleaned up local batch zip after failed upload.")
except OSError as clean_err:
app.logger.error(f"Error cleaning up batch zip after failed upload: {clean_err}")
# Do NOT remove temp_batch_dir if it exists; its contents will be retried next time.
# Do NOT remove original vote directories if upload failed.
finally:
# Final cleanup for temporary directories in case of unexpected exits
if temp_individual_zip_dir and os.path.exists(temp_individual_zip_dir):
try:
shutil.rmtree(temp_individual_zip_dir)
except Exception as final_clean_err:
app.logger.error(f"Error in final cleanup (indiv zips): {final_clean_err}")
# Only clean up batch dir in finally block if it *wasn't* kept intentionally after upload failure
if temp_batch_dir and os.path.exists(temp_batch_dir):
# Check if an upload attempt happened and failed
upload_failed = 'e' in locals() and isinstance(e, Exception) # Crude check if exception occurred
if not upload_failed: # If no upload error or upload succeeded, clean up
try:
shutil.rmtree(temp_batch_dir)
except Exception as final_clean_err:
app.logger.error(f"Error in final cleanup (batch dir): {final_clean_err}")
else:
app.logger.warning("Keeping temporary batch directory due to upload failure for next attempt.")
# Schedule periodic tasks
scheduler = BackgroundScheduler()
# Sync database less frequently if needed, e.g., every 15 minutes
scheduler.add_job(sync_database, "interval", minutes=15, id="sync_db_job")
# Sync preferences more frequently
scheduler.add_job(sync_preferences_data, "interval", minutes=5, id="sync_pref_job")
scheduler.start()
print("Periodic tasks scheduler started (DB sync and Preferences upload)") # Use print for startup
@app.cli.command("init-db")
def init_db():
"""Initialize the database."""
with app.app_context():
db.create_all()
print("Database initialized!")
@app.route("/api/toggle-leaderboard-visibility", methods=["POST"])
def toggle_leaderboard_visibility():
"""Toggle whether the current user appears in the top voters leaderboard"""
if not current_user.is_authenticated:
return jsonify({"error": "You must be logged in to change this setting"}), 401
new_status = toggle_user_leaderboard_visibility(current_user.id)
if new_status is None:
return jsonify({"error": "User not found"}), 404
return jsonify({
"success": True,
"visible": new_status,
"message": "You are now visible in the voters leaderboard" if new_status else "You are now hidden from the voters leaderboard"
})
@app.route("/api/tts/cached-sentences")
def get_cached_sentences():
"""Returns a list of sentences currently available in the TTS cache."""
with tts_cache_lock:
cached_keys = list(tts_cache.keys())
return jsonify(cached_keys)
if __name__ == "__main__":
with app.app_context():
# Ensure ./instance and ./votes directories exist
os.makedirs("instance", exist_ok=True)
os.makedirs("./votes", exist_ok=True) # Create votes directory if it doesn't exist
os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache audio dir exists
# Clean up old cache audio files on startup
try:
app.logger.info(f"Clearing old cache audio files from {CACHE_AUDIO_DIR}")
for filename in os.listdir(CACHE_AUDIO_DIR):
file_path = os.path.join(CACHE_AUDIO_DIR, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
app.logger.error(f'Failed to delete {file_path}. Reason: {e}')
except Exception as e:
app.logger.error(f"Error clearing cache directory {CACHE_AUDIO_DIR}: {e}")
# Download database if it doesn't exist (only on initial space start)
if IS_SPACES and not os.path.exists(app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "")):
try:
print("Database not found, downloading from HF dataset...")
hf_hub_download(
repo_id="TTS-AGI/database-arena-v2",
filename="tts_arena.db",
repo_type="dataset",
local_dir="instance", # download to instance/
token=os.getenv("HF_TOKEN"),
)
print("Database downloaded successfully ✅")
except Exception as e:
print(f"Error downloading database from HF dataset: {str(e)} ⚠️")
db.create_all() # Create tables if they don't exist
insert_initial_models()
# Setup background tasks
initialize_tts_cache() # Start populating the cache
setup_cleanup()
setup_periodic_tasks() # Renamed function call
# Configure Flask to recognize HTTPS when behind a reverse proxy
from werkzeug.middleware.proxy_fix import ProxyFix
# Apply ProxyFix middleware to handle reverse proxy headers
# This ensures Flask generates correct URLs with https scheme
# X-Forwarded-Proto header will be used to detect the original protocol
app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1)
# Force Flask to prefer HTTPS for generated URLs
app.config["PREFERRED_URL_SCHEME"] = "https"
from waitress import serve
# Configuration for 2 vCPUs:
# - threads: typically 4-8 threads per CPU core is a good balance
# - connection_limit: maximum concurrent connections
# - channel_timeout: prevent hanging connections
threads = 12 # 6 threads per vCPU is a good balance for mixed IO/CPU workloads
if IS_SPACES:
serve(
app,
host="0.0.0.0",
port=int(os.environ.get("PORT", 7860)),
threads=threads,
connection_limit=100,
channel_timeout=30,
url_scheme='https'
)
else:
print(f"Starting Waitress server with {threads} threads")
serve(
app,
host="0.0.0.0",
port=5000,
threads=threads,
connection_limit=100,
channel_timeout=30,
url_scheme='https' # Keep https for local dev if using proxy/tunnel
)