File size: 4,351 Bytes
70ec0d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import nltk
import pickle
import pandas as pd
import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
from librosa import load, resample

# Constants
filename = 'df10k_SP500_2020.csv.zip'

model_name = 'sentence-transformers/msmarco-distilbert-base-v4'
max_sequence_length = 512
embeddings_filename = 'df10k_embeddings_msmarco-distilbert-base-v4.npz'
asr_model = 'facebook/wav2vec2-xls-r-300m-21-to-en'

# Load corpus
df = pd.read_csv(filename)
df.drop_duplicates(inplace=True)
print(f'Number of documents: {len(df)}')

corpus = []
sentence_count = []
for _, row in df.iterrows():
    # We're interested in the 'mdna' column: 'Management discussion and analysis'
    sentences = nltk.tokenize.sent_tokenize(str(row['mdna']), language='english')
    sentence_count.append(len(sentences))
    for _,s in enumerate(sentences):
        corpus.append(s)
print(f'Number of sentences: {len(corpus)}')

# Load pre-embedded corpus
corpus_embeddings = np.load(embeddings_filename)['arr_0']
print(f'Number of embeddings: {corpus_embeddings.shape[0]}')

# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length

# Load speech to text model
asr = pipeline('automatic-speech-recognition', model=asr_model, feature_extractor=asr_model)

def find_sentences(query, hits):
    query_embedding = model.encode(query)
    hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
    hits = hits[0]

    output = pd.DataFrame(columns=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'])
    for hit in hits:
        corpus_id = hit['corpus_id']
        # Find source document based on sentence index
        count = 0
        for idx, c in enumerate(sentence_count):
            count+=c
            if (corpus_id > count-1):
                continue
            else:
                doc = df.iloc[idx]
                new_row = { 
                	'Ticker'     : doc['ticker'], 
                	'Form type'  : doc['form_type'],
                	'Filing date': doc['filing_date'], 
                	'Text'       : corpus[corpus_id],
                	'Score'      : '{:.2f}'.format(hit['score'])
                }
                output = output.append(new_row, ignore_index=True)
                break
    return output


def process(input_selection, query, filepath, hits):
	if input_selection=='speech':
		speech, sampling_rate = load(filepath)
		if sampling_rate != 16000:
			speech = resample(speech, sampling_rate, 16000)
		text = asr(speech)['text']
	else:
		text = query
	return text, find_sentences(text, hits)

# Gradio inputs
buttons    = gr.inputs.Radio(['text','speech'], type='value', default='speech', label='Input selection')
text_query = gr.inputs.Textbox(lines=1, label='Text input', default='The company is under investigation by tax authorities for potential fraud.')
mic        = gr.inputs.Audio(source='microphone', type='filepath', label='Speech input', optional=True)
slider     = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label='Number of hits')

# Gradio outputs
speech_query = gr.outputs.Textbox(type='auto', label='Query string')
results      = gr.outputs.Dataframe(
				headers=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'], 
				label='Query results')

iface = gr.Interface(
	theme='huggingface',
	description='This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages.',
	fn=process,
	inputs=[buttons,text_query,mic,slider], 
	outputs=[speech_query, results],
	examples=[
		['text', "The company is under investigation by tax authorities for potential fraud.", 'dummy.wav', 3],
		['text', "How much money does Microsoft make with Azure?", 'dummy.wav', 3],
        ['speech', "Nos ventes internationales ont significativement augmenté.", 'sales_16k_fr.wav', 3],
        ['speech', "Le prix de l'énergie pourrait avoir un impact négatif dans le futur.", 'energy_16k_fr.wav', 3],
        ['speech', "El precio de la energía podría tener un impacto negativo en el futuro.", 'energy_24k_es.wav', 3],
        ['speech', "Mehrere Steuerbehörden untersuchen unser Unternehmen.", 'tax_24k_de.wav', 3]
    ],
    allow_flagging=False
)
iface.launch()