asr-inference / whisper_cs.py
ssolito's picture
Update whisper_cs.py
a5e2883 verified
import spaces
from pydub import AudioSegment
import os
import torchaudio
import torch
import re
import whisper_timestamped as whisper_ts
from typing import Dict
from faster_whisper import WhisperModel
device = 0 if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float32
DEBUG_MODE = True
MODEL_PATH_V2 = "langtech-veu/whisper-timestamped-cs"
MODEL_PATH_V2_FAST = "langtech-veu/faster-whisper-timestamped-cs"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("[INFO] CUDA available:", torch.cuda.is_available())
def clean_text(input_text):
remove_chars = ['.', ',', ';', ':', '¿', '?', '«', '»', '-', '¡', '!', '@',
'*', '{', '}', '[', ']', '=', '/', '\\', '&', '#', '…']
output_text = ''.join(char if char not in remove_chars else ' ' for char in input_text)
return ' '.join(output_text.split()).lower()
def split_stereo_channels(audio_path):
ext = os.path.splitext(audio_path)[1].lower()
if ext == ".wav":
audio = AudioSegment.from_wav(audio_path)
elif ext == ".mp3":
audio = AudioSegment.from_file(audio_path, format="mp3")
else:
raise ValueError(f"Unsupported file format: {audio_path}")
channels = audio.split_to_mono()
if len(channels) != 2:
raise ValueError(f"Audio {audio_path} does not have 2 channels.")
channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right
channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left
def format_audio(audio_path):
input_audio, sample_rate = torchaudio.load(audio_path)
if input_audio.shape[0] == 2:
input_audio = torch.mean(input_audio, dim=0, keepdim=True)
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
input_audio = resampler(input_audio)
return input_audio.squeeze(), 16000
def post_process_transcription(transcription, max_repeats=2):
tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
cleaned_tokens = []
repetition_count = 0
previous_token = None
for token in tokens:
reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token)
if reduced_token == previous_token:
repetition_count += 1
if repetition_count <= max_repeats:
cleaned_tokens.append(reduced_token)
else:
repetition_count = 1
cleaned_tokens.append(reduced_token)
previous_token = reduced_token
cleaned_transcription = " ".join(cleaned_tokens)
cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip()
return cleaned_transcription
def post_merge_consecutive_segments_from_text(transcription_text: str) -> str:
segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text)
merged_transcription = ''
current_speaker = None
current_segment = []
for i in range(1, len(segments) - 1, 2):
speaker_tag = segments[i]
text = segments[i + 1].strip()
speaker = re.search(r'\d{2}', speaker_tag).group()
if speaker == current_speaker:
current_segment.append(text)
else:
if current_speaker is not None:
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
current_speaker = speaker
current_segment = [text]
if current_speaker is not None:
merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
return merged_transcription.strip()
def cleanup_temp_files(*file_paths):
if DEBUG_MODE: print(f"Entered cleanup_temp_files function...")
if DEBUG_MODE: print(f"file_paths: {file_paths}")
for path in file_paths:
if path and os.path.exists(path):
if DEBUG_MODE: print(f"Removing path: {path}")
os.remove(path)
if DEBUG_MODE: print(f"Exited cleanup_temp_files function.")
'''
try:
faster_model = WhisperModel(
MODEL_PATH_V2_FAST,
device="cuda" if torch.cuda.is_available() else "cpu",
compute_type="float16" if torch.cuda.is_available() else "int8"
)
except RuntimeError as e:
print(f"[WARNING] Failed to load model on GPU: {e}")
faster_model = WhisperModel(
MODEL_PATH_V2_FAST,
device="cpu",
compute_type="int8"
)
'''
#faster_model = WhisperModel(MODEL_PATH_V2_FAST, device=DEVICE, compute_type="int8")
def load_whisper_model(model_path: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper_ts.load_model(model_path, device=device)
return model
def transcribe_audio(model, audio_path: str) -> Dict:
try:
result = whisper_ts.transcribe(
model,
audio_path,
beam_size=5,
best_of=5,
temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
vad=False,
detect_disfluencies=True,
)
words = []
for segment in result.get('segments', []):
for word in segment.get('words', []):
word_text = word.get('word', '').strip()
if word_text.startswith(' '):
word_text = word_text[1:]
words.append({
'word': word_text,
'start': word.get('start', 0),
'end': word.get('end', 0),
'confidence': word.get('confidence', 0)
})
return {
'audio_path': audio_path,
'text': result['text'].strip(),
'segments': result.get('segments', []),
'words': words,
'duration': result.get('duration', 0),
'success': True
}
except Exception as e:
return {
'audio_path': audio_path,
'error': str(e),
'success': False
}
def generate(audio_path, use_v2_fast):
if DEBUG_MODE: print(f"Entering generate function...")
if DEBUG_MODE: print(f"use_v2_fast: {use_v2_fast}")
faster_model = None
if use_v2_fast:
if torch.cuda.is_available():
try:
if DEBUG_MODE: print("[INFO] GPU detected. Loading model on GPU with float16...")
faster_model = WhisperModel(
MODEL_PATH_V2_FAST,
device="cuda",
compute_type="float16"
)
except RuntimeError as e:
print(f"[WARNING] Failed to load model on GPU: {e}")
if DEBUG_MODE: print("[INFO] Falling back to CPU with int8...")
faster_model = WhisperModel(
MODEL_PATH_V2_FAST,
device="cpu",
compute_type="int8"
)
else:
if DEBUG_MODE: print("[INFO] No GPU detected. Loading model on CPU with int8...")
faster_model = WhisperModel(
MODEL_PATH_V2_FAST,
device="cpu",
compute_type="int8"
)
split_stereo_channels(audio_path)
left_channel_path = "temp_mono_speaker2.wav"
right_channel_path = "temp_mono_speaker1.wav"
left_waveform, _ = format_audio(left_channel_path)
right_waveform, _ = format_audio(right_channel_path)
left_waveform = left_waveform.numpy().astype("float32")
right_waveform = right_waveform.numpy().astype("float32")
left_result, _ = faster_model.transcribe(left_waveform, beam_size=5, task="transcribe")
right_result, _ = faster_model.transcribe(right_waveform, beam_size=5, task="transcribe")
left_result = list(left_result)
right_result = list(right_result)
def get_faster_segments(segments, speaker_label):
return [
(seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
for seg in segments if seg.text
]
left_segs = get_faster_segments(left_result, "Speaker 1")
right_segs = get_faster_segments(right_result, "Speaker 2")
merged_transcript = sorted(
left_segs + right_segs,
key=lambda x: float(x[0]) if x[0] is not None else float("inf")
)
clean_output = ""
for start, end, speaker, text in merged_transcript:
clean_output += f"[{speaker}]: {text}\n"
if DEBUG_MODE: print(f"clean_output: {clean_output}")
else:
model = load_whisper_model(MODEL_PATH_V2)
split_stereo_channels(audio_path)
left_channel_path = "temp_mono_speaker2.wav"
right_channel_path = "temp_mono_speaker1.wav"
left_waveform, _ = format_audio(left_channel_path)
right_waveform, _ = format_audio(right_channel_path)
left_result = transcribe_audio(model, left_waveform)
right_result = transcribe_audio(model, right_waveform)
def get_segments(result, speaker_label):
segments = result.get("segments", [])
if not segments:
return []
return [
(seg.get("start", 0.0), seg.get("end", 0.0), speaker_label,
post_process_transcription(seg.get("text", "").strip()))
for seg in segments if seg.get("text")
]
left_segs = get_segments(left_result, "Speaker 1")
right_segs = get_segments(right_result, "Speaker 2")
merged_transcript = sorted(
left_segs + right_segs,
key=lambda x: float(x[0]) if x[0] is not None else float("inf")
)
clean_output = ""
for start, end, speaker, text in merged_transcript:
clean_output += f"[{speaker}]: {text}\n"
cleanup_temp_files("temp_mono_speaker1.wav", "temp_mono_speaker2.wav")
if DEBUG_MODE: print(f"Exiting generate function...")
return clean_output.strip()
'''
def generate(audio_path, use_v2_fast):
if DEBUG_MODE: print(f"Entering generate function...")
if DEBUG_MODE: print(f"use_v2_fast: {use_v2_fast}")
if use_v2_fast:
split_stereo_channels(audio_path)
left_channel_path = "temp_mono_speaker2.wav"
right_channel_path = "temp_mono_speaker1.wav"
left_waveform, left_sr = format_audio(left_channel_path)
right_waveform, right_sr = format_audio(right_channel_path)
left_waveform = left_waveform.numpy().astype("float32")
right_waveform = right_waveform.numpy().astype("float32")
left_result, info = faster_model.transcribe(left_waveform, beam_size=5, task="transcribe")
right_result, info = faster_model.transcribe(right_waveform, beam_size=5, task="transcribe")
left_result = list(left_result)
right_result = list(right_result)
def get_faster_segments(segments, speaker_label):
return [
(seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
for seg in segments if seg.text
]
left_segs = get_faster_segments(left_result, "Speaker 1")
right_segs = get_faster_segments(right_result, "Speaker 2")
merged_transcript = sorted(
left_segs + right_segs,
key=lambda x: float(x[0]) if x[0] is not None else float("inf")
)
clean_output = ""
for start, end, speaker, text in merged_transcript:
clean_output += f"[{speaker}]: {text}\n"
# FIX Seems that post_merge_consecutive_segments_from_text returns an empty string
#clean_output = post_merge_consecutive_segments_from_text(clean_output)
#print('clean_output',clean_output)
if DEBUG_MODE: print(f"clean_output: {clean_output}")
else:
model = load_whisper_model(MODEL_PATH_V2)
split_stereo_channels(audio_path)
left_channel_path = "temp_mono_speaker2.wav"
right_channel_path = "temp_mono_speaker1.wav"
left_waveform, left_sr = format_audio(left_channel_path)
right_waveform, right_sr = format_audio(right_channel_path)
left_result = transcribe_audio(model, left_waveform)
right_result = transcribe_audio(model, right_waveform)
def get_segments(result, speaker_label):
segments = result.get("segments", [])
if not segments:
return []
return [
(seg.get("start", 0.0), seg.get("end", 0.0), speaker_label, post_process_transcription(seg.get("text", "").strip()))
for seg in segments if seg.get("text")
]
left_segs = get_segments(left_result, "Speaker 1")
right_segs = get_segments(right_result, "Speaker 2")
merged_transcript = sorted(
left_segs + right_segs,
key=lambda x: float(x[0]) if x[0] is not None else float("inf")
)
output = ""
for start, end, speaker, text in merged_transcript:
output += f"[{speaker}]: {text}\n"
clean_output = output.strip()
if DEBUG_MODE: print(f"Clean output generated.")
cleanup_temp_files(
"temp_mono_speaker1.wav",
"temp_mono_speaker2.wav"
)
if DEBUG_MODE: print(f"Exiting generate function...")
return clean_output
'''