File size: 4,376 Bytes
70ec0d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15763b2
 
70ec0d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15763b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import nltk
import pickle
import pandas as pd
import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
from librosa import load, resample

# Constants
filename = 'df10k_SP500_2020.csv.zip'

model_name = 'sentence-transformers/msmarco-distilbert-base-v4'
max_sequence_length = 512
embeddings_filename = 'df10k_embeddings_msmarco-distilbert-base-v4.npz'
asr_model = 'facebook/wav2vec2-xls-r-300m-21-to-en'

# Load corpus
df = pd.read_csv(filename)
df.drop_duplicates(inplace=True)
print(f'Number of documents: {len(df)}')

nltk.download('punkt')

corpus = []
sentence_count = []
for _, row in df.iterrows():
    # We're interested in the 'mdna' column: 'Management discussion and analysis'
    sentences = nltk.tokenize.sent_tokenize(str(row['mdna']), language='english')
    sentence_count.append(len(sentences))
    for _,s in enumerate(sentences):
        corpus.append(s)
print(f'Number of sentences: {len(corpus)}')

# Load pre-embedded corpus
corpus_embeddings = np.load(embeddings_filename)['arr_0']
print(f'Number of embeddings: {corpus_embeddings.shape[0]}')

# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length

# Load speech to text model
asr = pipeline('automatic-speech-recognition', model=asr_model, feature_extractor=asr_model)

def find_sentences(query, hits):
    query_embedding = model.encode(query)
    hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
    hits = hits[0]

    output = pd.DataFrame(columns=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'])
    for hit in hits:
        corpus_id = hit['corpus_id']
        # Find source document based on sentence index
        count = 0
        for idx, c in enumerate(sentence_count):
            count+=c
            if (corpus_id > count-1):
                continue
            else:
                doc = df.iloc[idx]
                new_row = { 
                	'Ticker'     : doc['ticker'], 
                	'Form type'  : doc['form_type'],
                	'Filing date': doc['filing_date'], 
                	'Text'       : corpus[corpus_id],
                	'Score'      : '{:.2f}'.format(hit['score'])
                }
                output = output.append(new_row, ignore_index=True)
                break
    return output


def process(input_selection, query, filepath, hits):
	if input_selection=='speech':
		speech, sampling_rate = load(filepath)
		if sampling_rate != 16000:
			speech = resample(speech, sampling_rate, 16000)
		text = asr(speech)['text']
	else:
		text = query
	return text, find_sentences(text, hits)

# Gradio inputs
buttons    = gr.inputs.Radio(['text','speech'], type='value', default='speech', label='Input selection')
text_query = gr.inputs.Textbox(lines=1, label='Text input', default='The company is under investigation by tax authorities for potential fraud.')
mic        = gr.inputs.Audio(source='microphone', type='filepath', label='Speech input', optional=True)
slider     = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label='Number of hits')

# Gradio outputs
speech_query = gr.outputs.Textbox(type='auto', label='Query string')
results      = gr.outputs.Dataframe(
				headers=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'], 
				label='Query results')

iface = gr.Interface(
	theme='huggingface',
	description='This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages.',
	fn=process,
	inputs=[buttons,text_query,mic,slider], 
	outputs=[speech_query, results],
	examples=[
		['text', "The company is under investigation by tax authorities for potential fraud.", 'dummy.wav', 3],
		['text', "How much money does Microsoft make with Azure?", 'dummy.wav', 3],
        ['speech', "Nos ventes internationales ont significativement augmenté.", 'sales_16k_fr.wav', 3],
        ['speech', "Le prix de l'énergie pourrait avoir un impact négatif dans le futur.", 'energy_16k_fr.wav', 3],
        ['speech', "El precio de la energía podría tener un impacto negativo en el futuro.", 'energy_24k_es.wav', 3],
        ['speech', "Mehrere Steuerbehörden untersuchen unser Unternehmen.", 'tax_24k_de.wav', 3]
    ],
    allow_flagging=False
)
iface.launch()