import gradio as gr import nltk import numpy as np import pandas as pd from librosa import load, resample from sentence_transformers import SentenceTransformer, util from transformers import pipeline # Constants filename = "df10k_SP500_2020.csv.zip" model_name = "sentence-transformers/msmarco-distilbert-base-v4" max_sequence_length = 512 embeddings_filename = "df10k_embeddings_msmarco-distilbert-base-v4.npz" asr_model = "facebook/wav2vec2-xls-r-300m-21-to-en" # Load corpus df = pd.read_csv(filename) df.drop_duplicates(inplace=True) print(f"Number of documents: {len(df)}") nltk.download("punkt") corpus = [] sentence_count = [] for _, row in df.iterrows(): # We're interested in the 'mdna' column: 'Management discussion and analysis' sentences = nltk.tokenize.sent_tokenize(str(row["mdna"]), language="english") sentence_count.append(len(sentences)) for _, s in enumerate(sentences): corpus.append(s) print(f"Number of sentences: {len(corpus)}") # Load pre-embedded corpus corpus_embeddings = np.load(embeddings_filename)["arr_0"] print(f"Number of embeddings: {corpus_embeddings.shape[0]}") # Load embedding model model = SentenceTransformer(model_name) model.max_seq_length = max_sequence_length # Load speech to text model asr = pipeline( "automatic-speech-recognition", model=asr_model, feature_extractor=asr_model ) def find_sentences(query, hits): query_embedding = model.encode(query) hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits) hits = hits[0] output = pd.DataFrame( columns=["Ticker", "Form type", "Filing date", "Text", "Score"] ) for hit in hits: corpus_id = hit["corpus_id"] # Find source document based on sentence index count = 0 for idx, c in enumerate(sentence_count): count += c if corpus_id > count - 1: continue else: doc = df.iloc[idx] new_row = { "Ticker": doc["ticker"], "Form type": doc["form_type"], "Filing date": doc["filing_date"], "Text": corpus[corpus_id][:80], "Score": "{:.2f}".format(hit["score"]), } output = output.append(new_row, ignore_index=True) break return output def process(input_selection, query, filepath, hits): if input_selection == "speech": speech, sampling_rate = load(filepath) if sampling_rate != 16000: speech = resample(speech, orig_sr=sampling_rate, target_sr=16000) text = asr(speech)["text"] else: text = query return text, find_sentences(text, hits) # Gradio inputs buttons = gr.Radio( ["text", "speech"], type="value", value="speech", label="Input selection" ) text_query = gr.Textbox( lines=1, label="Text input", value="The company is under investigation by tax authorities for potential fraud.", ) mic = gr.Audio( source="microphone", type="filepath", label="Speech input", optional=True ) slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of hits") # Gradio outputs speech_query = gr.Textbox(type="text", label="Query string") results = gr.Dataframe( type="pandas", headers=["Ticker", "Form type", "Filing date", "Text", "Score"], label="Query results", ) iface = gr.Interface( theme="huggingface", description="This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages. You can find a technical deep dive at https://www.youtube.com/watch?v=YPme-gR0f80", fn=process, inputs=[buttons, text_query, mic, slider], outputs=[speech_query, results], examples=[ [ "speech", "Nos ventes internationales ont significativement augmenté.", "sales_16k_fr.wav", 3, ], [ "speech", "Le prix de l'énergie pourrait avoir un impact négatif dans le futur.", "energy_16k_fr.wav", 3, ], [ "speech", "El precio de la energía podría tener un impacto negativo en el futuro.", "energy_24k_es.wav", 3, ], [ "speech", "Mehrere Steuerbehörden untersuchen unser Unternehmen.", "tax_24k_de.wav", 3, ], ], ) iface.launch()