voice-queries / app.py
juliensimon's picture
juliensimon HF staff
Initial version
70ec0d7
raw history blame
No virus
4.35 kB
import nltk
import pickle
import pandas as pd
import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
from librosa import load, resample
# Constants
filename = 'df10k_SP500_2020.csv.zip'
model_name = 'sentence-transformers/msmarco-distilbert-base-v4'
max_sequence_length = 512
embeddings_filename = 'df10k_embeddings_msmarco-distilbert-base-v4.npz'
asr_model = 'facebook/wav2vec2-xls-r-300m-21-to-en'
# Load corpus
df = pd.read_csv(filename)
df.drop_duplicates(inplace=True)
print(f'Number of documents: {len(df)}')
corpus = []
sentence_count = []
for _, row in df.iterrows():
# We're interested in the 'mdna' column: 'Management discussion and analysis'
sentences = nltk.tokenize.sent_tokenize(str(row['mdna']), language='english')
sentence_count.append(len(sentences))
for _,s in enumerate(sentences):
corpus.append(s)
print(f'Number of sentences: {len(corpus)}')
# Load pre-embedded corpus
corpus_embeddings = np.load(embeddings_filename)['arr_0']
print(f'Number of embeddings: {corpus_embeddings.shape[0]}')
# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length
# Load speech to text model
asr = pipeline('automatic-speech-recognition', model=asr_model, feature_extractor=asr_model)
def find_sentences(query, hits):
query_embedding = model.encode(query)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
hits = hits[0]
output = pd.DataFrame(columns=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'])
for hit in hits:
corpus_id = hit['corpus_id']
# Find source document based on sentence index
count = 0
for idx, c in enumerate(sentence_count):
count+=c
if (corpus_id > count-1):
continue
else:
doc = df.iloc[idx]
new_row = {
'Ticker' : doc['ticker'],
'Form type' : doc['form_type'],
'Filing date': doc['filing_date'],
'Text' : corpus[corpus_id],
'Score' : '{:.2f}'.format(hit['score'])
}
output = output.append(new_row, ignore_index=True)
break
return output
def process(input_selection, query, filepath, hits):
if input_selection=='speech':
speech, sampling_rate = load(filepath)
if sampling_rate != 16000:
speech = resample(speech, sampling_rate, 16000)
text = asr(speech)['text']
else:
text = query
return text, find_sentences(text, hits)
# Gradio inputs
buttons = gr.inputs.Radio(['text','speech'], type='value', default='speech', label='Input selection')
text_query = gr.inputs.Textbox(lines=1, label='Text input', default='The company is under investigation by tax authorities for potential fraud.')
mic = gr.inputs.Audio(source='microphone', type='filepath', label='Speech input', optional=True)
slider = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label='Number of hits')
# Gradio outputs
speech_query = gr.outputs.Textbox(type='auto', label='Query string')
results = gr.outputs.Dataframe(
headers=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'],
label='Query results')
iface = gr.Interface(
theme='huggingface',
description='This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages.',
fn=process,
inputs=[buttons,text_query,mic,slider],
outputs=[speech_query, results],
examples=[
['text', "The company is under investigation by tax authorities for potential fraud.", 'dummy.wav', 3],
['text', "How much money does Microsoft make with Azure?", 'dummy.wav', 3],
['speech', "Nos ventes internationales ont significativement augmenté.", 'sales_16k_fr.wav', 3],
['speech', "Le prix de l'énergie pourrait avoir un impact négatif dans le futur.", 'energy_16k_fr.wav', 3],
['speech', "El precio de la energía podría tener un impacto negativo en el futuro.", 'energy_24k_es.wav', 3],
['speech', "Mehrere Steuerbehörden untersuchen unser Unternehmen.", 'tax_24k_de.wav', 3]
],
allow_flagging=False
)
iface.launch()