test_whisper / app.py
evannh's picture
Update app.py
8555957 verified
import os
import tempfile
import json
import torchaudio
import gradio as gr
from faster_whisper import WhisperModel
import whisperx
from pyannote.audio import Pipeline as DiarizationPipeline
from transformers import pipeline
# NER via transformers (CamemBERT pour le français)
ner_pipeline = pipeline("ner", model="Jean-Baptiste/camembert-ner", aggregation_strategy="simple")
# Configuration des modèles
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
diarization_pipeline = DiarizationPipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=HUGGINGFACE_TOKEN
)
whisper_model = WhisperModel("large-v2", device="cpu", compute_type="int8")
whisperx_model = whisperx.load_model("large-v2", device="cpu", compute_type="int8")
align_model, metadata = whisperx.load_align_model(language_code="fr", device="cpu")
def convert_to_wav_if_needed(audio_path: str) -> str:
if audio_path.lower().endswith(".mp3"):
new_path = audio_path[:-4] + ".wav"
waveform, sr = torchaudio.load(audio_path)
torchaudio.save(new_path, waveform, sr)
return new_path
return audio_path
def get_speaker_segments(audio_path: str) -> list:
diarization = diarization_pipeline(audio_path)
segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
start, end = float(turn.start), float(turn.end)
if end - start < 1.0:
continue
if end - start > 10.0:
end = start + 10.0
segments.append({"start": start, "end": end, "speaker": speaker})
return segments
def transcribe_with_alignment(audio_path: str, segments: list) -> list:
word_segments_all = []
waveform, sr = torchaudio.load(audio_path)
for seg in segments:
start, end, speaker = seg["start"], seg["end"], seg["speaker"]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
temp_audio_path = tmp.name
segment_waveform = waveform[:, int(start * sr): int(end * sr)]
torchaudio.save(temp_audio_path, segment_waveform, sr)
whisper_result = whisperx_model.transcribe(temp_audio_path)
aligned = whisperx.align(whisper_result["segments"], align_model, metadata, temp_audio_path, device="cpu")
for word in aligned.get("word_segments", []):
word["start"] += start
word["end"] += start
word["speaker"] = speaker
word_segments_all.append(word)
os.remove(temp_audio_path)
return word_segments_all
def extract_entities(word_segments: list):
full_text = " ".join([w["text"] for w in word_segments])
entities_raw = ner_pipeline(full_text)
return full_text, entities_raw
def process_pipeline(audio_path: str):
audio_path = convert_to_wav_if_needed(audio_path)
segments = get_speaker_segments(audio_path)
words = transcribe_with_alignment(audio_path, segments)
aligned_path = audio_path.replace(".wav", "_aligned.json")
with open(aligned_path, "w", encoding="utf-8") as f:
json.dump(words, f, ensure_ascii=False, indent=2)
full_text, named_entities = extract_entities(words)
meta_path = audio_path.replace(".wav", "_meta.json")
with open(meta_path, "w", encoding="utf-8") as f:
json.dump({"text": full_text, "entities": named_entities}, f, ensure_ascii=False, indent=2)
return full_text, named_entities, aligned_path, meta_path
def gradio_process(audio_file_path):
if audio_file_path is None:
return "", [], None, None
try:
texte, ents, aligned_json, meta_json = process_pipeline(audio_file_path)
return texte, ents, aligned_json, meta_json
except Exception as e:
return f"Erreur : {str(e)}", [], None, None
with gr.Blocks() as demo:
gr.Markdown("## Transcription + Diarisation + NER en français")
gr.Markdown("- Le texte brut\n- Les entités nommées détectées\n- Les fichiers JSON générés")
with gr.Row():
audio_input = gr.File(label="Sélectionnez un fichier audio", type="filepath")
run_button = gr.Button("Lancer")
with gr.Row():
punctuated_output = gr.Textbox(label="Texte brut", lines=10)
entities_output = gr.JSON(label="Entités nommées")
with gr.Row():
aligned_output = gr.File(label="Aligned JSON")
meta_output = gr.File(label="Meta JSON")
run_button.click(
gradio_process,
inputs=[audio_input],
outputs=[punctuated_output, entities_output, aligned_output, meta_output]
)
if __name__ == "__main__":
demo.launch(share=True)