Spaces:
Runtime error
Runtime error
import pickle | |
import pandas as pd | |
import gradio as gr | |
import numpy as np | |
from sentence_transformers import SentenceTransformer, util | |
from transformers import pipeline, Wav2Vec2ProcessorWithLM | |
from librosa import load, resample | |
# Constants | |
model_name = "sentence-transformers/msmarco-distilbert-base-v4" | |
max_sequence_length = 512 | |
# Load corpus | |
import subprocess | |
subprocess.run(["gdown", "1QVpyk_xyqNYrHT3NdUfBxbDV_eyCDa2Q"]) | |
with open("embeddings.pkl", "rb") as fp: | |
pickled_data = pickle.load(fp) | |
sentences = pickled_data["sentences"] | |
corpus_embeddings = pickled_data["embeddings"] | |
print(f"Number of documents: {len(sentences)}") | |
# Load pre-embedded corpus | |
print(f"Number of embeddings: {corpus_embeddings.shape[0]}") | |
# Load embedding model | |
model = SentenceTransformer(model_name) | |
model.max_seq_length = max_sequence_length | |
# Load speech to text model | |
asr_model = "patrickvonplaten/wav2vec2-base-960h-4-gram" | |
processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model) | |
asr = pipeline( | |
"automatic-speech-recognition", | |
model=asr_model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
decoder=processor.decoder, | |
) | |
def find_sentences(query, n_hits): | |
query_embedding = model.encode(query) | |
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=n_hits) | |
hits = hits[0] | |
output_texts = [] | |
output_scores = [] | |
for hit in hits: | |
print(hit["corpus_id"]) | |
# Find source document based on sentence index | |
output_texts.append(sentences[hit["corpus_id"]]) | |
output_scores.append(hit["score"]) | |
return pd.DataFrame(data={"Text": output_texts, "Score": output_scores}) | |
def process(input_selection, query, filepath, hits): | |
if input_selection == "speech": | |
speech, sampling_rate = load(filepath) | |
if sampling_rate != 16000: | |
speech = resample(speech, sampling_rate, 16000) | |
text = asr(speech)["text"] | |
else: | |
text = query | |
return text, find_sentences(text, hits) | |
# Gradio inputs | |
buttons = gr.inputs.Radio( | |
["text", "speech"], type="value", default="speech", label="Input selection" | |
) | |
text_query = gr.inputs.Textbox( | |
lines=1, label="Text input", default="breast cancer biomarkers" | |
) | |
mic = gr.inputs.Audio( | |
source="microphone", type="filepath", label="Speech input", optional=True | |
) | |
slider = gr.inputs.Slider( | |
minimum=1, maximum=10, step=1, default=3, label="Number of hits" | |
) | |
# Gradio outputs | |
speech_query = gr.Textbox(type="auto", label="Query string") | |
results = gr.Dataframe(headers=["Text", "Score"], label="Query results") | |
iface = gr.Interface( | |
theme="huggingface", | |
description="This Space lets you query a text corpus containing 50,000 random clinical trial descriptions", | |
fn=process, | |
layout="horizontal", | |
inputs=[buttons, text_query, mic, slider], | |
outputs=[speech_query, results], | |
examples=[ | |
["text", "breast cancer biomarkers", "dummy.wav", 3], | |
], | |
allow_flagging=False, | |
) | |
iface.launch() | |