nbroad's picture
nbroad HF staff
first commit
e693db5
raw
history blame
No virus
2.95 kB
import pickle
import pandas as pd
import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline, Wav2Vec2ProcessorWithLM
from librosa import load, resample
# Constants
model_name = 'sentence-transformers/msmarco-distilbert-base-v4'
max_sequence_length = 512
# Load corpus
import subprocess
subprocess.run(["gdown", "1QVpyk_xyqNYrHT3NdUfBxbDV_eyCDa2Q"])
with open("embeddings.pkl", "rb") as fp:
pickled_data = pickle.load(fp)
sentences = pickled_data['sentences']
corpus_embeddings = pickled_data['embeddings']
print(f'Number of documents: {len(sentences)}')
# Load pre-embedded corpus
print(f'Number of embeddings: {corpus_embeddings.shape[0]}')
# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length
# Load speech to text model
asr_model = "patrickvonplaten/wav2vec2-base-960h-4-gram"
processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model)
asr = pipeline(
"automatic-speech-recognition",
model=asr_model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
decoder=processor.decoder,
)
def find_sentences(query, hits):
query_embedding = model.encode(query)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
hits = hits[0]
output_texts = []
output_scores = []
for hit in hits:
# Find source document based on sentence index
output_texts.append(sentences[hit['corpus_id']])
output_scores.append(hit['score'])
return pd.DataFrame(data={"Text": output_texts, "Score": output_scores})
def process(input_selection, query, filepath, hits):
if input_selection=='speech':
speech, sampling_rate = load(filepath)
if sampling_rate != 16000:
speech = resample(speech, sampling_rate, 16000)
text = asr(speech)['text']
else:
text = query
return text, find_sentences(text, hits)
# Gradio inputs
buttons = gr.inputs.Radio(['text','speech'], type='value', default='speech', label='Input selection')
text_query = gr.inputs.Textbox(lines=1, label='Text input', default='breast cancer biomarkers')
mic = gr.inputs.Audio(source='microphone', type='filepath', label='Speech input', optional=True)
slider = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label='Number of hits')
# Gradio outputs
speech_query = gr.outputs.Textbox(type='auto', label='Query string')
results = gr.outputs.Dataframe(
headers=['Text', 'Score'],
label='Query results')
iface = gr.Interface(
theme='huggingface',
description='This Space lets you query a text corpus containing 50,000 random clinical trial descriptions',
fn=process,
layout='horizontal',
inputs=[buttons,text_query,mic,slider],
outputs=[speech_query, results],
examples=[
['text', "breast cancer biomarkers", 'dummy.wav', 3],
],
allow_flagging=False
)
iface.launch()