konieshadow's picture
更新LLM模型为google/gemma-3-4b-it,移除不再使用的Phi-4模型,优化设备参数支持,增强说话人识别器的日志记录功能。
924aa01
"""
整合ASR和说话人分离的转录器模块,支持流式处理长语音对话
"""
import os
from pydub import AudioSegment
from typing import Dict, List, Union, Optional, Any
import logging
from concurrent.futures import ThreadPoolExecutor
import re
from .summary.speaker_identify import SpeakerIdentifier # 新增导入
# 导入ASR和说话人分离模块,使用相对导入
from .asr import asr_router
from .asr.asr_base import TranscriptionResult
from .diarization import diarizer_router
from .schemas import EnhancedSegment, CombinedTranscriptionResult, PodcastChannel, PodcastEpisode, DiarizationResult
# 配置日志
logger = logging.getLogger("podcast_transcribe")
class CombinedTranscriber:
"""整合ASR和说话人分离的转录器"""
def __init__(
self,
asr_model_name: str,
asr_provider: str,
diarization_provider: str,
diarization_model_name: str,
llm_model_name: str,
llm_provider: str,
device: Optional[str] = None,
segmentation_batch_size: int = 64,
parallel: bool = False,
):
"""
初始化转录器
参数:
asr_model_name: ASR模型名称
asr_provider: ASR提供者名称
diarization_provider: 说话人分离提供者名称
diarization_model_name: 说话人分离模型名称
llm_model_name: LLM模型名称
llm_provider: LLM提供者名称
device: 推理设备,'cpu'或'cuda'
segmentation_batch_size: 分割批处理大小,默认为64
parallel: 是否并行执行ASR和说话人分离,默认为False
"""
if not device:
import torch
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
self.asr_model_name = asr_model_name
self.asr_provider = asr_provider
self.diarization_provider = diarization_provider
self.diarization_model_name = diarization_model_name
self.device = device
self.segmentation_batch_size = segmentation_batch_size
self.parallel = parallel
self.speaker_identifier = SpeakerIdentifier(
llm_model_name=llm_model_name,
llm_provider=llm_provider,
device=device
)
logger.info(f"初始化组合转录器,ASR提供者: {asr_provider},ASR模型: {asr_model_name},分离提供者: {diarization_provider},分离模型: {diarization_model_name},分割批处理大小: {segmentation_batch_size},并行执行: {parallel},推理设备: {device}")
def _merge_adjacent_text_segments(self, segments: List[EnhancedSegment]) -> List[EnhancedSegment]:
"""
合并相邻的、可能属于同一句子的 EnhancedSegment。
合并条件:同一说话人,时间基本连续,文本内容可拼接。
"""
if not segments:
return []
merged_segments: List[EnhancedSegment] = []
if not segments: # 重复检查,可移除
return merged_segments
current_merged_segment = segments[0]
for i in range(1, len(segments)):
next_segment = segments[i]
time_gap_seconds = next_segment.start - current_merged_segment.end
can_merge_text = False
if current_merged_segment.text and next_segment.text:
current_text_stripped = current_merged_segment.text.strip()
if current_text_stripped and not current_text_stripped[-1] in ".。?!?!":
can_merge_text = True
if (current_merged_segment.speaker == next_segment.speaker and
0 <= time_gap_seconds < 0.75 and
can_merge_text):
current_merged_segment = EnhancedSegment(
start=current_merged_segment.start,
end=next_segment.end,
text=(current_merged_segment.text.strip() + " " + next_segment.text.strip()).strip(),
speaker=current_merged_segment.speaker,
language=current_merged_segment.language
)
else:
merged_segments.append(current_merged_segment)
current_merged_segment = next_segment
merged_segments.append(current_merged_segment)
return merged_segments
def _run_asr(self, audio: AudioSegment) -> TranscriptionResult:
"""执行ASR处理"""
logger.debug("执行ASR...")
return asr_router.transcribe_audio(
audio,
provider=self.asr_provider,
model_name=self.asr_model_name,
device=self.device
)
def _run_diarization(self, audio: AudioSegment) -> DiarizationResult:
"""执行说话人分离处理"""
logger.debug("执行说话人分离...")
return diarizer_router.diarize_audio(
audio,
provider=self.diarization_provider,
model_name=self.diarization_model_name,
device=self.device,
segmentation_batch_size=self.segmentation_batch_size
)
def transcribe(self, audio: AudioSegment) -> CombinedTranscriptionResult:
"""
转录整个音频 (新的非流式逻辑将在这里实现)
参数:
audio: 要转录的AudioSegment对象
返回:
包含完整转录和说话人信息的结果
"""
logger.info(f"开始转录 {len(audio)/1000:.2f} 秒的音频 (非流式)")
if self.parallel:
# 并行执行ASR和说话人分离
logger.info("并行执行ASR和说话人分离")
with ThreadPoolExecutor(max_workers=2) as executor:
asr_future = executor.submit(self._run_asr, audio)
diarization_future = executor.submit(self._run_diarization, audio)
asr_result: TranscriptionResult = asr_future.result()
diarization_result: DiarizationResult = diarization_future.result()
logger.debug(f"ASR完成,识别语言: {asr_result.language},得到 {len(asr_result.segments)} 个分段")
logger.debug(f"说话人分离完成,得到 {len(diarization_result.segments)} 个说话人分段,检测到 {diarization_result.num_speakers} 个说话人")
else:
# 顺序执行ASR和说话人分离
# 步骤1: 对整个音频执行ASR
logger.debug("执行ASR...")
asr_result: TranscriptionResult = asr_router.transcribe_audio(
audio,
provider=self.asr_provider,
model_name=self.asr_model_name,
device=self.device
)
logger.debug(f"ASR完成,识别语言: {asr_result.language},得到 {len(asr_result.segments)} 个分段")
# 步骤2: 对整个音频执行说话人分离
logger.debug("执行说话人分离...")
diarization_result: DiarizationResult = diarizer_router.diarize_audio(
audio,
provider=self.diarization_provider,
model_name=self.diarization_model_name,
device=self.device,
segmentation_batch_size=self.segmentation_batch_size
)
logger.debug(f"说话人分离完成,得到 {len(diarization_result.segments)} 个说话人分段,检测到 {diarization_result.num_speakers} 个说话人")
# 步骤3: 创建增强分段
all_enhanced_segments: List[EnhancedSegment] = self._create_enhanced_segments_with_splitting(
asr_result.segments,
diarization_result.segments,
asr_result.language
)
# 步骤4: (可选)合并相邻的文本分段
if all_enhanced_segments:
logger.debug(f"合并前有 {len(all_enhanced_segments)} 个增强分段,尝试合并相邻分段...")
final_segments = self._merge_adjacent_text_segments(all_enhanced_segments)
logger.debug(f"合并后有 {len(final_segments)} 个增强分段")
else:
final_segments = []
logger.debug("没有增强分段可供合并。")
# 整理合并的文本
full_text = " ".join([segment.text for segment in final_segments]).strip()
# 计算最终说话人数
num_speakers_set = set(s.speaker for s in final_segments if s.speaker != "UNKNOWN")
return CombinedTranscriptionResult(
segments=final_segments,
text=full_text,
language=asr_result.language or "unknown",
num_speakers=len(num_speakers_set) if num_speakers_set else diarization_result.num_speakers
)
# 新方法:根据标点分割ASR文本片段
def _split_asr_segment_by_punctuation(
self,
asr_seg_text: str,
asr_seg_start: float,
asr_seg_end: float
) -> List[Dict[str, Any]]:
"""
根据标点符号分割ASR文本片段,并按字符比例估算子片段的时间戳。
返回: 字典列表,每个字典包含 'text', 'start', 'end'。
"""
sentence_terminators = ".。?!?!;;"
# 正则表达式:匹配句子内容以及紧随其后的标点(如果存在)
# 使用 re.split 保留分隔符,然后重组
parts = re.split(f'([{sentence_terminators}])', asr_seg_text)
sub_texts_final = []
current_s = ""
for s_part in parts:
if not s_part:
continue
current_s += s_part
if s_part in sentence_terminators:
if current_s.strip():
sub_texts_final.append(current_s.strip())
current_s = ""
if current_s.strip():
sub_texts_final.append(current_s.strip())
if not sub_texts_final or (len(sub_texts_final) == 1 and sub_texts_final[0] == asr_seg_text.strip()):
# 没有有效分割或分割后只有一个句子(等于原始文本)
return [{"text": asr_seg_text.strip(), "start": asr_seg_start, "end": asr_seg_end}]
output_sub_segments = []
total_text_len = len(asr_seg_text) # 使用原始文本长度进行比例计算
if total_text_len == 0:
return [{"text": "", "start": asr_seg_start, "end": asr_seg_end}]
current_time = asr_seg_start
original_duration = asr_seg_end - asr_seg_start
for i, sub_text in enumerate(sub_texts_final):
sub_len = len(sub_text)
sub_duration = (sub_len / total_text_len) * original_duration
sub_start_time = current_time
sub_end_time = current_time + sub_duration
# 对于最后一个分片,确保其结束时间与原始分段的结束时间一致,以避免累积误差
if i == len(sub_texts_final) - 1:
sub_end_time = asr_seg_end
# 确保结束时间不超过原始结束时间,并且开始时间不晚于结束时间
sub_end_time = min(sub_end_time, asr_seg_end)
if sub_start_time >= sub_end_time and sub_start_time == asr_seg_end : # 如果开始等于原始结束,允许微小片段
if sub_text: # 仅当有文本时
output_sub_segments.append({"text": sub_text, "start": sub_start_time, "end": sub_end_time})
elif sub_start_time < sub_end_time :
output_sub_segments.append({"text": sub_text, "start": sub_start_time, "end": sub_end_time})
current_time = sub_end_time
if current_time >= asr_seg_end and i < len(sub_texts_final) -1: # 如果时间已用完,但还有句子
# 将剩余句子附加到最后一个有效的时间段,或创建零长度的段
logger.warning(f"时间已在分割过程中用尽,但仍有文本未分配时间。原始段: [{asr_seg_start}-{asr_seg_end}], 当前子句: '{sub_text}'")
# 为后续未分配时间的文本创建零时长或极短时长的片段,附着在末尾
for k in range(i + 1, len(sub_texts_final)):
remaining_text = sub_texts_final[k]
if remaining_text:
output_sub_segments.append({"text": remaining_text, "start": asr_seg_end, "end": asr_seg_end})
break
# 如果处理后没有任何子分段(例如原始文本为空,或分割逻辑问题),返回原始信息作为一个分段
if not output_sub_segments and asr_seg_text.strip():
return [{"text": asr_seg_text.strip(), "start": asr_seg_start, "end": asr_seg_end}]
elif not output_sub_segments and not asr_seg_text.strip():
return [{"text": "", "start": asr_seg_start, "end": asr_seg_end}]
return output_sub_segments
# 新的核心方法:创建增强分段,包含说话人分配和按需分裂逻辑
def _create_enhanced_segments_with_splitting(
self,
asr_segments: List[Dict[str, Union[float, str]]],
diarization_segments: List[Dict[str, Union[float, str, int]]],
language: str
) -> List[EnhancedSegment]:
"""
为ASR分段分配说话人,如果ASR分段跨越多个说话人,则尝试按标点分裂。
"""
final_enhanced_segments: List[EnhancedSegment] = []
if not asr_segments:
return []
# 为了快速查找,可以预处理 diarization_segments,但对于数量不多的情况,直接遍历也可
# diarization_segments.sort(key=lambda x: x['start']) # 确保有序
for asr_seg in asr_segments:
asr_start = float(asr_seg["start"])
asr_end = float(asr_seg["end"])
asr_text = str(asr_seg["text"]).strip()
if not asr_text or asr_start >= asr_end: # 跳过无效的ASR分段
continue
# 找出与当前ASR分段在时间上重叠的所有说话人分段
overlapping_diar_segs = []
for diar_seg in diarization_segments:
diar_start = float(diar_seg["start"])
diar_end = float(diar_seg["end"])
overlap_start = max(asr_start, diar_start)
overlap_end = min(asr_end, diar_end)
if overlap_end > overlap_start: # 有重叠
overlapping_diar_segs.append({
"speaker": str(diar_seg["speaker"]),
"start": diar_start,
"end": diar_end,
"overlap_duration": overlap_end - overlap_start
})
distinct_speakers_in_overlap = set(d['speaker'] for d in overlapping_diar_segs)
segments_to_process_further: List[Dict[str, Any]] = []
if len(distinct_speakers_in_overlap) > 1:
logger.debug(f"ASR段 [{asr_start:.2f}-{asr_end:.2f}] \"{asr_text[:50]}...\" 跨越 {len(distinct_speakers_in_overlap)} 个说话人。尝试按标点分裂。")
# 跨多个说话人,尝试按标点分裂ASR segment
sub_asr_segments_data = self._split_asr_segment_by_punctuation(
asr_text,
asr_start,
asr_end
)
if len(sub_asr_segments_data) > 1:
logger.debug(f"成功将ASR段分裂成 {len(sub_asr_segments_data)} 个子句。")
segments_to_process_further.extend(sub_asr_segments_data)
else:
# 单一说话人或无说话人重叠(也视为单一处理单位)
segments_to_process_further.append({"text": asr_text, "start": asr_start, "end": asr_end})
# 为每个原始或分裂后的ASR(子)分段分配说话人
for current_proc_seg_data in segments_to_process_further:
proc_text = current_proc_seg_data["text"].strip()
proc_start = current_proc_seg_data["start"]
proc_end = current_proc_seg_data["end"]
if not proc_text or proc_start >= proc_end: # 跳过无效的子分段
continue
# 为当前处理的(可能是子)分段确定最佳说话人
speaker_overlaps_for_proc_seg = {}
for diar_seg_info in overlapping_diar_segs: # 使用之前计算的、与原始ASR段重叠的diar_segs
# 现在需要计算这个 diar_seg_info 与 proc_seg 的重叠
overlap_start = max(proc_start, diar_seg_info["start"])
overlap_end = min(proc_end, diar_seg_info["end"])
if overlap_end > overlap_start:
overlap_duration = overlap_end - overlap_start
speaker = diar_seg_info["speaker"]
speaker_overlaps_for_proc_seg[speaker] = \
speaker_overlaps_for_proc_seg.get(speaker, 0) + overlap_duration
best_speaker = "UNKNOWN"
if speaker_overlaps_for_proc_seg:
best_speaker = max(speaker_overlaps_for_proc_seg.items(), key=lambda x: x[1])[0]
elif overlapping_diar_segs: # 如果子分段本身没有重叠,但原始ASR段有
# 可以选择原始ASR段中占比最大的,或者最近的
# 为简化,如果子分段无直接重叠,也可能标记为UNKNOWN,或尝试找最近的
# 这里采用:如果子分段无直接重叠,但在原始ASR段中有说话人,则使用原始ASR段中重叠最长的
# (此逻辑分支效果待观察,更简单的是直接UNKNOWN)
# 此处简化:若子分段无重叠,则为UNKNOWN
pass # best_speaker 默认为 UNKNOWN
# 如果 best_speaker 仍为 UNKNOWN,但原始ASR段只有一个说话者,则使用该说话者
if best_speaker == "UNKNOWN" and len(distinct_speakers_in_overlap) == 1:
best_speaker = list(distinct_speakers_in_overlap)[0]
elif best_speaker == "UNKNOWN" and not overlapping_diar_segs:
# 如果整个ASR段都没有任何说话人信息,则确实是UNKNOWN
pass
final_enhanced_segments.append(
EnhancedSegment(
start=proc_start,
end=proc_end,
text=proc_text,
speaker=best_speaker,
language=language # 所有子分段继承原始ASR段的语言
)
)
# 对最终结果按开始时间排序
final_enhanced_segments.sort(key=lambda seg: seg.start)
return final_enhanced_segments
def transcribe_podcast(
self,
audio: AudioSegment,
podcast_info: PodcastChannel,
episode_info: PodcastEpisode,
) -> CombinedTranscriptionResult:
"""
专门针对播客剧集的音频转录方法
参数:
audio: 要转录的AudioSegment对象
podcast_info: 播客频道信息
episode_info: 播客剧集信息
返回:
包含完整转录和识别后说话人名称的结果
"""
logger.info(f"开始转录播客剧集 {len(audio)/1000:.2f} 秒的音频")
# 1. 先执行基础转录流程
transcription_result = self.transcribe(audio)
# 3. 识别说话人名称
logger.info("识别说话人名称...")
speaker_name_map = self.speaker_identifier.recognize_speaker_names(
transcription_result.segments,
podcast_info,
episode_info
)
# 4. 将识别的说话人名称添加到转录结果中
enhanced_segments_with_names = []
for segment in transcription_result.segments:
# 复制原始段落并添加说话人名称
speaker_id = segment.speaker
speaker_name = speaker_name_map.get(speaker_id, None)
# 创建新的段落对象,包含说话人名称
new_segment = EnhancedSegment(
start=segment.start,
end=segment.end,
text=segment.text,
speaker=speaker_id,
language=segment.language,
speaker_name=speaker_name
)
enhanced_segments_with_names.append(new_segment)
# 5. 创建并返回新的转录结果
return CombinedTranscriptionResult(
segments=enhanced_segments_with_names,
text=transcription_result.text,
language=transcription_result.language,
num_speakers=transcription_result.num_speakers
)
def transcribe_audio(
audio_segment: AudioSegment,
asr_model_name: str = "distil-whisper/distil-large-v3.5",
asr_provider: str = "distil_whisper_transformers",
diarization_model_name: str = "pyannote/speaker-diarization-3.1",
diarization_provider: str = "pyannote_transformers",
device: Optional[str] = None,
segmentation_batch_size: int = 64,
parallel: bool = False,
) -> CombinedTranscriptionResult: # 返回类型固定为 CombinedTranscriptionResult
"""
整合ASR和说话人分离的音频转录函数 (仅支持非流式)
参数:
audio_segment: 输入的AudioSegment对象
asr_model_name: ASR模型名称
asr_provider: ASR提供者名称
diarization_model_name: 说话人分离模型名称
diarization_provider: 说话人分离提供者名称
device: 推理设备,'cpu'或'cuda'
segmentation_batch_size: 分割批处理大小,默认为64
parallel: 是否并行执行ASR和说话人分离,默认为False
返回:
完整转录结果
"""
logger.info(f"调用transcribe_audio函数 (非流式),音频长度: {len(audio_segment)/1000:.2f}秒")
transcriber = CombinedTranscriber(
asr_model_name=asr_model_name,
asr_provider=asr_provider,
diarization_model_name=diarization_model_name,
diarization_provider=diarization_provider,
llm_model_name="",
llm_provider="",
device=device,
segmentation_batch_size=segmentation_batch_size,
parallel=parallel
)
# 直接调用 transcribe 方法
return transcriber.transcribe(audio_segment)
def transcribe_podcast_audio(
audio_segment: AudioSegment,
podcast_info: PodcastChannel,
episode_info: PodcastEpisode,
asr_model_name: str = "distil-whisper/distil-large-v3.5",
asr_provider: str = "distil_whisper_transformers",
diarization_model_name: str = "pyannote/speaker-diarization-3.1",
diarization_provider: str = "pyannote_transformers",
llm_model_name: str = "google/gemma-3-4b-it",
llm_provider: str = "gemma-transformers",
device: Optional[str] = None,
segmentation_batch_size: int = 64,
parallel: bool = False,
) -> CombinedTranscriptionResult:
"""
针对播客剧集的音频转录函数,包含说话人名称识别
参数:
audio_segment: 输入的AudioSegment对象
podcast_info: 播客频道信息
episode_info: 播客剧集信息
asr_model_name: ASR模型名称
asr_provider: ASR提供者名称
diarization_provider: 说话人分离提供者名称
diarization_model_name: 说话人分离模型名称
llm_model_name: LLM模型名称
llm_provider: LLM提供者名称
device: 推理设备,'cpu'或'cuda'
segmentation_batch_size: 分割批处理大小,默认为64
parallel: 是否并行执行ASR和说话人分离,默认为False
返回:
包含说话人名称的完整转录结果
"""
logger.info(f"调用transcribe_podcast_audio函数,音频长度: {len(audio_segment)/1000:.2f}秒")
transcriber = CombinedTranscriber(
asr_model_name=asr_model_name,
asr_provider=asr_provider,
diarization_provider=diarization_provider,
diarization_model_name=diarization_model_name,
llm_model_name=llm_model_name,
llm_provider=llm_provider,
device=device,
segmentation_batch_size=segmentation_batch_size,
parallel=parallel
)
# 调用播客专用转录方法
return transcriber.transcribe_podcast(
audio=audio_segment,
podcast_info=podcast_info,
episode_info=episode_info,
)