nbroad's picture
nbroad HF staff
col width smaller
db0d58f
raw
history blame
No virus
2.98 kB
import pickle
import pandas as pd
import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline, Wav2Vec2ProcessorWithLM
from librosa import load, resample
# Constants
model_name = 'sentence-transformers/msmarco-distilbert-base-v4'
max_sequence_length = 512
# Load corpus
import subprocess
subprocess.run(["gdown", "1QVpyk_xyqNYrHT3NdUfBxbDV_eyCDa2Q"])
with open("embeddings.pkl", "rb") as fp:
pickled_data = pickle.load(fp)
sentences = pickled_data['sentences']
corpus_embeddings = pickled_data['embeddings']
print(f'Number of documents: {len(sentences)}')
# Load pre-embedded corpus
print(f'Number of embeddings: {corpus_embeddings.shape[0]}')
# Load embedding model
model = SentenceTransformer(model_name)
model.max_seq_length = max_sequence_length
# Load speech to text model
asr_model = "patrickvonplaten/wav2vec2-base-960h-4-gram"
processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model)
asr = pipeline(
"automatic-speech-recognition",
model=asr_model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
decoder=processor.decoder,
)
def find_sentences(query, hits):
query_embedding = model.encode(query)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits)
hits = hits[0]
output_texts = []
output_scores = []
for hit in hits:
# Find source document based on sentence index
output_texts.append(sentences[hit['corpus_id']])
output_scores.append(hit['score'])
return pd.DataFrame(data={"Text": output_texts, "Score": output_scores})
def process(input_selection, query, filepath, hits):
if input_selection=='speech':
speech, sampling_rate = load(filepath)
if sampling_rate != 16000:
speech = resample(speech, sampling_rate, 16000)
text = asr(speech)['text']
else:
text = query
return text, find_sentences(text, hits)
# Gradio inputs
buttons = gr.inputs.Radio(['text','speech'], type='value', default='speech', label='Input selection')
text_query = gr.inputs.Textbox(lines=1, label='Text input', default='breast cancer biomarkers')
mic = gr.inputs.Audio(source='microphone', type='filepath', label='Speech input', optional=True)
slider = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label='Number of hits')
# Gradio outputs
speech_query = gr.outputs.Textbox(type='auto', label='Query string')
results = gr.outputs.Dataframe(
headers=['Text', 'Score'],
col_width=200,
label='Query results')
iface = gr.Interface(
theme='huggingface',
description='This Space lets you query a text corpus containing 50,000 random clinical trial descriptions',
fn=process,
layout='horizontal',
inputs=[buttons,text_query,mic,slider],
outputs=[speech_query, results],
examples=[
['text', "breast cancer biomarkers", 'dummy.wav', 3],
],
allow_flagging=False
)
iface.launch()