File size: 3,034 Bytes
e693db5
 
 
 
 
 
 
 
 
 
c9a3e15
e693db5
 
 
 
c9a3e15
e693db5
 
 
c9a3e15
 
e693db5
c9a3e15
e693db5
 
 
c9a3e15
e693db5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9a3e15
 
e693db5
c9a3e15
e693db5
 
 
 
 
 
c9a3e15
e693db5
c9a3e15
 
 
e693db5
 
 
 
c9a3e15
 
 
 
 
 
 
 
 
e693db5
 
c9a3e15
 
 
 
 
 
 
 
 
 
 
 
e693db5
 
e58f18b
 
e693db5
 
c9a3e15
 
 
 
 
 
 
 
e693db5
c9a3e15
e693db5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import pickle
import pandas as pd
import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline, Wav2Vec2ProcessorWithLM
from librosa import load, resample

# Constants

model_name = "sentence-transformers/msmarco-distilbert-base-v4"
max_sequence_length = 512

# Load corpus
import subprocess

subprocess.run(["gdown", "1QVpyk_xyqNYrHT3NdUfBxbDV_eyCDa2Q"])
with open("embeddings.pkl", "rb") as fp:
    pickled_data = pickle.load(fp)
    sentences = pickled_data["sentences"]
    corpus_embeddings = pickled_data["embeddings"]

print(f"Number of documents: {len(sentences)}")


# Load pre-embedded corpus
print(f"Number of embeddings: {corpus_embeddings.shape[0]}")

# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length

# Load speech to text model
asr_model = "patrickvonplaten/wav2vec2-base-960h-4-gram"
processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model)
asr = pipeline(
    "automatic-speech-recognition",
    model=asr_model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    decoder=processor.decoder,
)


def find_sentences(query, n_hits):
    query_embedding = model.encode(query)
    hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=n_hits)
    hits = hits[0]

    output_texts = []
    output_scores = []

    for hit in hits:
        print(hit["corpus_id"])
        # Find source document based on sentence index
        output_texts.append(sentences[hit["corpus_id"]])
        output_scores.append(hit["score"])

    return pd.DataFrame(data={"Text": output_texts, "Score": output_scores})


def process(input_selection, query, filepath, hits):
    if input_selection == "speech":
        speech, sampling_rate = load(filepath)
        if sampling_rate != 16000:
            speech = resample(speech, sampling_rate, 16000)
        text = asr(speech)["text"]
    else:
        text = query
    return text, find_sentences(text, hits)


# Gradio inputs
buttons = gr.inputs.Radio(
    ["text", "speech"], type="value", default="speech", label="Input selection"
)
text_query = gr.inputs.Textbox(
    lines=1, label="Text input", default="breast cancer biomarkers"
)
mic = gr.inputs.Audio(
    source="microphone", type="filepath", label="Speech input", optional=True
)
slider = gr.inputs.Slider(
    minimum=1, maximum=10, step=1, default=3, label="Number of hits"
)

# Gradio outputs
speech_query = gr.Textbox(type="auto", label="Query string")
results = gr.Dataframe(headers=["Text", "Score"], label="Query results")

iface = gr.Interface(
    theme="huggingface",
    description="This Space lets you query a text corpus containing 50,000 random clinical trial descriptions",
    fn=process,
    layout="horizontal",
    inputs=[buttons, text_query, mic, slider],
    outputs=[speech_query, results],
    examples=[
        ["text", "breast cancer biomarkers", "dummy.wav", 3],
    ],
    allow_flagging=False,
)
iface.launch()