import pickle import pandas as pd import gradio as gr import numpy as np from sentence_transformers import SentenceTransformer, util from transformers import pipeline, Wav2Vec2ProcessorWithLM from librosa import load, resample # Constants model_name = 'sentence-transformers/msmarco-distilbert-base-v4' max_sequence_length = 512 # Load corpus import subprocess subprocess.run(["gdown", "1QVpyk_xyqNYrHT3NdUfBxbDV_eyCDa2Q"]) with open("embeddings.pkl", "rb") as fp: pickled_data = pickle.load(fp) sentences = pickled_data['sentences'] corpus_embeddings = pickled_data['embeddings'] print(f'Number of documents: {len(sentences)}') # Load pre-embedded corpus print(f'Number of embeddings: {corpus_embeddings.shape[0]}') # Load embedding model model = SentenceTransformer(model_name) model.max_seq_length = max_sequence_length # Load speech to text model asr_model = "patrickvonplaten/wav2vec2-base-960h-4-gram" processor = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model) asr = pipeline( "automatic-speech-recognition", model=asr_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, decoder=processor.decoder, ) def find_sentences(query, hits): query_embedding = model.encode(query) hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=hits) hits = hits[0] output_texts = [] output_scores = [] for hit in hits: # Find source document based on sentence index output_texts.append(sentences[hit['corpus_id']]) output_scores.append(hit['score']) return pd.DataFrame(data={"Text": output_texts, "Score": output_scores}) def process(input_selection, query, filepath, hits): if input_selection=='speech': speech, sampling_rate = load(filepath) if sampling_rate != 16000: speech = resample(speech, sampling_rate, 16000) text = asr(speech)['text'] else: text = query return text, find_sentences(text, hits) # Gradio inputs buttons = gr.inputs.Radio(['text','speech'], type='value', default='speech', label='Input selection') text_query = gr.inputs.Textbox(lines=1, label='Text input', default='breast cancer biomarkers') mic = gr.inputs.Audio(source='microphone', type='filepath', label='Speech input', optional=True) slider = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=3, label='Number of hits') # Gradio outputs speech_query = gr.outputs.Textbox(type='auto', label='Query string') results = gr.outputs.Dataframe( headers=['Text', 'Score'], col_width=200, label='Query results') iface = gr.Interface( theme='huggingface', description='This Space lets you query a text corpus containing 50,000 random clinical trial descriptions', fn=process, layout='horizontal', inputs=[buttons,text_query,mic,slider], outputs=[speech_query, results], examples=[ ['text', "breast cancer biomarkers", 'dummy.wav', 3], ], allow_flagging=False ) iface.launch()