Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.responses import JSONResponse | |
import subprocess | |
import tempfile | |
import os | |
import shutil | |
from pydantic import BaseModel | |
import sys | |
import numpy as np # ViSQOL 可能需要 numpy | |
import soundfile as sf # 用于读取音频 | |
from typing import Optional, List # 导入 List | |
import librosa # Need librosa for resampling during conversion if soundfile fails | |
app = FastAPI(title="ViSQOL 音频质量 API") | |
# --- 配置 ViSQOL 路径 --- | |
# 相对于 app.py 的路径 | |
VISQOL_DIR = "./build/visqol" | |
VISQOL_LIB_PATH = os.path.join(VISQOL_DIR, "visqol_lib_py.so") | |
PB2_DIR = os.path.join(VISQOL_DIR, "pb2") # pb2 文件所在的目录 | |
MODEL_DIR = os.path.join(VISQOL_DIR, "model") | |
SPEECH_MODEL_PATH = os.path.join(MODEL_DIR, "libsvm_nu_svr_model.txt") | |
AUDIO_MODEL_PATH = os.path.join(MODEL_DIR, "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite") | |
# --- 路径配置结束 --- | |
# 检查文件是否存在 | |
required_files = [VISQOL_LIB_PATH, SPEECH_MODEL_PATH, AUDIO_MODEL_PATH] | |
if not all(os.path.exists(f) for f in required_files): | |
missing = [f for f in required_files if not os.path.exists(f)] | |
raise FileNotFoundError(f"ViSQOL 必需文件未找到: {', '.join(missing)}") | |
if not os.path.exists(PB2_DIR) or not os.path.isdir(PB2_DIR): | |
raise FileNotFoundError(f"ViSQOL pb2 目录未找到: {PB2_DIR}") | |
# 动态导入 ViSQOL 库和 pb2 文件 | |
try: | |
# 将 pb2 目录和 visqol 目录添加到 Python 路径 | |
sys.path.insert(0, os.path.abspath(PB2_DIR)) | |
sys.path.insert(0, os.path.abspath(VISQOL_DIR)) | |
# 加载 .so 文件需要确保 Python 能找到它,或者它在 LD_LIBRARY_PATH 中 | |
# 通常放在 sys.path 中对于纯 Python 导入是足够的,但 .so 可能不同 | |
# 在 Dockerfile 中我们会处理库路径 | |
import visqol_lib_py | |
import similarity_result_pb2 | |
import visqol_config_pb2 | |
print("ViSQOL 库和 pb2 文件导入成功。") | |
except ImportError as e: | |
print(f"错误:无法导入 ViSQOL 库或 pb2 文件。") | |
print(f"Python 搜索路径: {sys.path}") | |
print(f"错误详情: {e}") | |
# 在 Hugging Face 环境中,启动失败会显示日志,所以这里不直接 raise | |
# raise ImportError(f"无法导入 ViSQOL 库或 pb2 文件: {e}") | |
visqol_lib_py = None # 标记为不可用 | |
# 定义 API 响应模型 | |
class VisqolResponse(BaseModel): | |
reference_filename: str | |
degraded_filename: str | |
mode: str | |
moslqo: float | |
vnsim: Optional[float] = None # 添加 vnsim 字段,设为可选 | |
fvnsim: Optional[List[float]] = None # 添加 fvnsim 字段,设为可选 | |
status: str | |
error_message: Optional[str] = None | |
# Function to convert and resample audio using ffmpeg | |
def convert_and_resample_audio(input_path, output_path, target_sr): | |
"""Converts audio to WAV format and resamples using ffmpeg.""" | |
cmd = [ | |
'ffmpeg', | |
'-y', # Overwrite output file if it exists | |
'-i', input_path, | |
'-ar', str(target_sr), # Set target sample rate | |
'-ac', '1', # Force mono channel (ViSQOL often expects mono) | |
output_path | |
] | |
print(f"Running ffmpeg: {' '.join(cmd)}") | |
try: | |
result = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8') | |
print("ffmpeg conversion successful.") | |
# print(f"ffmpeg stderr: {result.stderr}") # Optional debug | |
return True | |
except FileNotFoundError: | |
print("错误: ffmpeg 未找到,无法转换音频。请确保已在 Docker 环境中安装 ffmpeg。") | |
return False | |
except subprocess.CalledProcessError as e: | |
print(f"错误: ffmpeg 执行失败 (返回码 {e.returncode})。") | |
print(f"ffmpeg stderr: {e.stderr}") | |
return False | |
except Exception as e: | |
print(f"转换音频时发生未知错误: {e}") | |
return False | |
async def evaluate_audio( | |
reference: UploadFile = File(..., description="参考音频文件"), | |
degraded: UploadFile = File(..., description="待评估音频文件"), | |
mode: str = "audio" # 'audio' 或 'speech' | |
): | |
""" | |
使用 ViSQOL 评估两个音频文件之间的感知相似度。 | |
返回预测的平均意见得分 (MOS-LQO)。 | |
""" | |
if visqol_lib_py is None: | |
raise HTTPException(status_code=500, detail="ViSQOL 库未成功加载。") | |
if mode not in ["audio", "speech"]: | |
raise HTTPException(status_code=400, detail="模式参数 'mode' 必须是 'audio' 或 'speech'") | |
temp_dir = tempfile.mkdtemp() | |
# Save with original extension first to help ffmpeg identify format | |
ref_temp_orig = os.path.join(temp_dir, f"ref_{reference.filename}") | |
deg_temp_orig = os.path.join(temp_dir, f"deg_{degraded.filename}") | |
# Define final WAV paths | |
ref_path_wav = os.path.join(temp_dir, "reference.wav") | |
deg_path_wav = os.path.join(temp_dir, "degraded.wav") | |
mos = -1.0 | |
vnsim_val = None # 初始化 vnsim | |
fvnsim_val = None # 初始化 fvnsim | |
status_msg = "处理失败" | |
error_msg = None | |
try: | |
# 1. 保存原始上传文件 | |
ref_content = await reference.read() | |
with open(ref_temp_orig, "wb") as f: f.write(ref_content) | |
deg_content = await degraded.read() | |
with open(deg_temp_orig, "wb") as f: f.write(deg_content) | |
await reference.close() | |
await degraded.close() | |
# 2. 确定目标采样率并转换/重采样文件 | |
target_sr = 48000 if mode == 'audio' else 16000 | |
print(f"目标采样率: {target_sr} Hz for mode '{mode}'") | |
conv_ref_ok = convert_and_resample_audio(ref_temp_orig, ref_path_wav, target_sr) | |
conv_deg_ok = convert_and_resample_audio(deg_temp_orig, deg_path_wav, target_sr) | |
if not (conv_ref_ok and conv_deg_ok): | |
raise HTTPException(status_code=500, detail="使用 ffmpeg 转换或重采样音频文件失败。") | |
# 3. 验证转换后的 WAV 文件 (可选) | |
try: | |
ref_info = sf.info(ref_path_wav) | |
deg_info = sf.info(deg_path_wav) | |
if ref_info.samplerate != target_sr or deg_info.samplerate != target_sr: | |
print(f"警告:ffmpeg 转换后的采样率 ({ref_info.samplerate}/{deg_info.samplerate}) 与目标 ({target_sr}) 不符,可能影响 ViSQOL 结果。") | |
except Exception as audio_e: | |
# 如果 sf.info 失败,可能是 ffmpeg 转换有问题 | |
raise HTTPException(status_code=400, detail=f"无法读取转换后的 WAV 文件: {audio_e}") | |
# 4. 加载转换/重采样后的音频数据 | |
try: | |
print(f"从 WAV 加载音频数据: {ref_path_wav}, {deg_path_wav}") | |
# 确保读取为 float64 类型 (对应 C++ double) | |
ref_data, sr_ref = sf.read(ref_path_wav, dtype='float64') | |
deg_data, sr_deg = sf.read(deg_path_wav, dtype='float64') | |
# 确认采样率是否符合预期 (理论上 ffmpeg 已经处理) | |
if sr_ref != target_sr or sr_deg != target_sr: | |
print(f"警告:读取的 WAV 文件采样率 ({sr_ref}/{sr_deg}) 与目标 ({target_sr}) 不符。") | |
# 可以选择在这里停止或继续 | |
print("音频数据加载成功。") | |
except Exception as read_e: | |
raise HTTPException(status_code=500, detail=f"读取转换后的 WAV 文件时出错: {read_e}") | |
# 5. 初始化 ViSQOL 配置 (修正模型选择逻辑) | |
config = visqol_config_pb2.VisqolConfig() | |
config.audio.sample_rate = target_sr # 使用目标采样率 | |
# 修正模型选择:根据官方示例调整 | |
if mode == "speech": | |
config.options.use_speech_scoring = True | |
# Speech mode uses the TFLite model according to official example | |
model_file_to_use = AUDIO_MODEL_PATH # .tflite model | |
else: # audio mode | |
config.options.use_speech_scoring = False | |
# Audio mode uses the SVR model according to official example | |
model_file_to_use = SPEECH_MODEL_PATH # .txt model (libsvm) | |
config.options.svr_model_path = os.path.abspath(model_file_to_use) | |
print(f"使用模型: {model_file_to_use} for mode '{mode}'") | |
# 6. 创建 API 实例并运行评估 (传递数据而不是路径) | |
api = visqol_lib_py.VisqolApi() | |
api.Create(config) # 传递对象 | |
# 传递加载的 NumPy 数组 | |
similarity_result_msg = api.Measure(ref_data, deg_data) # <--- 修改此处 | |
# 7. 处理结果 (逻辑保持不变,增加提取 vnsim 和 fvnsim) | |
if similarity_result_msg and hasattr(similarity_result_msg, 'moslqo'): | |
mos = similarity_result_msg.moslqo | |
status_msg = "处理成功" | |
print(f"ViSQOL 评估完成: MOS-LQO = {mos}") | |
# 尝试提取 vnsim | |
if hasattr(similarity_result_msg, 'vnsim'): | |
vnsim_val = similarity_result_msg.vnsim | |
print(f"VNSIM = {vnsim_val}") | |
else: | |
print("ViSQOL 结果中未找到 vnsim 字段。") | |
# 尝试提取 fvnsim (需要转换为 Python 列表) | |
if hasattr(similarity_result_msg, 'fvnsim') and similarity_result_msg.fvnsim: | |
fvnsim_val = list(similarity_result_msg.fvnsim) # 转换为列表 | |
print(f"FVNSIM (第一个元素): {fvnsim_val[0] if fvnsim_val else 'N/A'}") # 打印部分信息 | |
else: | |
print("ViSQOL 结果中未找到 fvnsim 字段或为空。") | |
else: | |
error_msg = "ViSQOL 未返回有效的 MOS-LQO 结果。" | |
print(f"错误: {error_msg}") | |
except ImportError as e: | |
status_msg = "导入错误" | |
error_msg = f"无法导入 ViSQOL 库或依赖: {e}" | |
print(f"错误: {error_msg}") | |
except FileNotFoundError as e: | |
status_msg = "文件未找到错误" | |
error_msg = f"必需文件丢失: {e}" | |
print(f"错误: {error_msg}") | |
except HTTPException as e: # 捕获我们自己抛出的 HTTP 异常 | |
status_msg = "请求错误" | |
error_msg = str(e.detail) | |
print(f"错误: {error_msg}") | |
except Exception as e: | |
status_msg = "运行时错误" | |
error_msg = f"处理过程中发生错误: {type(e).__name__} - {e}" | |
print(f"错误: {error_msg}") | |
# 可以在这里添加更详细的堆栈跟踪日志,如果需要 | |
# import traceback | |
# print(traceback.format_exc()) | |
finally: | |
if os.path.exists(temp_dir): | |
shutil.rmtree(temp_dir) | |
return VisqolResponse( | |
reference_filename=reference.filename, | |
degraded_filename=degraded.filename, | |
mode=mode, | |
moslqo=mos, | |
vnsim=vnsim_val, # 添加 vnsim 到响应 | |
fvnsim=fvnsim_val, # 添加 fvnsim 到响应 | |
status=status_msg, | |
error_message=error_msg | |
) | |
async def root(): | |
# 提供一个简单的根路径信息 | |
return {"message": "欢迎使用 ViSQOL 音频质量评估 API。请使用 POST 方法访问 /evaluate/ 端点。"} | |
# 添加健康检查端点 | |
async def health_check(): | |
"""Hugging Face Spaces health check endpoint.""" | |
# 如果 ViSQOL 库加载失败,也在这里反映出来 | |
if visqol_lib_py is None: | |
return {"status": "error", "detail": "ViSQOL library not loaded"} | |
return {"status": "ok"} | |
# 如果直接运行脚本,用于本地测试 (可选) | |
if __name__ == "__main__": | |
import uvicorn | |
print("运行本地测试服务器: http://127.0.0.1:8000") | |
# 注意:本地运行可能需要正确设置 LD_LIBRARY_PATH 或将 .so 文件放在系统可查找的路径 | |
uvicorn.run(app, host="127.0.0.1", port=8000) |