Elise-hf's picture
Update app.py
a87f84b
raw
history blame
3.5 kB
import gradio as gr
from sentence_transformers import SentenceTransformer, util
import torch
import pandas as pd
import faiss
import numpy as np
model = SentenceTransformer("Elise-hf/distilbert-base-pwc-multi-task")
# Load the Faiss indexes and data files
index = faiss.read_index("all_inst_index")
tasks_index = faiss.read_index("tasks_index")
methods_index = faiss.read_index("methods_index")
labels = pd.read_json('lbl_gpt3_davinci_clean_with_counts.json')
methods = pd.read_json('methods.json')
papers = pd.read_json('title_url_clean.json')
tasks_embeddings = np.load('tasks_embeddings.npy')
methods_embeddings = np.load('methods_embeddings.npy')
def search_faiss_single(index, inst_embeddings, top_k):
# faiss.normalize_L2(inst_embeddings)
D, I = index.search(inst_embeddings, top_k)
return D, I
def find_similar_papers_tasks_methods(title, abstract, k=100):
# Load the pre-trained model
# Add the title and the abstract together
query = title + '</s>' + abstract
# Encode the query sentence into an embedding
query_embedding = model.encode([query], convert_to_numpy=True)
# Search for the top k most similar papers
D, I = search_faiss_single(index, query_embedding, k)
# Search for the top k most similar tasks
D_tasks, I_tasks = search_faiss_single(tasks_index, query_embedding, k)
norm = np.linalg.norm(tasks_embeddings[I_tasks[0]], axis=1) * np.linalg.norm(query_embedding, axis=1)[:, None]
D_tasks /= norm
# Search for the top k most similar methods
D_methods, I_methods = search_faiss_single(methods_index, query_embedding, k)
norm = np.linalg.norm(methods_embeddings[I_methods[0]], axis=1) * np.linalg.norm(query_embedding, axis=1)[:, None]
D_methods /= norm
# Create a dictionary of the top k similar tasks and their cosine similarities
tasks_results = dict(zip(labels.loc[I_tasks[0]].title, D_tasks[0].tolist()))
# Create a dictionary of the top k similar methods and their cosine similarities
methods_results = dict(zip(methods.loc[I_methods[0]].title, D_methods[0].tolist()))
# Return the dictionaries of the top k similar tasks and methods, and the dataframe of the top k similar papers
return tasks_results, methods_results,papers.loc[I[0]]
with gr.Blocks() as demo:
with gr.TabItem("Task Search"):
gr.Markdown(
"""
# Identify Relevant Tasks from Abstracts
"""
)
title = gr.components.Textbox(label="Enter an paper's title")
abstract = gr.components.Textbox(label="Enter an abstract to discover relevant tasks from it")
btn = gr.Button("Submit")
with gr.Row():
tasks_table = gr.components.Label(label="Relevant Tasks from PapersWithCode")
methods_table = gr.components.Label(label="Relevant Methods from PapersWithCode")
output_df = gr.Dataframe(
headers=["title", "paper_url"],
datatype=["str", "str"],
row_count=10,
col_count=(2, "fixed"), label="Relevant papers from PapersWithCode"
)
btn.click(fn=find_similar_papers_tasks_methods,
inputs=[title, abstract],
outputs=[tasks_table, methods_table,output_df])
# gr.Examples(examples, inputs=[title, abstract], cache_examples=True, fn=find_similar_papers_tasks_methods,
# outputs=[tasks_table, methods_table,output_df])
demo.launch()