voice-queries / app.py
juliensimon's picture
juliensimon HF staff
Update code
59b8055
import gradio as gr
import nltk
import numpy as np
import pandas as pd
from librosa import load, resample
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
# Constants
filename = "df10k_SP500_2020.csv.zip"
model_name = "sentence-transformers/msmarco-distilbert-base-v4"
max_sequence_length = 512
embeddings_filename = "df10k_embeddings_msmarco-distilbert-base-v4.npz"
asr_model = "facebook/wav2vec2-xls-r-300m-21-to-en"
# Load corpus
df = pd.read_csv(filename)
df.drop_duplicates(inplace=True)
print(f"Number of documents: {len(df)}")
nltk.download("punkt")
corpus = []
sentence_count = []
for _, row in df.iterrows():
# We're interested in the 'mdna' column: 'Management discussion and analysis'
sentences = nltk.tokenize.sent_tokenize(str(row["mdna"]), language="english")
sentence_count.append(len(sentences))
for _, s in enumerate(sentences):
corpus.append(s)
print(f"Number of sentences: {len(corpus)}")
# Load pre-embedded corpus
corpus_embeddings = np.load(embeddings_filename)["arr_0"]
print(f"Number of embeddings: {corpus_embeddings.shape[0]}")
# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length
# Load speech to text model
asr = pipeline(
"automatic-speech-recognition", model=asr_model, feature_extractor=asr_model
)
def find_sentences(query, hits):
query_embedding = model.encode(query)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
hits = hits[0]
output = pd.DataFrame(
columns=["Ticker", "Form type", "Filing date", "Text", "Score"]
)
for hit in hits:
corpus_id = hit["corpus_id"]
# Find source document based on sentence index
count = 0
for idx, c in enumerate(sentence_count):
count += c
if corpus_id > count - 1:
continue
else:
doc = df.iloc[idx]
new_row = {
"Ticker": doc["ticker"],
"Form type": doc["form_type"],
"Filing date": doc["filing_date"],
"Text": corpus[corpus_id][:80],
"Score": "{:.2f}".format(hit["score"]),
}
output = output.append(new_row, ignore_index=True)
break
return output
def process(input_selection, query, filepath, hits):
if input_selection == "speech":
speech, sampling_rate = load(filepath)
if sampling_rate != 16000:
speech = resample(speech, orig_sr=sampling_rate, target_sr=16000)
text = asr(speech)["text"]
else:
text = query
return text, find_sentences(text, hits)
# Gradio inputs
buttons = gr.Radio(
["text", "speech"], type="value", value="speech", label="Input selection"
)
text_query = gr.Textbox(
lines=1,
label="Text input",
value="The company is under investigation by tax authorities for potential fraud.",
)
mic = gr.Audio(
source="microphone", type="filepath", label="Speech input", optional=True
)
slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of hits")
# Gradio outputs
speech_query = gr.Textbox(type="text", label="Query string")
results = gr.Dataframe(
type="pandas",
headers=["Ticker", "Form type", "Filing date", "Text", "Score"],
label="Query results",
)
iface = gr.Interface(
theme="huggingface",
description="This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages. You can find a technical deep dive at https://www.youtube.com/watch?v=YPme-gR0f80",
fn=process,
inputs=[buttons, text_query, mic, slider],
outputs=[speech_query, results],
examples=[
[
"speech",
"Nos ventes internationales ont significativement augmenté.",
"sales_16k_fr.wav",
3,
],
[
"speech",
"Le prix de l'énergie pourrait avoir un impact négatif dans le futur.",
"energy_16k_fr.wav",
3,
],
[
"speech",
"El precio de la energía podría tener un impacto negativo en el futuro.",
"energy_24k_es.wav",
3,
],
[
"speech",
"Mehrere Steuerbehörden untersuchen unser Unternehmen.",
"tax_24k_de.wav",
3,
],
],
allow_flagging=False,
)
iface.launch()