Spaces:
Running
Running
Commit
·
99b1671
1
Parent(s):
367e06e
(wip)modify models
Browse files
app.py
CHANGED
|
@@ -129,6 +129,34 @@ CACHE_AUDIO_DIR = os.path.join(TEMP_AUDIO_DIR, CACHE_AUDIO_SUBDIR)
|
|
| 129 |
os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
|
| 130 |
os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache subdir exists
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
# Store active TTS sessions
|
| 134 |
app.tts_sessions = {}
|
|
@@ -382,8 +410,13 @@ def generate_and_save_tts(text, model_id, output_dir):
|
|
| 382 |
temp_audio_path = None # Initialize to None
|
| 383 |
try:
|
| 384 |
app.logger.debug(f"[TTS Gen {model_id}] Starting generation for: '{text[:30]}...'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
# If predict_tts saves file itself and returns path:
|
| 386 |
-
temp_audio_path = predict_tts(text, model_id)
|
| 387 |
app.logger.debug(f"[TTS Gen {model_id}] predict_tts returned: {temp_audio_path}")
|
| 388 |
|
| 389 |
if not temp_audio_path or not os.path.exists(temp_audio_path):
|
|
@@ -396,7 +429,7 @@ def generate_and_save_tts(text, model_id, output_dir):
|
|
| 396 |
# Move the file generated by predict_tts to the target cache directory
|
| 397 |
shutil.move(temp_audio_path, dest_path)
|
| 398 |
app.logger.debug(f"[TTS Gen {model_id}] Move successful. Returning {dest_path}")
|
| 399 |
-
return dest_path
|
| 400 |
|
| 401 |
except Exception as e:
|
| 402 |
app.logger.error(f"Error generating/saving TTS for model {model_id} and text '{text[:30]}...': {str(e)}")
|
|
@@ -407,7 +440,7 @@ def generate_and_save_tts(text, model_id, output_dir):
|
|
| 407 |
os.remove(temp_audio_path)
|
| 408 |
except OSError:
|
| 409 |
pass # Ignore error if file couldn't be removed
|
| 410 |
-
return None
|
| 411 |
|
| 412 |
|
| 413 |
def _generate_cache_entry_task(sentence):
|
|
@@ -445,8 +478,8 @@ def _generate_cache_entry_task(sentence):
|
|
| 445 |
future_b = audio_executor.submit(generate_and_save_tts, sentence, model_b_id, CACHE_AUDIO_DIR)
|
| 446 |
|
| 447 |
timeout_seconds = 120
|
| 448 |
-
audio_a_path = future_a.result(timeout=timeout_seconds)
|
| 449 |
-
audio_b_path = future_b.result(timeout=timeout_seconds)
|
| 450 |
|
| 451 |
if audio_a_path and audio_b_path:
|
| 452 |
with tts_cache_lock:
|
|
@@ -458,6 +491,8 @@ def _generate_cache_entry_task(sentence):
|
|
| 458 |
"model_b": model_b_id,
|
| 459 |
"audio_a": audio_a_path,
|
| 460 |
"audio_b": audio_b_path,
|
|
|
|
|
|
|
| 461 |
"created_at": datetime.utcnow(),
|
| 462 |
}
|
| 463 |
app.logger.info(f"Successfully cached entry for: '{sentence[:50]}...'")
|
|
@@ -1112,7 +1147,7 @@ def setup_periodic_tasks():
|
|
| 1112 |
|
| 1113 |
db_path = app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "instance/") # Get relative path
|
| 1114 |
preferences_repo_id = "kemuriririn/arena-preferences"
|
| 1115 |
-
database_repo_id = "kemuriririn/database-arena
|
| 1116 |
votes_dir = "./votes"
|
| 1117 |
|
| 1118 |
def sync_database():
|
|
@@ -1318,10 +1353,27 @@ def toggle_leaderboard_visibility():
|
|
| 1318 |
|
| 1319 |
@app.route("/api/tts/cached-sentences")
|
| 1320 |
def get_cached_sentences():
|
| 1321 |
-
"""Returns a list of sentences currently available in the TTS cache."""
|
| 1322 |
with tts_cache_lock:
|
| 1323 |
-
|
| 1324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1325 |
|
| 1326 |
|
| 1327 |
def get_weighted_random_models(
|
|
@@ -1414,6 +1466,7 @@ if __name__ == "__main__":
|
|
| 1414 |
except Exception as e:
|
| 1415 |
print(f"Error downloading database from HF dataset: {str(e)} ⚠️")
|
| 1416 |
|
|
|
|
| 1417 |
|
| 1418 |
db.create_all() # Create tables if they don't exist
|
| 1419 |
insert_initial_models()
|
|
|
|
| 129 |
os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
|
| 130 |
os.makedirs(CACHE_AUDIO_DIR, exist_ok=True) # Ensure cache subdir exists
|
| 131 |
|
| 132 |
+
# --- 参考音色下载与管理 ---
|
| 133 |
+
REFERENCE_AUDIO_DIR = os.path.join(TEMP_AUDIO_DIR, "reference_audios")
|
| 134 |
+
REFERENCE_AUDIO_DATASET = os.getenv("REFERENCE_AUDIO_DATASET", "kemuriririn/arena-files")
|
| 135 |
+
REFERENCE_AUDIO_PATTERN = os.getenv("REFERENCE_AUDIO_PATTERN", "reference_audios/")
|
| 136 |
+
reference_audio_files = []
|
| 137 |
+
|
| 138 |
+
def download_reference_audios():
|
| 139 |
+
"""从 Hugging Face dataset 下载参考音频到本地目录,并生成文件列表"""
|
| 140 |
+
global reference_audio_files
|
| 141 |
+
os.makedirs(REFERENCE_AUDIO_DIR, exist_ok=True)
|
| 142 |
+
try:
|
| 143 |
+
api = HfApi(token=os.getenv("HF_TOKEN"))
|
| 144 |
+
files = api.list_repo_files(repo_id=REFERENCE_AUDIO_DATASET, repo_type="dataset")
|
| 145 |
+
# 只下载 wav 文件
|
| 146 |
+
wav_files = [f for f in files if f.startswith(REFERENCE_AUDIO_PATTERN) and f.endswith(".wav")]
|
| 147 |
+
for f in wav_files:
|
| 148 |
+
local_path = hf_hub_download(
|
| 149 |
+
repo_id=REFERENCE_AUDIO_DATASET,
|
| 150 |
+
filename=f,
|
| 151 |
+
repo_type="dataset",
|
| 152 |
+
local_dir=REFERENCE_AUDIO_DIR,
|
| 153 |
+
token=os.getenv("HF_TOKEN"),
|
| 154 |
+
)
|
| 155 |
+
reference_audio_files.append(local_path)
|
| 156 |
+
print(f"Downloaded {len(reference_audio_files)} reference audios.")
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Error downloading reference audios: {e}")
|
| 159 |
+
reference_audio_files = []
|
| 160 |
|
| 161 |
# Store active TTS sessions
|
| 162 |
app.tts_sessions = {}
|
|
|
|
| 410 |
temp_audio_path = None # Initialize to None
|
| 411 |
try:
|
| 412 |
app.logger.debug(f"[TTS Gen {model_id}] Starting generation for: '{text[:30]}...'")
|
| 413 |
+
# 随机选一个参考音频
|
| 414 |
+
reference_audio_path = None
|
| 415 |
+
if reference_audio_files:
|
| 416 |
+
reference_audio_path = random.choice(reference_audio_files)
|
| 417 |
+
app.logger.debug(f"[TTS Gen {model_id}] Using reference audio: {reference_audio_path}")
|
| 418 |
# If predict_tts saves file itself and returns path:
|
| 419 |
+
temp_audio_path = predict_tts(text, model_id, reference_audio_path=reference_audio_path)
|
| 420 |
app.logger.debug(f"[TTS Gen {model_id}] predict_tts returned: {temp_audio_path}")
|
| 421 |
|
| 422 |
if not temp_audio_path or not os.path.exists(temp_audio_path):
|
|
|
|
| 429 |
# Move the file generated by predict_tts to the target cache directory
|
| 430 |
shutil.move(temp_audio_path, dest_path)
|
| 431 |
app.logger.debug(f"[TTS Gen {model_id}] Move successful. Returning {dest_path}")
|
| 432 |
+
return dest_path, reference_audio_path
|
| 433 |
|
| 434 |
except Exception as e:
|
| 435 |
app.logger.error(f"Error generating/saving TTS for model {model_id} and text '{text[:30]}...': {str(e)}")
|
|
|
|
| 440 |
os.remove(temp_audio_path)
|
| 441 |
except OSError:
|
| 442 |
pass # Ignore error if file couldn't be removed
|
| 443 |
+
return None, None
|
| 444 |
|
| 445 |
|
| 446 |
def _generate_cache_entry_task(sentence):
|
|
|
|
| 478 |
future_b = audio_executor.submit(generate_and_save_tts, sentence, model_b_id, CACHE_AUDIO_DIR)
|
| 479 |
|
| 480 |
timeout_seconds = 120
|
| 481 |
+
audio_a_path, ref_a = future_a.result(timeout=timeout_seconds)
|
| 482 |
+
audio_b_path, ref_b = future_b.result(timeout=timeout_seconds)
|
| 483 |
|
| 484 |
if audio_a_path and audio_b_path:
|
| 485 |
with tts_cache_lock:
|
|
|
|
| 491 |
"model_b": model_b_id,
|
| 492 |
"audio_a": audio_a_path,
|
| 493 |
"audio_b": audio_b_path,
|
| 494 |
+
"ref_a": ref_a,
|
| 495 |
+
"ref_b": ref_b,
|
| 496 |
"created_at": datetime.utcnow(),
|
| 497 |
}
|
| 498 |
app.logger.info(f"Successfully cached entry for: '{sentence[:50]}...'")
|
|
|
|
| 1147 |
|
| 1148 |
db_path = app.config["SQLALCHEMY_DATABASE_URI"].replace("sqlite:///", "instance/") # Get relative path
|
| 1149 |
preferences_repo_id = "kemuriririn/arena-preferences"
|
| 1150 |
+
database_repo_id = "kemuriririn/database-arena"
|
| 1151 |
votes_dir = "./votes"
|
| 1152 |
|
| 1153 |
def sync_database():
|
|
|
|
| 1353 |
|
| 1354 |
@app.route("/api/tts/cached-sentences")
|
| 1355 |
def get_cached_sentences():
|
| 1356 |
+
"""Returns a list of sentences currently available in the TTS cache, with reference audio."""
|
| 1357 |
with tts_cache_lock:
|
| 1358 |
+
cached = [
|
| 1359 |
+
{
|
| 1360 |
+
"sentence": k,
|
| 1361 |
+
"model_a": v["model_a"],
|
| 1362 |
+
"model_b": v["model_b"],
|
| 1363 |
+
"ref_a": os.path.relpath(v["ref_a"], start=REFERENCE_AUDIO_DIR) if v.get("ref_a") else None,
|
| 1364 |
+
"ref_b": os.path.relpath(v["ref_b"], start=REFERENCE_AUDIO_DIR) if v.get("ref_b") else None,
|
| 1365 |
+
}
|
| 1366 |
+
for k, v in tts_cache.items()
|
| 1367 |
+
]
|
| 1368 |
+
return jsonify(cached)
|
| 1369 |
+
|
| 1370 |
+
@app.route("/api/tts/reference-audio/<filename>")
|
| 1371 |
+
def get_reference_audio(filename):
|
| 1372 |
+
"""试听参考音频"""
|
| 1373 |
+
file_path = os.path.join(REFERENCE_AUDIO_DIR, filename)
|
| 1374 |
+
if not os.path.exists(file_path):
|
| 1375 |
+
return jsonify({"error": "Reference audio not found"}), 404
|
| 1376 |
+
return send_file(file_path, mimetype="audio/wav")
|
| 1377 |
|
| 1378 |
|
| 1379 |
def get_weighted_random_models(
|
|
|
|
| 1466 |
except Exception as e:
|
| 1467 |
print(f"Error downloading database from HF dataset: {str(e)} ⚠️")
|
| 1468 |
|
| 1469 |
+
download_reference_audios()
|
| 1470 |
|
| 1471 |
db.create_all() # Create tables if they don't exist
|
| 1472 |
insert_initial_models()
|
models.py
CHANGED
|
@@ -446,7 +446,7 @@ def insert_initial_models():
|
|
| 446 |
name="Spark TTS",
|
| 447 |
model_type=ModelType.TTS,
|
| 448 |
is_open=False,
|
| 449 |
-
is_active=
|
| 450 |
model_url="https://github.com/SparkAudio/Spark-TTS",
|
| 451 |
),
|
| 452 |
# Model(
|
|
|
|
| 446 |
name="Spark TTS",
|
| 447 |
model_type=ModelType.TTS,
|
| 448 |
is_open=False,
|
| 449 |
+
is_active=True, # API stopped working
|
| 450 |
model_url="https://github.com/SparkAudio/Spark-TTS",
|
| 451 |
),
|
| 452 |
# Model(
|