|
import gradio as gr |
|
from sentence_transformers import SentenceTransformer, util |
|
import torch |
|
import pandas as pd |
|
import faiss |
|
import numpy as np |
|
|
|
model = SentenceTransformer("Elise-hf/distilbert-base-pwc-multi-task") |
|
|
|
|
|
index = faiss.read_index("all_inst_index") |
|
tasks_index = faiss.read_index("tasks_index") |
|
methods_index = faiss.read_index("methods_index") |
|
labels = pd.read_json('lbl_gpt3_davinci_clean_with_counts.json') |
|
methods = pd.read_json('methods.json') |
|
|
|
papers = pd.read_json('title_url_clean.json') |
|
tasks_embeddings = np.load('tasks_embeddings.npy') |
|
methods_embeddings = np.load('methods_embeddings.npy') |
|
|
|
def search_faiss_single(index, inst_embeddings, top_k): |
|
|
|
D, I = index.search(inst_embeddings, top_k) |
|
return D, I |
|
|
|
|
|
def find_similar_papers_tasks_methods(title, abstract, k=100): |
|
|
|
|
|
|
|
query = title + '</s>' + abstract |
|
|
|
|
|
query_embedding = model.encode([query], convert_to_numpy=True) |
|
|
|
|
|
D, I = search_faiss_single(index, query_embedding, k) |
|
|
|
|
|
|
|
|
|
D_tasks, I_tasks = search_faiss_single(tasks_index, query_embedding, k) |
|
norm = np.linalg.norm(tasks_embeddings[I_tasks[0]], axis=1) * np.linalg.norm(query_embedding, axis=1)[:, None] |
|
D_tasks /= norm |
|
|
|
|
|
|
|
D_methods, I_methods = search_faiss_single(methods_index, query_embedding, k) |
|
norm = np.linalg.norm(methods_embeddings[I_methods[0]], axis=1) * np.linalg.norm(query_embedding, axis=1)[:, None] |
|
D_methods /= norm |
|
|
|
|
|
|
|
tasks_results = dict(zip(labels.loc[I_tasks[0]].title, D_tasks[0].tolist())) |
|
|
|
|
|
methods_results = dict(zip(methods.loc[I_methods[0]].title, D_methods[0].tolist())) |
|
|
|
|
|
return tasks_results, methods_results,papers.loc[I[0]] |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.TabItem("Task Search"): |
|
gr.Markdown( |
|
""" |
|
# Identify Relevant Tasks from Abstracts |
|
|
|
""" |
|
) |
|
title = gr.components.Textbox(label="Enter an paper's title") |
|
abstract = gr.components.Textbox(label="Enter an abstract to discover relevant tasks from it") |
|
btn = gr.Button("Submit") |
|
|
|
with gr.Row(): |
|
tasks_table = gr.components.Label(label="Relevant Tasks from PapersWithCode") |
|
methods_table = gr.components.Label(label="Relevant Methods from PapersWithCode") |
|
|
|
output_df = gr.Dataframe( |
|
headers=["title", "paper_url"], |
|
datatype=["str", "str"], |
|
row_count=10, |
|
col_count=(2, "fixed"), label="Relevant papers from PapersWithCode" |
|
) |
|
btn.click(fn=find_similar_papers_tasks_methods, |
|
inputs=[title, abstract], |
|
outputs=[tasks_table, methods_table,output_df]) |
|
|
|
|
|
|
|
demo.launch() |
|
|