Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import json | |
import torchaudio | |
import gradio as gr | |
from faster_whisper import WhisperModel | |
import whisperx | |
from pyannote.audio import Pipeline as DiarizationPipeline | |
from transformers import pipeline | |
# NER via transformers (CamemBERT pour le français) | |
ner_pipeline = pipeline("ner", model="Jean-Baptiste/camembert-ner", aggregation_strategy="simple") | |
# Configuration des modèles | |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
diarization_pipeline = DiarizationPipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", | |
use_auth_token=HUGGINGFACE_TOKEN | |
) | |
whisper_model = WhisperModel("large-v2", device="cpu", compute_type="int8") | |
whisperx_model = whisperx.load_model("large-v2", device="cpu", compute_type="int8") | |
align_model, metadata = whisperx.load_align_model(language_code="fr", device="cpu") | |
def convert_to_wav_if_needed(audio_path: str) -> str: | |
if audio_path.lower().endswith(".mp3"): | |
new_path = audio_path[:-4] + ".wav" | |
waveform, sr = torchaudio.load(audio_path) | |
torchaudio.save(new_path, waveform, sr) | |
return new_path | |
return audio_path | |
def get_speaker_segments(audio_path: str) -> list: | |
diarization = diarization_pipeline(audio_path) | |
segments = [] | |
for turn, _, speaker in diarization.itertracks(yield_label=True): | |
start, end = float(turn.start), float(turn.end) | |
if end - start < 1.0: | |
continue | |
if end - start > 10.0: | |
end = start + 10.0 | |
segments.append({"start": start, "end": end, "speaker": speaker}) | |
return segments | |
def transcribe_with_alignment(audio_path: str, segments: list) -> list: | |
word_segments_all = [] | |
waveform, sr = torchaudio.load(audio_path) | |
for seg in segments: | |
start, end, speaker = seg["start"], seg["end"], seg["speaker"] | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
temp_audio_path = tmp.name | |
segment_waveform = waveform[:, int(start * sr): int(end * sr)] | |
torchaudio.save(temp_audio_path, segment_waveform, sr) | |
whisper_result = whisperx_model.transcribe(temp_audio_path) | |
aligned = whisperx.align(whisper_result["segments"], align_model, metadata, temp_audio_path, device="cpu") | |
for word in aligned.get("word_segments", []): | |
word["start"] += start | |
word["end"] += start | |
word["speaker"] = speaker | |
word_segments_all.append(word) | |
os.remove(temp_audio_path) | |
return word_segments_all | |
def extract_entities(word_segments: list): | |
full_text = " ".join([w["text"] for w in word_segments]) | |
entities_raw = ner_pipeline(full_text) | |
return full_text, entities_raw | |
def process_pipeline(audio_path: str): | |
audio_path = convert_to_wav_if_needed(audio_path) | |
segments = get_speaker_segments(audio_path) | |
words = transcribe_with_alignment(audio_path, segments) | |
aligned_path = audio_path.replace(".wav", "_aligned.json") | |
with open(aligned_path, "w", encoding="utf-8") as f: | |
json.dump(words, f, ensure_ascii=False, indent=2) | |
full_text, named_entities = extract_entities(words) | |
meta_path = audio_path.replace(".wav", "_meta.json") | |
with open(meta_path, "w", encoding="utf-8") as f: | |
json.dump({"text": full_text, "entities": named_entities}, f, ensure_ascii=False, indent=2) | |
return full_text, named_entities, aligned_path, meta_path | |
def gradio_process(audio_file_path): | |
if audio_file_path is None: | |
return "", [], None, None | |
try: | |
texte, ents, aligned_json, meta_json = process_pipeline(audio_file_path) | |
return texte, ents, aligned_json, meta_json | |
except Exception as e: | |
return f"Erreur : {str(e)}", [], None, None | |
with gr.Blocks() as demo: | |
gr.Markdown("## Transcription + Diarisation + NER en français") | |
gr.Markdown("- Le texte brut\n- Les entités nommées détectées\n- Les fichiers JSON générés") | |
with gr.Row(): | |
audio_input = gr.File(label="Sélectionnez un fichier audio", type="filepath") | |
run_button = gr.Button("Lancer") | |
with gr.Row(): | |
punctuated_output = gr.Textbox(label="Texte brut", lines=10) | |
entities_output = gr.JSON(label="Entités nommées") | |
with gr.Row(): | |
aligned_output = gr.File(label="Aligned JSON") | |
meta_output = gr.File(label="Meta JSON") | |
run_button.click( | |
gradio_process, | |
inputs=[audio_input], | |
outputs=[punctuated_output, entities_output, aligned_output, meta_output] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |