Spaces:
Running
Running
File size: 5,960 Bytes
2428d17 3cf3964 2428d17 e7eed8f 2428d17 e7eed8f 2428d17 e7eed8f 2428d17 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# serve.py
# Loads all completed shards and finds the most similar vector to a given query vector.
import requests
from sentence_transformers import SentenceTransformer
import faiss
import gradio as gr
from markdown_it import MarkdownIt # used for overriding default markdown renderer
model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
works_ids_path = 'openalex_ids.txt'
with open(works_ids_path) as f:
idxs = f.read().splitlines()
index = faiss.read_index('index.faiss')
ps = faiss.ParameterSpace()
ps.initialize(index)
ps.set_index_parameters(index, 'nprobe=16,ht=512')
def _recover_abstract(inverted_index):
abstract_size = max([max(appearances) for appearances in inverted_index.values()])+1
abstract = [None]*abstract_size
for word, appearances in inverted_index.items(): # yes, this is a second iteration over inverted_index
for appearance in appearances:
abstract[appearance] = word
abstract = [word for word in abstract if word is not None]
abstract = ' '.join(abstract)
return abstract
def search(query):
global model, index, idxs
query_embedding = model.encode(query)
query_embedding = query_embedding.reshape(1, -1)
distances, faiss_ids = index.search(query_embedding, 20)
distances = distances[0]
faiss_ids = faiss_ids[0]
openalex_ids = [idxs[faiss_id] for faiss_id in faiss_ids]
search_filter = f'openalex_id:{"|".join(openalex_ids)}'
search_select = 'id,title,abstract_inverted_index,authorships,primary_location,publication_year,cited_by_count,doi'
neighbors = [(distance, openalex_id) for distance, openalex_id in zip(distances, openalex_ids)]
request_str = f'https://api.openalex.org/works?filter={search_filter}&select={search_select}'
return neighbors, request_str
def execute_request(request_str):
response = requests.get(request_str).json()
return response
def format_response(neighbors, response):
response = {doc['id']: doc for doc in response['results']}
result_string = ''
for distance, openalex_id in neighbors:
doc = response[openalex_id]
# collect attributes from openalex doc for the given openalex_id
title = doc['title']
abstract = _recover_abstract(doc['abstract_inverted_index'])
author_names = [authorship['author']['display_name'] for authorship in doc['authorships']]
# journal_name = doc['primary_location']['source']['display_name']
publication_year = doc['publication_year']
citation_count = doc['cited_by_count']
doi = doc['doi']
# try to get journal name or else set it to None
try:
journal_name = doc['primary_location']['source']['display_name']
except (TypeError, KeyError):
journal_name = None
# title: knock out escape sequences
title = title.replace('\n', '\\n').replace('\r', '\\r')
# abstract: knock out escape sequences, then truncate to 1500 characters if necessary
abstract = abstract.replace('\n', '\\n').replace('\r', '\\r')
if len(abstract) > 2000:
abstract = abstract[:2000] + '...'
# authors: truncate to 3 authors if necessary
if len(author_names) >= 3:
authors_str = ', '.join(author_names[:3]) + ', ...'
else:
authors_str = ', '.join(author_names)
entry_string = ''
if doi: # edge case: for now, no doi -> no link
entry_string += f'## [{title}]({doi})\n\n'
else:
entry_string += f'## {title}\n\n'
if journal_name:
entry_string += f'**{authors_str} - {journal_name}, {publication_year}**\n'
else:
entry_string += f'**{authors_str}, {publication_year}**\n'
entry_string += f'{abstract}\n\n'
if citation_count: # edge case: we shouldn't tack "Cited-by count: 0" onto someone's paper
entry_string += f'*Cited-by count: {citation_count}*'
entry_string += ' '
if doi: # list the doi if it exists
entry_string += f'*DOI: {doi.replace("https://doi.org/", "")}*'
entry_string += ' '
entry_string += f'*Similarity: {distance:.2f}*'
entry_string += ' \n'
result_string += entry_string
return result_string
with gr.Blocks() as demo:
gr.Markdown('# abstracts-search demo')
gr.Markdown(
'Explore 95 million academic publications selected from the [OpenAlex](https://openalex.org) dataset. This '
'project is an index of the embeddings generated from their titles and abstracts. The embeddings were '
'generated using the `all-MiniLM-L6-v2` model provided by the [sentence-transformers](https://www.sbert.net/) '
'module, and the index was built using the [faiss](https://github.com/facebookresearch/faiss) module.'
)
neighbors_var = gr.State()
request_str_var = gr.State()
response_var = gr.State()
query = gr.Textbox(lines=1, placeholder='Enter your query here', show_label=False)
btn = gr.Button('Search')
with gr.Box():
results = gr.Markdown()
md = MarkdownIt('js-default', {'linkify': True, 'typographer': True}) # don't render html or latex!
results.md = md
query.submit(search, inputs=[query], outputs=[neighbors_var, request_str_var]) \
.success(execute_request, inputs=[request_str_var], outputs=[response_var]) \
.success(format_response, inputs=[neighbors_var, response_var], outputs=[results])
btn.click(search, inputs=[query], outputs=[neighbors_var, request_str_var]) \
.success(execute_request, inputs=[request_str_var], outputs=[response_var]) \
.success(format_response, inputs=[neighbors_var, response_var], outputs=[results])
demo.queue(2)
demo.launch() |