import nltk import pickle import pandas as pd import gradio as gr import numpy as np from sentence_transformers import SentenceTransformer, util from transformers import pipeline from librosa import load, resample # Constants filename = 'df10k_SP500_2020.csv.zip' model_name = 'sentence-transformers/msmarco-distilbert-base-v4' max_sequence_length = 512 embeddings_filename = 'df10k_embeddings_msmarco-distilbert-base-v4.npz' asr_model = 'facebook/wav2vec2-xls-r-300m-21-to-en' # Load corpus df = pd.read_csv(filename) df.drop_duplicates(inplace=True) print(f'Number of documents: {len(df)}') nltk.download('punkt') corpus = [] sentence_count = [] for _, row in df.iterrows(): # We're interested in the 'mdna' column: 'Management discussion and analysis' sentences = nltk.tokenize.sent_tokenize(str(row['mdna']), language='english') sentence_count.append(len(sentences)) for _,s in enumerate(sentences): corpus.append(s) print(f'Number of sentences: {len(corpus)}') # Load pre-embedded corpus corpus_embeddings = np.load(embeddings_filename)['arr_0'] print(f'Number of embeddings: {corpus_embeddings.shape[0]}') # Load embedding model model = SentenceTransformer(model_name) model.max_seq_length = max_sequence_length # Load speech to text model asr = pipeline('automatic-speech-recognition', model=asr_model, feature_extractor=asr_model) def find_sentences(query, hits): query_embedding = model.encode(query) hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits) hits = hits[0] output = pd.DataFrame(columns=['Ticker', 'Form type', 'Filing date', 'Text', 'Score']) for hit in hits: corpus_id = hit['corpus_id'] # Find source document based on sentence index count = 0 for idx, c in enumerate(sentence_count): count+=c if (corpus_id > count-1): continue else: doc = df.iloc[idx] new_row = { 'Ticker' : doc['ticker'], 'Form type' : doc['form_type'], 'Filing date': doc['filing_date'], 'Text' : corpus[corpus_id], 'Score' : '{:.2f}'.format(hit['score']) } output = output.append(new_row, ignore_index=True) break return output def process(input_selection, query, filepath, hits): if input_selection=='speech': speech, sampling_rate = load(filepath) if sampling_rate != 16000: speech = resample(speech, sampling_rate, 16000) text = asr(speech)['text'] else: text = query return text, find_sentences(text, hits) # Gradio inputs buttons = gr.inputs.Radio(['text','speech'], type='value', default='speech', label='Input selection') text_query = gr.inputs.Textbox(lines=1, label='Text input', default='The company is under investigation by tax authorities for potential fraud.') mic = gr.inputs.Audio(source='microphone', type='filepath', label='Speech input', optional=True) slider = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label='Number of hits') # Gradio outputs speech_query = gr.outputs.Textbox(type='auto', label='Query string') results = gr.outputs.Dataframe( headers=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'], label='Query results') iface = gr.Interface( theme='huggingface', description='This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages. You can find a technical deep dive at https://www.youtube.com/watch?v=YPme-gR0f80', fn=process, inputs=[buttons,text_query,mic,slider], outputs=[speech_query, results], examples=[ ['text', "The company is under investigation by tax authorities for potential fraud.", 'dummy.wav', 3], ['text', "How much money does Microsoft make with Azure?", 'dummy.wav', 3], ['speech', "Nos ventes internationales ont significativement augmenté.", 'sales_16k_fr.wav', 3], ['speech', "Le prix de l'énergie pourrait avoir un impact négatif dans le futur.", 'energy_16k_fr.wav', 3], ['speech', "El precio de la energía podría tener un impacto negativo en el futuro.", 'energy_24k_es.wav', 3], ['speech', "Mehrere Steuerbehörden untersuchen unser Unternehmen.", 'tax_24k_de.wav', 3] ], allow_flagging=False ) iface.launch()