import gradio as gr import ir_datasets import pandas as pd import numpy as np from autogluon.multimodal import MultiModalPredictor query_embedding = None document_embedding = None docs_df = None def text_embedding_batch(): global query_embedding global docs_df model_name = "sentence-transformers/all-MiniLM-L6-v2" dataset = ir_datasets.load("beir/fiqa/dev") docs_df = pd.DataFrame(dataset.docs_iter()).set_index("doc_id").sample(frac=0.0001) predictor = MultiModalPredictor( pipeline="feature_extraction", hyperparameters={ "model.hf_text.checkpoint_name": model_name } ) embedding = predictor.extract_embedding(docs_df) query_embedding = embedding["text"] return query_embedding def text_embedding_single(query: str): global document_embedding model_name = "sentence-transformers/all-MiniLM-L6-v2" predictor = MultiModalPredictor( pipeline="feature_extraction", hyperparameters={ "model.hf_text.checkpoint_name": model_name } ) embedding = predictor.extract_embedding([query]) document_embedding = embedding["0"] return document_embedding def rank_document(): global query_embedding global document_embedding global docs_df print('~~~~~here') print('~~~~~~~~', query_embedding, document_embedding) q_norm = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True) print(q_norm) d_norm = document_embedding / np.linalg.norm(document_embedding, axis=-1, keepdims=True) scores = d_norm.dot(q_norm[0]) print(scores) result = [] for idx in np.argsort(-scores)[:2]: result.append(docs_df['text'].iloc[idx]) return result def main(): with gr.Blocks(title="OpenSearch Demo") as demo: gr.Markdown("# Semantic Search with Autogluon") gr.Markdown("Ask an open question!") with gr.Row(): inp_single = gr.Textbox(show_label=False) with gr.Row(): btn_single = gr.Button("Generate Embedding") with gr.Row(): out_single = gr.DataFrame(label="Embedding", show_label=True) gr.Markdown("You can select one of the sample datasets for document embedding") with gr.Row(): btn_fiqa = gr.Button("fiqa") with gr.Row(): out_batch = gr.DataFrame(label="Sample Embeddings", show_label=True, row_count=5) gr.Markdown("Now rank the documents and pick the top 3 most relevant from the dataset") with gr.Row(): btn_rank = gr.Button("Rank documents") with gr.Row(): out_rank = gr.DataFrame(label="Top ranked documents", show_label=True, row_count=5) # with gr.Row(): # out_batch = gr.File(interactive=True) # with gr.Row(): # btn_file = gr.Button("Generate Embedding") btn_single.click(fn=text_embedding_single, inputs=inp_single, outputs=out_single) btn_fiqa.click(fn=text_embedding_batch, inputs=None, outputs=out_batch) btn_rank.click(fn=rank_document, inputs=None, outputs=out_rank) # btn_file.click(fn=text_embedding_batch, inputs=inp_single, outputs=out_single) demo.launch() if __name__ == "__main__": main()