File size: 4,578 Bytes
70ec0d7
59b8055
70ec0d7
59b8055
 
70ec0d7
 
 
 
59b8055
70ec0d7
59b8055
70ec0d7
59b8055
 
70ec0d7
 
 
 
59b8055
70ec0d7
59b8055
15763b2
70ec0d7
 
 
 
59b8055
70ec0d7
59b8055
70ec0d7
59b8055
70ec0d7
 
59b8055
 
70ec0d7
 
 
 
 
 
59b8055
 
 
 
70ec0d7
 
 
 
 
 
59b8055
 
 
70ec0d7
59b8055
70ec0d7
 
 
59b8055
 
70ec0d7
 
 
59b8055
 
 
 
 
 
70ec0d7
ea11adf
70ec0d7
 
 
 
 
59b8055
 
 
 
 
 
 
 
 
70ec0d7
 
59b8055
 
 
 
 
 
 
 
 
 
 
 
70ec0d7
 
59b8055
 
 
 
 
 
70ec0d7
 
59b8055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70ec0d7
 
15763b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import nltk
import numpy as np
import pandas as pd
from librosa import load, resample
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline

# Constants
filename = "df10k_SP500_2020.csv.zip"

model_name = "sentence-transformers/msmarco-distilbert-base-v4"
max_sequence_length = 512
embeddings_filename = "df10k_embeddings_msmarco-distilbert-base-v4.npz"
asr_model = "facebook/wav2vec2-xls-r-300m-21-to-en"

# Load corpus
df = pd.read_csv(filename)
df.drop_duplicates(inplace=True)
print(f"Number of documents: {len(df)}")

nltk.download("punkt")

corpus = []
sentence_count = []
for _, row in df.iterrows():
    # We're interested in the 'mdna' column: 'Management discussion and analysis'
    sentences = nltk.tokenize.sent_tokenize(str(row["mdna"]), language="english")
    sentence_count.append(len(sentences))
    for _, s in enumerate(sentences):
        corpus.append(s)
print(f"Number of sentences: {len(corpus)}")

# Load pre-embedded corpus
corpus_embeddings = np.load(embeddings_filename)["arr_0"]
print(f"Number of embeddings: {corpus_embeddings.shape[0]}")

# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length

# Load speech to text model
asr = pipeline(
    "automatic-speech-recognition", model=asr_model, feature_extractor=asr_model
)


def find_sentences(query, hits):
    query_embedding = model.encode(query)
    hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
    hits = hits[0]

    output = pd.DataFrame(
        columns=["Ticker", "Form type", "Filing date", "Text", "Score"]
    )
    for hit in hits:
        corpus_id = hit["corpus_id"]
        # Find source document based on sentence index
        count = 0
        for idx, c in enumerate(sentence_count):
            count += c
            if corpus_id > count - 1:
                continue
            else:
                doc = df.iloc[idx]
                new_row = {
                    "Ticker": doc["ticker"],
                    "Form type": doc["form_type"],
                    "Filing date": doc["filing_date"],
                    "Text": corpus[corpus_id][:80],
                    "Score": "{:.2f}".format(hit["score"]),
                }
                output = pd.concat([output, pd.DataFrame([new_row])], ignore_index=True)
                break
    return output


def process(input_selection, query, filepath, hits):
    if input_selection == "speech":
        speech, sampling_rate = load(filepath)
        if sampling_rate != 16000:
            speech = resample(speech, orig_sr=sampling_rate, target_sr=16000)
        text = asr(speech)["text"]
    else:
        text = query
    return text, find_sentences(text, hits)


# Gradio inputs
buttons = gr.Radio(
    ["text", "speech"], type="value", value="speech", label="Input selection"
)
text_query = gr.Textbox(
    lines=1,
    label="Text input",
    value="The company is under investigation by tax authorities for potential fraud.",
)
mic = gr.Audio(
    source="microphone", type="filepath", label="Speech input", optional=True
)
slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of hits")

# Gradio outputs
speech_query = gr.Textbox(type="text", label="Query string")
results = gr.Dataframe(
    type="pandas",
    headers=["Ticker", "Form type", "Filing date", "Text", "Score"],
    label="Query results",
)

iface = gr.Interface(
    theme="huggingface",
    description="This Spaces lets you query a text corpus containing 2020 annual filings for all S&P500 companies. You can type a text query in English, or record an audio query in 21 languages. You can find a technical deep dive at https://www.youtube.com/watch?v=YPme-gR0f80",
    fn=process,
    inputs=[buttons, text_query, mic, slider],
    outputs=[speech_query, results],
    examples=[
        [
            "speech",
            "Nos ventes internationales ont significativement augmenté.",
            "sales_16k_fr.wav",
            3,
        ],
        [
            "speech",
            "Le prix de l'énergie pourrait avoir un impact négatif dans le futur.",
            "energy_16k_fr.wav",
            3,
        ],
        [
            "speech",
            "El precio de la energía podría tener un impacto negativo en el futuro.",
            "energy_24k_es.wav",
            3,
        ],
        [
            "speech",
            "Mehrere Steuerbehörden untersuchen unser Unternehmen.",
            "tax_24k_de.wav",
            3,
        ],
    ],
)
iface.launch()