nbroad's picture
nbroad HF staff
no examples
edc4b42
import pickle
import pandas as pd
import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline, Wav2Vec2ProcessorWithLM
from librosa import load, resample
# Constants
model_name = "sentence-transformers/msmarco-distilbert-base-v4"
max_sequence_length = 512
# Load corpus
import subprocess
subprocess.run(["gdown", "1QVpyk_xyqNYrHT3NdUfBxbDV_eyCDa2Q"])
with open("embeddings.pkl", "rb") as fp:
pickled_data = pickle.load(fp)
sentences = pickled_data["sentences"]
corpus_embeddings = pickled_data["embeddings"]
print(f"Number of documents: {len(sentences)}")
# Load pre-embedded corpus
print(f"Number of embeddings: {corpus_embeddings.shape[0]}")
# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length
# Load speech to text model
asr_model = "patrickvonplaten/wav2vec2-base-960h-4-gram"
processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model)
asr = pipeline(
"automatic-speech-recognition",
model=asr_model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
decoder=processor.decoder,
)
def find_sentences(query, n_hits):
query_embedding = model.encode(query)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=n_hits)
hits = hits[0]
output_texts = []
output_scores = []
for hit in hits:
print(hit["corpus_id"])
# Find source document based on sentence index
output_texts.append(sentences[hit["corpus_id"]])
output_scores.append(hit["score"])
return pd.DataFrame(data={"Text": output_texts, "Score": output_scores})
def process(input_selection, query, filepath, hits):
if input_selection == "speech":
speech, sampling_rate = load(filepath)
if sampling_rate != 16000:
speech = resample(speech, sampling_rate, 16000)
text = asr(speech)["text"]
else:
text = query
return text, find_sentences(text, hits)
# Gradio inputs
buttons = gr.inputs.Radio(
["text", "speech"], type="value", default="speech", label="Input selection"
)
text_query = gr.inputs.Textbox(
lines=1, label="Text input", default="breast cancer biomarkers"
)
mic = gr.inputs.Audio(
source="microphone", type="filepath", label="Speech input", optional=True
)
slider = gr.inputs.Slider(
minimum=1, maximum=10, step=1, default=3, label="Number of hits"
)
# Gradio outputs
speech_query = gr.Textbox(type="auto", label="Query string")
results = gr.Dataframe(headers=["Text", "Score"], label="Query results")
iface = gr.Interface(
theme="huggingface",
description="This Space lets you query a text corpus containing 50,000 random clinical trial descriptions",
fn=process,
layout="horizontal",
inputs=[buttons, text_query, mic, slider],
outputs=[speech_query, results],
allow_flagging=False,
)
iface.launch()