Julien Simon commited on
Commit
70ec0d7
1 Parent(s): dad2516

Initial version

Browse files
.gitattributes CHANGED
@@ -23,5 +23,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
 
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.npz filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Voice Queries
3
- emoji: 💩
4
- colorFrom: pink
5
- colorTo: indigo
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
 
1
  ---
2
  title: Voice Queries
3
+ emoji: 🐢
4
+ colorFrom: green
5
+ colorTo: yellow
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import pickle
3
+ import pandas as pd
4
+ import gradio as gr
5
+ import numpy as np
6
+ from sentence_transformers import SentenceTransformer, util
7
+ from transformers import pipeline
8
+ from librosa import load, resample
9
+
10
+ # Constants
11
+ filename = 'df10k_SP500_2020.csv.zip'
12
+
13
+ model_name = 'sentence-transformers/msmarco-distilbert-base-v4'
14
+ max_sequence_length = 512
15
+ embeddings_filename = 'df10k_embeddings_msmarco-distilbert-base-v4.npz'
16
+ asr_model = 'facebook/wav2vec2-xls-r-300m-21-to-en'
17
+
18
+ # Load corpus
19
+ df = pd.read_csv(filename)
20
+ df.drop_duplicates(inplace=True)
21
+ print(f'Number of documents: {len(df)}')
22
+
23
+ corpus = []
24
+ sentence_count = []
25
+ for _, row in df.iterrows():
26
+ # We're interested in the 'mdna' column: 'Management discussion and analysis'
27
+ sentences = nltk.tokenize.sent_tokenize(str(row['mdna']), language='english')
28
+ sentence_count.append(len(sentences))
29
+ for _,s in enumerate(sentences):
30
+ corpus.append(s)
31
+ print(f'Number of sentences: {len(corpus)}')
32
+
33
+ # Load pre-embedded corpus
34
+ corpus_embeddings = np.load(embeddings_filename)['arr_0']
35
+ print(f'Number of embeddings: {corpus_embeddings.shape[0]}')
36
+
37
+ # Load embedding model
38
+ model = SentenceTransformer(model_name)
39
+ model.max_seq_length = max_sequence_length
40
+
41
+ # Load speech to text model
42
+ asr = pipeline('automatic-speech-recognition', model=asr_model, feature_extractor=asr_model)
43
+
44
+ def find_sentences(query, hits):
45
+ query_embedding = model.encode(query)
46
+ hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
47
+ hits = hits[0]
48
+
49
+ output = pd.DataFrame(columns=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'])
50
+ for hit in hits:
51
+ corpus_id = hit['corpus_id']
52
+ # Find source document based on sentence index
53
+ count = 0
54
+ for idx, c in enumerate(sentence_count):
55
+ count+=c
56
+ if (corpus_id > count-1):
57
+ continue
58
+ else:
59
+ doc = df.iloc[idx]
60
+ new_row = {
61
+ 'Ticker' : doc['ticker'],
62
+ 'Form type' : doc['form_type'],
63
+ 'Filing date': doc['filing_date'],
64
+ 'Text' : corpus[corpus_id],
65
+ 'Score' : '{:.2f}'.format(hit['score'])
66
+ }
67
+ output = output.append(new_row, ignore_index=True)
68
+ break
69
+ return output
70
+
71
+
72
+ def process(input_selection, query, filepath, hits):
73
+ if input_selection=='speech':
74
+ speech, sampling_rate = load(filepath)
75
+ if sampling_rate != 16000:
76
+ speech = resample(speech, sampling_rate, 16000)
77
+ text = asr(speech)['text']
78
+ else:
79
+ text = query
80
+ return text, find_sentences(text, hits)
81
+
82
+ # Gradio inputs
83
+ buttons = gr.inputs.Radio(['text','speech'], type='value', default='speech', label='Input selection')
84
+ text_query = gr.inputs.Textbox(lines=1, label='Text input', default='The company is under investigation by tax authorities for potential fraud.')
85
+ mic = gr.inputs.Audio(source='microphone', type='filepath', label='Speech input', optional=True)
86
+ slider = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label='Number of hits')
87
+
88
+ # Gradio outputs
89
+ speech_query = gr.outputs.Textbox(type='auto', label='Query string')
90
+ results = gr.outputs.Dataframe(
91
+ headers=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'],
92
+ label='Query results')
93
+
94
+ iface = gr.Interface(
95
+ theme='huggingface',
96
+ description='This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages.',
97
+ fn=process,
98
+ inputs=[buttons,text_query,mic,slider],
99
+ outputs=[speech_query, results],
100
+ examples=[
101
+ ['text', "The company is under investigation by tax authorities for potential fraud.", 'dummy.wav', 3],
102
+ ['text', "How much money does Microsoft make with Azure?", 'dummy.wav', 3],
103
+ ['speech', "Nos ventes internationales ont significativement augmenté.", 'sales_16k_fr.wav', 3],
104
+ ['speech', "Le prix de l'énergie pourrait avoir un impact négatif dans le futur.", 'energy_16k_fr.wav', 3],
105
+ ['speech', "El precio de la energía podría tener un impacto negativo en el futuro.", 'energy_24k_es.wav', 3],
106
+ ['speech', "Mehrere Steuerbehörden untersuchen unser Unternehmen.", 'tax_24k_de.wav', 3]
107
+ ],
108
+ allow_flagging=False
109
+ )
110
+ iface.launch()
df10k_SP500_2020.csv.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:984d7a036dba3c32e176e6609392909b005bc5ac030de24427f0982c88aaaf0d
3
+ size 134796242
df10k_embeddings_msmarco-distilbert-base-v4.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa084a851fd82187d2e0aa1ab72fb4e1ac127a07cbcc3597a551719850b8b25d
3
+ size 526747035
dummy.wav ADDED
File without changes
energy_16k_fr.wav ADDED
Binary file (156 kB). View file
 
energy_24k_es.wav ADDED
Binary file (182 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ nltk
4
+ pandas
5
+ numpy
6
+ sentence-transformers
7
+ librosa
sales_16k_fr.wav ADDED
Binary file (136 kB). View file
 
tax_24k_de.wav ADDED
Binary file (152 kB). View file