juliensimon HF staff commited on
Commit
59b8055
1 Parent(s): 0d6c6e0

Update code

Browse files
Files changed (2) hide show
  1. app.py +88 -57
  2. dummy.wav +0 -0
app.py CHANGED
@@ -1,70 +1,74 @@
1
- import nltk
2
- import pickle
3
- import pandas as pd
4
  import gradio as gr
 
5
  import numpy as np
 
 
6
  from sentence_transformers import SentenceTransformer, util
7
  from transformers import pipeline
8
- from librosa import load, resample
9
 
10
  # Constants
11
- filename = 'df10k_SP500_2020.csv.zip'
12
 
13
- model_name = 'sentence-transformers/msmarco-distilbert-base-v4'
14
  max_sequence_length = 512
15
- embeddings_filename = 'df10k_embeddings_msmarco-distilbert-base-v4.npz'
16
- asr_model = 'facebook/wav2vec2-xls-r-300m-21-to-en'
17
 
18
  # Load corpus
19
  df = pd.read_csv(filename)
20
  df.drop_duplicates(inplace=True)
21
- print(f'Number of documents: {len(df)}')
22
 
23
- nltk.download('punkt')
24
 
25
  corpus = []
26
  sentence_count = []
27
  for _, row in df.iterrows():
28
  # We're interested in the 'mdna' column: 'Management discussion and analysis'
29
- sentences = nltk.tokenize.sent_tokenize(str(row['mdna']), language='english')
30
  sentence_count.append(len(sentences))
31
- for _,s in enumerate(sentences):
32
  corpus.append(s)
33
- print(f'Number of sentences: {len(corpus)}')
34
 
35
  # Load pre-embedded corpus
36
- corpus_embeddings = np.load(embeddings_filename)['arr_0']
37
- print(f'Number of embeddings: {corpus_embeddings.shape[0]}')
38
 
39
  # Load embedding model
40
  model = SentenceTransformer(model_name)
41
  model.max_seq_length = max_sequence_length
42
 
43
  # Load speech to text model
44
- asr = pipeline('automatic-speech-recognition', model=asr_model, feature_extractor=asr_model)
 
 
 
45
 
46
  def find_sentences(query, hits):
47
  query_embedding = model.encode(query)
48
  hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
49
  hits = hits[0]
50
 
51
- output = pd.DataFrame(columns=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'])
 
 
52
  for hit in hits:
53
- corpus_id = hit['corpus_id']
54
  # Find source document based on sentence index
55
  count = 0
56
  for idx, c in enumerate(sentence_count):
57
- count+=c
58
- if (corpus_id > count-1):
59
  continue
60
  else:
61
  doc = df.iloc[idx]
62
- new_row = {
63
- 'Ticker' : doc['ticker'],
64
- 'Form type' : doc['form_type'],
65
- 'Filing date': doc['filing_date'],
66
- 'Text' : corpus[corpus_id][:80],
67
- 'Score' : '{:.2f}'.format(hit['score'])
68
  }
69
  output = output.append(new_row, ignore_index=True)
70
  break
@@ -72,43 +76,70 @@ def find_sentences(query, hits):
72
 
73
 
74
  def process(input_selection, query, filepath, hits):
75
- if input_selection=='speech':
76
- speech, sampling_rate = load(filepath)
77
- if sampling_rate != 16000:
78
- speech = resample(speech, sampling_rate, 16000)
79
- text = asr(speech)['text']
80
- else:
81
- text = query
82
- return text, find_sentences(text, hits)
 
83
 
84
  # Gradio inputs
85
- buttons = gr.Radio(['text','speech'], type='value', default='speech', label='Input selection')
86
- text_query = gr.Textbox(lines=1, label='Text input', default='The company is under investigation by tax authorities for potential fraud.')
87
- mic = gr.Audio(source='microphone', type='filepath', label='Speech input', optional=True)
88
- slider = gr.Slider(minimum=1, maximum=10, step=1, default=3, label='Number of hits')
 
 
 
 
 
 
 
 
89
 
90
  # Gradio outputs
91
- speech_query = gr.Textbox(type='text', label='Query string')
92
- results = gr.Dataframe(
93
- type='pandas',
94
- headers=['Ticker', 'Form type', 'Filing date', 'Text', 'Score'],
95
- label='Query results')
 
96
 
97
  iface = gr.Interface(
98
- theme='huggingface',
99
- description='This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages. You can find a technical deep dive at https://www.youtube.com/watch?v=YPme-gR0f80',
100
- fn=process,
101
- layout='horizontal',
102
- inputs=[buttons,text_query,mic,slider],
103
- outputs=[speech_query, results],
104
- examples=[
105
- ['text', "The company is under investigation by tax authorities for potential fraud.", 'dummy.wav', 3],
106
- ['text', "How much money does Microsoft make with Azure?", 'dummy.wav', 3],
107
- ['speech', "Nos ventes internationales ont significativement augmenté.", 'sales_16k_fr.wav', 3],
108
- ['speech', "Le prix de l'énergie pourrait avoir un impact négatif dans le futur.", 'energy_16k_fr.wav', 3],
109
- ['speech', "El precio de la energía podría tener un impacto negativo en el futuro.", 'energy_24k_es.wav', 3],
110
- ['speech', "Mehrere Steuerbehörden untersuchen unser Unternehmen.", 'tax_24k_de.wav', 3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  ],
112
- allow_flagging=False
113
  )
114
  iface.launch()
 
 
 
 
1
  import gradio as gr
2
+ import nltk
3
  import numpy as np
4
+ import pandas as pd
5
+ from librosa import load, resample
6
  from sentence_transformers import SentenceTransformer, util
7
  from transformers import pipeline
 
8
 
9
  # Constants
10
+ filename = "df10k_SP500_2020.csv.zip"
11
 
12
+ model_name = "sentence-transformers/msmarco-distilbert-base-v4"
13
  max_sequence_length = 512
14
+ embeddings_filename = "df10k_embeddings_msmarco-distilbert-base-v4.npz"
15
+ asr_model = "facebook/wav2vec2-xls-r-300m-21-to-en"
16
 
17
  # Load corpus
18
  df = pd.read_csv(filename)
19
  df.drop_duplicates(inplace=True)
20
+ print(f"Number of documents: {len(df)}")
21
 
22
+ nltk.download("punkt")
23
 
24
  corpus = []
25
  sentence_count = []
26
  for _, row in df.iterrows():
27
  # We're interested in the 'mdna' column: 'Management discussion and analysis'
28
+ sentences = nltk.tokenize.sent_tokenize(str(row["mdna"]), language="english")
29
  sentence_count.append(len(sentences))
30
+ for _, s in enumerate(sentences):
31
  corpus.append(s)
32
+ print(f"Number of sentences: {len(corpus)}")
33
 
34
  # Load pre-embedded corpus
35
+ corpus_embeddings = np.load(embeddings_filename)["arr_0"]
36
+ print(f"Number of embeddings: {corpus_embeddings.shape[0]}")
37
 
38
  # Load embedding model
39
  model = SentenceTransformer(model_name)
40
  model.max_seq_length = max_sequence_length
41
 
42
  # Load speech to text model
43
+ asr = pipeline(
44
+ "automatic-speech-recognition", model=asr_model, feature_extractor=asr_model
45
+ )
46
+
47
 
48
  def find_sentences(query, hits):
49
  query_embedding = model.encode(query)
50
  hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
51
  hits = hits[0]
52
 
53
+ output = pd.DataFrame(
54
+ columns=["Ticker", "Form type", "Filing date", "Text", "Score"]
55
+ )
56
  for hit in hits:
57
+ corpus_id = hit["corpus_id"]
58
  # Find source document based on sentence index
59
  count = 0
60
  for idx, c in enumerate(sentence_count):
61
+ count += c
62
+ if corpus_id > count - 1:
63
  continue
64
  else:
65
  doc = df.iloc[idx]
66
+ new_row = {
67
+ "Ticker": doc["ticker"],
68
+ "Form type": doc["form_type"],
69
+ "Filing date": doc["filing_date"],
70
+ "Text": corpus[corpus_id][:80],
71
+ "Score": "{:.2f}".format(hit["score"]),
72
  }
73
  output = output.append(new_row, ignore_index=True)
74
  break
 
76
 
77
 
78
  def process(input_selection, query, filepath, hits):
79
+ if input_selection == "speech":
80
+ speech, sampling_rate = load(filepath)
81
+ if sampling_rate != 16000:
82
+ speech = resample(speech, orig_sr=sampling_rate, target_sr=16000)
83
+ text = asr(speech)["text"]
84
+ else:
85
+ text = query
86
+ return text, find_sentences(text, hits)
87
+
88
 
89
  # Gradio inputs
90
+ buttons = gr.Radio(
91
+ ["text", "speech"], type="value", value="speech", label="Input selection"
92
+ )
93
+ text_query = gr.Textbox(
94
+ lines=1,
95
+ label="Text input",
96
+ value="The company is under investigation by tax authorities for potential fraud.",
97
+ )
98
+ mic = gr.Audio(
99
+ source="microphone", type="filepath", label="Speech input", optional=True
100
+ )
101
+ slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of hits")
102
 
103
  # Gradio outputs
104
+ speech_query = gr.Textbox(type="text", label="Query string")
105
+ results = gr.Dataframe(
106
+ type="pandas",
107
+ headers=["Ticker", "Form type", "Filing date", "Text", "Score"],
108
+ label="Query results",
109
+ )
110
 
111
  iface = gr.Interface(
112
+ theme="huggingface",
113
+ description="This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages. You can find a technical deep dive at https://www.youtube.com/watch?v=YPme-gR0f80",
114
+ fn=process,
115
+ inputs=[buttons, text_query, mic, slider],
116
+ outputs=[speech_query, results],
117
+ examples=[
118
+ [
119
+ "speech",
120
+ "Nos ventes internationales ont significativement augmenté.",
121
+ "sales_16k_fr.wav",
122
+ 3,
123
+ ],
124
+ [
125
+ "speech",
126
+ "Le prix de l'énergie pourrait avoir un impact négatif dans le futur.",
127
+ "energy_16k_fr.wav",
128
+ 3,
129
+ ],
130
+ [
131
+ "speech",
132
+ "El precio de la energía podría tener un impacto negativo en el futuro.",
133
+ "energy_24k_es.wav",
134
+ 3,
135
+ ],
136
+ [
137
+ "speech",
138
+ "Mehrere Steuerbehörden untersuchen unser Unternehmen.",
139
+ "tax_24k_de.wav",
140
+ 3,
141
+ ],
142
  ],
143
+ allow_flagging=False,
144
  )
145
  iface.launch()
dummy.wav DELETED
File without changes