voice-queries / app.py
Julien Simon
Truncate sentences
d0dbd25
raw
history blame
4.49 kB
import nltk
import pickle
import pandas as pd
import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
from librosa import load, resample
# Constants
filename = 'df10k_SP500_2020.csv.zip'
model_name = 'sentence-transformers/msmarco-distilbert-base-v4'
max_sequence_length = 512
embeddings_filename = 'df10k_embeddings_msmarco-distilbert-base-v4.npz'
asr_model = 'facebook/wav2vec2-xls-r-300m-21-to-en'
# Load corpus
df = pd.read_csv(filename)
df.drop_duplicates(inplace=True)
print(f'Number of documents: {len(df)}')
nltk.download('punkt')
corpus = []
sentence_count = []
for _, row in df.iterrows():
# We're interested in the 'mdna' column: 'Management discussion and analysis'
sentences = nltk.tokenize.sent_tokenize(str(row['mdna']), language='english')
sentence_count.append(len(sentences))
for _,s in enumerate(sentences):
corpus.append(s)
print(f'Number of sentences: {len(corpus)}')
# Load pre-embedded corpus
corpus_embeddings = np.load(embeddings_filename)['arr_0']
print(f'Number of embeddings: {corpus_embeddings.shape[0]}')
# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length
# Load speech to text model
asr = pipeline('automatic-speech-recognition', model=asr_model, feature_extractor=asr_model)
def find_sentences(query, hits):
query_embedding = model.encode(query)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
hits = hits[0]
output = pd.DataFrame(columns=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'])
for hit in hits:
corpus_id = hit['corpus_id']
# Find source document based on sentence index
count = 0
for idx, c in enumerate(sentence_count):
count+=c
if (corpus_id > count-1):
continue
else:
doc = df.iloc[idx]
new_row = {
'Ticker' : doc['ticker'],
'Form type' : doc['form_type'],
'Filing date': doc['filing_date'],
'Text' : corpus[corpus_id][:80],
'Score' : '{:.2f}'.format(hit['score'])
}
output = output.append(new_row, ignore_index=True)
break
return output
def process(input_selection, query, filepath, hits):
if input_selection=='speech':
speech, sampling_rate = load(filepath)
if sampling_rate != 16000:
speech = resample(speech, sampling_rate, 16000)
text = asr(speech)['text']
else:
text = query
return text, find_sentences(text, hits)
# Gradio inputs
buttons = gr.inputs.Radio(['text','speech'], type='value', default='speech', label='Input selection')
text_query = gr.inputs.Textbox(lines=1, label='Text input', default='The company is under investigation by tax authorities for potential fraud.')
mic = gr.inputs.Audio(source='microphone', type='filepath', label='Speech input', optional=True)
slider = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label='Number of hits')
# Gradio outputs
speech_query = gr.outputs.Textbox(type='auto', label='Query string')
results = gr.outputs.Dataframe(
headers=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'],
label='Query results')
iface = gr.Interface(
theme='huggingface',
description='This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages. You can find a technical deep dive at https://www.youtube.com/watch?v=YPme-gR0f80',
fn=process,
layout='horizontal',
inputs=[buttons,text_query,mic,slider],
outputs=[speech_query, results],
examples=[
['text', "The company is under investigation by tax authorities for potential fraud.", 'dummy.wav', 3],
['text', "How much money does Microsoft make with Azure?", 'dummy.wav', 3],
['speech', "Nos ventes internationales ont significativement augmenté.", 'sales_16k_fr.wav', 3],
['speech', "Le prix de l'énergie pourrait avoir un impact négatif dans le futur.", 'energy_16k_fr.wav', 3],
['speech', "El precio de la energía podría tener un impacto negativo en el futuro.", 'energy_24k_es.wav', 3],
['speech', "Mehrere Steuerbehörden untersuchen unser Unternehmen.", 'tax_24k_de.wav', 3]
],
allow_flagging=False
)
iface.launch()