abidlabs's picture
abidlabs HF staff
Update app.py
c8a961d
import gradio as gr
from transformers import pipeline, Wav2Vec2ProcessorWithLM
from pyannote.audio import Pipeline
from librosa import load, resample
from rpunct import RestorePuncts
# Audio components
asr_model = 'patrickvonplaten/wav2vec2-base-960h-4-gram'
processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model)
asr = pipeline('automatic-speech-recognition', model=asr_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, decoder=processor.decoder)
speaker_segmentation = Pipeline.from_pretrained("pyannote/speaker-segmentation")
rpunct = RestorePuncts()
# Text components
sentiment_pipeline = pipeline('text-classification', model="distilbert-base-uncased-finetuned-sst-2-english")
sentiment_threshold = 0.75
EXAMPLES = ["example_audio.wav"]
def speech_to_text(speech):
speaker_output = speaker_segmentation(speech)
speech, sampling_rate = load(speech)
if sampling_rate != 16000:
speech = resample(speech, sampling_rate, 16000)
text = asr(speech, return_timestamps="word")
full_text = text['text'].lower()
chunks = text['chunks']
diarized_output = []
i = 0
speaker_counter = 0
# New iteration every time the speaker changes
for turn, _, _ in speaker_output.itertracks(yield_label=True):
speaker = "Speaker 0" if speaker_counter % 2 == 0 else "Speaker 1"
diarized = ""
while i < len(chunks) and chunks[i]['timestamp'][1] <= turn.end:
diarized += chunks[i]['text'].lower() + ' '
i += 1
if diarized != "":
diarized = rpunct.punctuate(diarized)
diarized_output.extend([(diarized, speaker), ('from {:.2f}-{:.2f}'.format(turn.start, turn.end), None)])
speaker_counter += 1
return diarized_output, full_text
def sentiment(checked_options, diarized):
customer_id = checked_options
customer_sentiments = []
for transcript in diarized:
speaker_speech, speaker_id = transcript
if speaker_id == customer_id:
output = sentiment_pipeline(speaker_speech)[0]
if output["label"] != "neutral" and output["score"] > sentiment_threshold:
customer_sentiments.append((speaker_speech, output["label"]))
else:
customer_sentiments.append(speaker_speech, None)
return customer_sentiments
demo = gr.Blocks(enable_queue=True)
demo.encrypt = False
with demo:
with gr.Row():
with gr.Column():
audio = gr.Audio(label="Audio file", type='filepath')
with gr.Row():
btn = gr.Button("Transcribe")
with gr.Row():
examples = gr.components.Dataset(components=[audio], samples=[EXAMPLES], type="index")
with gr.Column():
gr.Markdown("**Diarized Output:**")
diarized = gr.HighlightedText(lines=5, label="Diarized Output")
full = gr.Textbox(lines=4, label="Full Transcript")
check = gr.Radio(["Speaker 0", "Speaker 1"], label='Choose speaker for sentiment analysis')
analyzed = gr.HighlightedText(label="Customer Sentiment")
btn.click(speech_to_text, audio, [diarized, full])
check.change(sentiment, [check, diarized], analyzed)
def cache_example(example):
processed_examples = audio.preprocess_example(example)
diarized_output, full_text = speech_to_text(example)
return processed_examples, diarized_output, full_text
cache = [cache_example(e) for e in EXAMPLES]
def load_example(example_id):
return cache[example_id]
examples._click_no_postprocess(load_example, inputs=[examples], outputs=[audio, diarized, full], queue=False)
demo.launch()