import pickle import pandas as pd import gradio as gr import numpy as np from sentence_transformers import SentenceTransformer, util from transformers import pipeline, Wav2Vec2ProcessorWithLM from librosa import load, resample # Constants model_name = "sentence-transformers/msmarco-distilbert-base-v4" max_sequence_length = 512 # Load corpus import subprocess subprocess.run(["gdown", "1QVpyk_xyqNYrHT3NdUfBxbDV_eyCDa2Q"]) with open("embeddings.pkl", "rb") as fp: pickled_data = pickle.load(fp) sentences = pickled_data["sentences"] corpus_embeddings = pickled_data["embeddings"] print(f"Number of documents: {len(sentences)}") # Load pre-embedded corpus print(f"Number of embeddings: {corpus_embeddings.shape[0]}") # Load embedding model model = SentenceTransformer(model_name) model.max_seq_length = max_sequence_length # Load speech to text model asr_model = "patrickvonplaten/wav2vec2-base-960h-4-gram" processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model) asr = pipeline( "automatic-speech-recognition", model=asr_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, decoder=processor.decoder, ) def find_sentences(query, n_hits): query_embedding = model.encode(query) hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=n_hits) hits = hits[0] output_texts = [] output_scores = [] for hit in hits: print(hit["corpus_id"]) # Find source document based on sentence index output_texts.append(sentences[hit["corpus_id"]]) output_scores.append(hit["score"]) return pd.DataFrame(data={"Text": output_texts, "Score": output_scores}) def process(input_selection, query, filepath, hits): if input_selection == "speech": speech, sampling_rate = load(filepath) if sampling_rate != 16000: speech = resample(speech, sampling_rate, 16000) text = asr(speech)["text"] else: text = query return text, find_sentences(text, hits) # Gradio inputs buttons = gr.inputs.Radio( ["text", "speech"], type="value", default="speech", label="Input selection" ) text_query = gr.inputs.Textbox( lines=1, label="Text input", default="breast cancer biomarkers" ) mic = gr.inputs.Audio( source="microphone", type="filepath", label="Speech input", optional=True ) slider = gr.inputs.Slider( minimum=1, maximum=10, step=1, default=3, label="Number of hits" ) # Gradio outputs speech_query = gr.Textbox(type="auto", label="Query string") results = gr.Dataframe(headers=["Text", "Score"], label="Query results") iface = gr.Interface( theme="huggingface", description="This Space lets you query a text corpus containing 50,000 random clinical trial descriptions", fn=process, layout="horizontal", inputs=[buttons, text_query, mic, slider], outputs=[speech_query, results], allow_flagging=False, ) iface.launch()