Spaces:
Sleeping
Sleeping
| """ | |
| 说话人分离器基础类,包含可复用的方法 | |
| """ | |
| import os | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from pydub import AudioSegment | |
| from typing import Any, Dict, List, Union, Optional, Tuple | |
| from ..schemas import DiarizationResult | |
| # 配置日志 | |
| logger = logging.getLogger("diarization") | |
| class BaseDiarizer(ABC): | |
| """说话人分离器基础类""" | |
| def __init__( | |
| self, | |
| model_name: str, | |
| token: Optional[str] = None, | |
| device: str = "cpu", | |
| segmentation_batch_size: int = 32, | |
| ): | |
| """ | |
| 初始化说话人分离器基础参数 | |
| 参数: | |
| model_name: 模型名称 | |
| token: Hugging Face令牌,用于访问模型 | |
| device: 推理设备,'cpu'或'cuda' | |
| segmentation_batch_size: 分割批处理大小,默认为32 | |
| """ | |
| self.model_name = model_name | |
| self.device = device | |
| self.segmentation_batch_size = segmentation_batch_size | |
| logger.info(f"初始化说话人分离器,模型: {model_name},设备: {device},分割批处理大小: {segmentation_batch_size}") | |
| def _load_model(self): | |
| """加载模型,子类需要实现""" | |
| pass | |
| def _prepare_audio(self, audio: AudioSegment) -> str: | |
| """ | |
| 准备音频数据,保存为临时文件 | |
| 参数: | |
| audio: 输入的AudioSegment对象 | |
| 返回: | |
| 临时音频文件的路径 | |
| """ | |
| logger.debug(f"准备音频数据: 时长={len(audio)/1000:.2f}秒, 采样率={audio.frame_rate}Hz, 声道数={audio.channels}") | |
| # 确保采样率为16kHz (pyannote模型要求) | |
| if audio.frame_rate != 16000: | |
| logger.debug(f"重采样音频从 {audio.frame_rate}Hz 到 16000Hz") | |
| audio = audio.set_frame_rate(16000) | |
| # 确保是单声道 | |
| if audio.channels > 1: | |
| logger.debug(f"将{audio.channels}声道音频转换为单声道") | |
| audio = audio.set_channels(1) | |
| # 保存为临时文件 | |
| temp_audio_path = "_temp_audio_for_diarization.wav" | |
| audio.export(temp_audio_path, format="wav") | |
| logger.debug(f"音频处理完成,保存至: {temp_audio_path}") | |
| return temp_audio_path | |
| def _convert_segments(self, diarization) -> Tuple[List[Dict[str, Union[float, str, int]]], int]: | |
| """ | |
| 将pyannote的分段结果转换为所需格式 | |
| 参数: | |
| diarization: pyannote模型返回的分段结果 | |
| 返回: | |
| 转换后的分段列表和说话人数量 | |
| """ | |
| segments = [] | |
| speakers = set() | |
| # 遍历说话人分离结果 | |
| for turn, _, speaker in diarization.itertracks(yield_label=True): | |
| segments.append({ | |
| "start": turn.start, | |
| "end": turn.end, | |
| "speaker": speaker | |
| }) | |
| speakers.add(speaker) | |
| # 按开始时间排序 | |
| segments.sort(key=lambda x: x["start"]) | |
| logger.debug(f"转换了 {len(segments)} 个分段,检测到 {len(speakers)} 个说话人") | |
| return segments, len(speakers) | |
| def diarize(self, audio: AudioSegment) -> DiarizationResult: | |
| """ | |
| 对音频进行说话人分离,子类需要实现 | |
| 参数: | |
| audio: 要处理的AudioSegment对象 | |
| 返回: | |
| DiarizationResult对象,包含分段结果和说话人数量 | |
| """ | |
| pass |