abidlabs's picture
abidlabs HF staff
Update app.py
b709d83
raw history blame
No virus
3.63 kB
import gradio as gr
from transformers import pipeline, Wav2Vec2ProcessorWithLM
from pyannote.audio import Pipeline
from librosa import load, resample
from rpunct import RestorePuncts
# Audio components
asr_model = 'patrickvonplaten/wav2vec2-base-960h-4-gram'
processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model)
asr = pipeline('automatic-speech-recognition', model=asr_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, decoder=processor.decoder)
speaker_segmentation = Pipeline.from_pretrained("pyannote/speaker-segmentation")
rpunct = RestorePuncts()
# Text components
sentiment_pipeline = pipeline('text-classification', model="distilbert-base-uncased-finetuned-sst-2-english")
sentiment_threshold = 0.75
EXAMPLES = ["example_audio.wav"]
def speech_to_text(speech):
speaker_output = speaker_segmentation(speech)
speech, sampling_rate = load(speech)
if sampling_rate != 16000:
speech = resample(speech, sampling_rate, 16000)
text = asr(speech, return_timestamps="word")
full_text = text['text'].lower()
chunks = text['chunks']
diarized_output = []
i = 0
speaker_counter = 0
# New iteration every time the speaker changes
for turn, _, _ in speaker_output.itertracks(yield_label=True):
speaker = "Speaker 0" if speaker_counter % 2 == 0 else "Speaker 1"
diarized = ""
while i < len(chunks) and chunks[i]['timestamp'][1] <= turn.end:
diarized += chunks[i]['text'].lower() + ' '
i += 1
if diarized != "":
diarized = rpunct.punctuate(diarized)
diarized_output.extend([(diarized, speaker), ('from {:.2f}-{:.2f}'.format(turn.start, turn.end), None)])
speaker_counter += 1
return diarized_output, full_text
def sentiment(checked_options, diarized):
customer_id = checked_options
customer_sentiments = []
for transcript in diarized:
speaker_speech, speaker_id = transcript
if speaker_id == customer_id:
output = sentiment_pipeline(speaker_speech)[0]
if output["label"] != "neutral" and output["score"] > sentiment_threshold:
customer_sentiments.append((speaker_speech, output["label"]))
else:
customer_sentiments.append(speaker_speech, None)
return customer_sentiments
demo = gr.Blocks()
demo.encrypt = False
with demo:
with gr.Row():
with gr.Column():
audio = gr.Audio(label="Audio file", type='filepath')
with gr.Row():
btn = gr.Button("Transcribe")
with gr.Row():
examples = gr.components.Dataset(components=[audio], samples=[EXAMPLES], type="index")
with gr.Column():
gr.Markdown("**Diarized Output:**")
diarized = gr.HighlightedText(lines=5, label="Diarized Output")
full = gr.Textbox(lines=4, label="Full Transcript")
check = gr.Radio(["Speaker 0", "Speaker 1"], label='Choose speaker for sentiment analysis')
analyzed = gr.HighlightedText(label="Customer Sentiment")
btn.click(speech_to_text, audio, [diarized, full], status_tracker=gr.StatusTracker(cover_container=True))
check.change(sentiment, [check, diarized], analyzed, status_tracker=gr.StatusTracker(cover_container=True))
def load_example(example_id):
processed_examples = audio.preprocess_example(EXAMPLES[example_id])
return processed_examples
examples._click_no_postprocess(load_example, inputs=[examples], outputs=[audio])
demo.launch()