ucare / app.py
antonin perrot-audet
log asr response
7062303
raw
history blame contribute delete
No virus
8.15 kB
## Dirty one file implementation for expermiental (and fun) purpose only
import os
import gradio as gr
from gradio_client import Client
import requests
from dotenv import load_dotenv
from pydub import AudioSegment
from tqdm.auto import tqdm
print("starting")
load_dotenv()
HF_API = os.getenv("HF_API")
SEAMLESS_API_URL = os.getenv("SEAMLESS_API_URL") # path to Seamlessm4t API endpoint
GPU_AVAILABLE = os.getenv("GPU_AVAILABLE")
DEFAULT_TARGET_LANGUAGE = "French"
MISTRAL_SUMMARY_URL = (
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
)
LLAMA_SUMMARY_URL = (
"https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
)
print("env setup ok")
DESCRIPTION = """
# Transcribe and create a summary of a conversation.
"""
DUPLICATE = """
To duplicate this repo, you have to give permission from three reopsitories and accept all user conditions:
1- https://huggingface.co/pyannote/voice-activity-detection
2- https://hf.co/pyannote/segmentation
3- https://hf.co/pyannote/speaker-diarization
"""
from pyannote.audio import Pipeline
# initialize diarization pipeline
diarizer = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=HF_API
)
# send pipeline to GPU (when available)
import torch
diarizer.to(torch.device(GPU_AVAILABLE))
print("diarizer setup ok")
# predict is a generator that incrementally yields recognized text with speaker label
def predict(target_language, input_audio):
print("->predict started")
print(target_language, type(input_audio), input_audio)
print("-->diarization")
diarized = diarizer(input_audio, min_speakers=2, max_speakers=5)
print("-->automatic speech recognition")
# split audio according to diarization
song = AudioSegment.from_wav(input_audio)
# client = Client(SEAMLESS_API_URL, hf_token=HF_API, serialize=False)
output_text = ""
for turn, _, speaker in diarized.itertracks(yield_label=True):
print(speaker, turn)
try:
filename = f"{turn.start}_segment.wav"
clipped = song[turn.start * 1000 : turn.end * 1000]
clipped.export(filename, format="wav", bitrate=16000)
# result = client.predict(f"my.wav", target_language, api_name="/asr")
result = automatic_speech_recognition(target_language, filename)
current_text = f"speaker: {speaker} text: {result} "
print(current_text)
if current_text is not None:
output_text = output_text + "\n" + current_text
yield output_text
except Exception as e:
print(e)
def automatic_speech_recognition(language, filename):
match language:
case "French":
api_url = "https://api-inference.huggingface.co/models/bofenghuang/whisper-large-v3-french"
case "English":
api_url = "https://api-inference.huggingface.co/models/facebook/wav2vec2-base-960h"
case _:
return f"Unknown language {language}"
print(f"-> automatic_speech_recognition with {api_url}")
with open(filename, "rb") as f:
data = f.read()
response = requests.post(
api_url, headers={"Authorization": f"Bearer {HF_API}"}, data=data
)
print(response.json())
return response.json()["text"]
def generate_summary_llama3(language, transcript):
queryTxt = f"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful and truthful patient-doctor encounter summary writer.
Users sends you transcripts of patient-doctor encounter and you create accurate and concise summaries.
The summary only contains informations from the transcript.
Your summary is written in {language}.
The summary only includes relevant sections.
<template>
# Chief Complaint
# History of Present Illness (HPI)
# Relevant Past Medical History
# Physical Examination
# Assessment and Plan
# Follow-up
# Additional Notes
</template> <|eot_id|>
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
<transcript>
{transcript}
</transcript><|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""
payload = {
"inputs": queryTxt,
"parameters": {
"return_full_text": False,
"wait_for_model": True,
"min_length": 1000,
},
"options": {"use_cache": False},
}
response = requests.post(
LLAMA_SUMMARY_URL, headers={"Authorization": f"Bearer {HF_API}"}, json=payload
)
print(response.json())
return response.json()[0]["generated_text"][len("<summary>") :]
def generate_summary_mistral(language, transcript):
sysPrompt = f"""<s>[INST]
You are a helpful and truthful patient-doctor encounter summary writer.
Users sends you transcripts of patient-doctor encounter and you create accurate and concise summaries.
The summary only contains informations from the transcript.
Your summary is written in {language}.
The summary only includes relevant sections.
<template>
# Chief Complaint
# History of Present Illness (HPI)
# Relevant Past Medical History
# Physical Examination
# Assessment and Plan
# Follow-up
# Additional Notes
</template>
"""
queryTxt = f"""
<transcript>
{transcript}
</transcript>
[/INST]
"""
payload = {
"inputs": sysPrompt + queryTxt,
"parameters": {
"return_full_text": False,
"wait_for_model": True,
"min_length": 1000,
},
"options": {"use_cache": False},
}
response = requests.post(
MISTRAL_SUMMARY_URL, headers={"Authorization": f"Bearer {HF_API}"}, json=payload
)
print(response.json())
return response.json()[0]["generated_text"][len("<summary>") :]
def generate_summary(model, language, transcript):
match model:
case "Mistral-7B":
print("-> summarize with mistral")
return generate_summary_mistral(language, transcript)
case "LLAMA3":
print("-> summarize with llama3")
return generate_summary_llama3(language, transcript)
case _:
return f"Unknown model {model}"
def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
mic = audio_source == "microphone"
return (
gr.update(visible=mic, value=None), # input_audio_mic
gr.update(visible=not mic, value=None), # input_audio_file
)
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
with gr.Row():
target_language = gr.Dropdown(
choices=["French", "English"],
label="Output Language",
value="French",
interactive=True,
info="Select your target language",
)
with gr.Row() as audio_box:
input_audio = gr.Audio(type="filepath")
submit = gr.Button("Transcribe")
transcribe_output = gr.Textbox(
label="Transcribed Text",
value="",
interactive=False,
lines=10,
scale=10,
max_lines=100,
)
submit.click(
fn=predict,
inputs=[target_language, input_audio],
outputs=[transcribe_output],
api_name="predict",
)
with gr.Row():
sumary_model = gr.Dropdown(
choices=["Mistral-7B", "LLAMA3"],
label="Summary model",
value="Mistral-7B",
interactive=True,
info="Select your summary model",
)
summarize = gr.Button("Summarize")
summary_output = gr.Textbox(
label="Summarized Text",
value="",
interactive=False,
lines=10,
scale=10,
max_lines=100,
)
summarize.click(
fn=generate_summary,
inputs=[sumary_model, target_language, transcribe_output],
outputs=[summary_output],
api_name="predict",
)
gr.Markdown(DUPLICATE)
demo.queue(max_size=50).launch()