File size: 2,950 Bytes
e693db5
 
 
 
 
 
 
 
 
 
c9a3e15
e693db5
 
 
 
c9a3e15
e693db5
 
 
c9a3e15
 
e693db5
c9a3e15
e693db5
 
 
c9a3e15
e693db5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9a3e15
 
e693db5
c9a3e15
e693db5
 
 
 
 
 
c9a3e15
e693db5
c9a3e15
 
 
e693db5
 
 
 
c9a3e15
 
 
 
 
 
 
 
 
e693db5
 
c9a3e15
 
 
 
 
 
 
 
 
 
 
 
e693db5
 
e58f18b
 
e693db5
 
c9a3e15
 
 
 
 
 
 
e693db5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import pickle
import pandas as pd
import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline, Wav2Vec2ProcessorWithLM
from librosa import load, resample

# Constants

model_name = "sentence-transformers/msmarco-distilbert-base-v4"
max_sequence_length = 512

# Load corpus
import subprocess

subprocess.run(["gdown", "1QVpyk_xyqNYrHT3NdUfBxbDV_eyCDa2Q"])
with open("embeddings.pkl", "rb") as fp:
    pickled_data = pickle.load(fp)
    sentences = pickled_data["sentences"]
    corpus_embeddings = pickled_data["embeddings"]

print(f"Number of documents: {len(sentences)}")


# Load pre-embedded corpus
print(f"Number of embeddings: {corpus_embeddings.shape[0]}")

# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length

# Load speech to text model
asr_model = "patrickvonplaten/wav2vec2-base-960h-4-gram"
processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model)
asr = pipeline(
    "automatic-speech-recognition",
    model=asr_model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    decoder=processor.decoder,
)


def find_sentences(query, n_hits):
    query_embedding = model.encode(query)
    hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=n_hits)
    hits = hits[0]

    output_texts = []
    output_scores = []

    for hit in hits:
        print(hit["corpus_id"])
        # Find source document based on sentence index
        output_texts.append(sentences[hit["corpus_id"]])
        output_scores.append(hit["score"])

    return pd.DataFrame(data={"Text": output_texts, "Score": output_scores})


def process(input_selection, query, filepath, hits):
    if input_selection == "speech":
        speech, sampling_rate = load(filepath)
        if sampling_rate != 16000:
            speech = resample(speech, sampling_rate, 16000)
        text = asr(speech)["text"]
    else:
        text = query
    return text, find_sentences(text, hits)


# Gradio inputs
buttons = gr.inputs.Radio(
    ["text", "speech"], type="value", default="speech", label="Input selection"
)
text_query = gr.inputs.Textbox(
    lines=1, label="Text input", default="breast cancer biomarkers"
)
mic = gr.inputs.Audio(
    source="microphone", type="filepath", label="Speech input", optional=True
)
slider = gr.inputs.Slider(
    minimum=1, maximum=10, step=1, default=3, label="Number of hits"
)

# Gradio outputs
speech_query = gr.Textbox(type="auto", label="Query string")
results = gr.Dataframe(headers=["Text", "Score"], label="Query results")

iface = gr.Interface(
    theme="huggingface",
    description="This Space lets you query a text corpus containing 50,000 random clinical trial descriptions",
    fn=process,
    layout="horizontal",
    inputs=[buttons, text_query, mic, slider],
    outputs=[speech_query, results],
    allow_flagging=False,
)
iface.launch()