Spaces:
Running
on
Zero
Running
on
Zero
| # app.py | |
| # Loads all completed shards and finds the most similar vector to a given query vector. | |
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from math import log10 | |
| from pathlib import Path | |
| from sys import stderr | |
| from typing import Any, Callable, TypedDict, TypeVar | |
| from urllib.parse import urlparse | |
| import faiss | |
| import gradio as gr | |
| import numpy as np | |
| import numpy.typing as npt | |
| import requests | |
| import torch | |
| from datasets import Dataset | |
| from datasets.search import FaissIndex | |
| from huggingface_hub import snapshot_download | |
| from sentence_transformers import SentenceTransformer | |
| try: | |
| import spaces | |
| except ImportError: | |
| spaces = None | |
| T = TypeVar("T") | |
| U = TypeVar("U") | |
| class IndexParameters(TypedDict): | |
| recall: float # in this case 10-recall@10 | |
| exec_time: float # seconds (raw faiss measure is in milliseconds) | |
| param_string: str # pass directly to faiss index | |
| class Params(TypedDict): | |
| dimensions: int | None | |
| normalize: bool | |
| optimal_params: list[IndexParameters] | |
| class Work: | |
| title: str | None | |
| abstract: str | None # recovered from abstract_inverted_index | |
| authors: list[str] # takes raw_author_name field from Authorship objects | |
| journal_name: str | None # takes the display_name field of the first location | |
| year: int | |
| citations: int | |
| doi: str | None | |
| oa_url: str | None | |
| def __post_init__(self): | |
| self._check_type(self.title, str, nullable=True) | |
| self._check_type(self.abstract, str, nullable=True) | |
| self._check_type(self.authors, list) | |
| for author in self.authors: | |
| self._check_type(author, str) | |
| self._check_type(self.journal_name, str, nullable=True) | |
| self._check_type(self.year, int) | |
| self._check_type(self.citations, int) | |
| self._check_type(self.doi, str, nullable=True) | |
| def from_dict(cls, d: dict) -> "Work": | |
| inverted_index: None | dict[str, list[int]] = d["abstract_inverted_index"] | |
| abstract = cls._recover_abstract(inverted_index) if inverted_index else None | |
| try: | |
| journal_name = d["primary_location"]["source"]["display_name"] | |
| except (TypeError, KeyError): # key didn't exist or a value was null | |
| journal_name = None | |
| try: | |
| location = d["best_oa_location"] | |
| pdf_url = location["pdf_url"] | |
| landing_page_url = location["landing_page_url"] | |
| oa_url = landing_page_url if landing_page_url else pdf_url | |
| except (TypeError, KeyError): | |
| oa_url = None | |
| return cls( | |
| title=d["title"], | |
| abstract=abstract, | |
| authors=[authorship["raw_author_name"] for authorship in d["authorships"]], | |
| journal_name=journal_name, | |
| year=d["publication_year"], | |
| citations=d["cited_by_count"], | |
| doi=d["doi"], | |
| oa_url=oa_url, | |
| ) | |
| def get_raw_fields() -> list[str]: | |
| return [ | |
| "title", | |
| "abstract_inverted_index", | |
| "authorships", | |
| "primary_location", | |
| "publication_year", | |
| "cited_by_count", | |
| "doi", | |
| "best_oa_location", | |
| ] | |
| def _check_type(v: Any, t: type, nullable: bool = False): | |
| if not ((nullable and v is None) or isinstance(v, t)): | |
| v_type_name = f"{type(v)}" if v is not None else "None" | |
| t_name = f"{t}" | |
| if nullable: | |
| t_name += " | None" | |
| raise ValueError(f"expected {t_name}, got {v_type_name}") | |
| def _recover_abstract(inverted_index: dict[str, list[int]]) -> str: | |
| abstract_size = max(max(locs) for locs in inverted_index.values()) + 1 | |
| abstract_words: list[str | None] = [None] * abstract_size | |
| for word, locs in inverted_index.items(): | |
| for loc in locs: | |
| abstract_words[loc] = word | |
| return " ".join(word for word in abstract_words if word is not None) | |
| def get_env_var(key: str, type_: Callable[[str], T] = str, default: U = None) -> T | U: | |
| var = os.getenv(key) | |
| if var is not None: | |
| var = type_(var) | |
| else: | |
| var = default | |
| return var | |
| def get_model( | |
| model_name: str, params_dir: Path, trust_remote_code: bool | |
| ) -> tuple[bool, SentenceTransformer]: | |
| # TODO: params["normalize"] for models like all-MiniLM-v6, which already normalize? | |
| with open(params_dir / "params.json", "r") as f: | |
| params: Params = json.load(f) | |
| return params["normalize"], SentenceTransformer( | |
| model_name, | |
| trust_remote_code=trust_remote_code, | |
| truncate_dim=params["dimensions"], | |
| ) | |
| def open_ondisk(dir: Path) -> faiss.Index: | |
| # without IO_FLAG_ONDISK_SAME_DIR, read_index gets on-disk indices in working dir | |
| return faiss.read_index(str(dir / "index.faiss"), faiss.IO_FLAG_ONDISK_SAME_DIR) | |
| def get_index(dir: Path, search_time_s: float) -> Dataset: | |
| # NOTE: use a private attr to load the index with IO_FLAG_ONDISK_SAME_DIR! | |
| index: Dataset = Dataset.from_parquet(str(dir / "ids.parquet")) # type: ignore | |
| faiss_index = open_ondisk(dir) | |
| index._indexes["embedding"] = FaissIndex(None, None, None, faiss_index) | |
| with open(dir / "params.json", "r") as f: | |
| params: Params = json.load(f) | |
| under = [p for p in params["optimal_params"] if p["exec_time"] < search_time_s] | |
| optimal = max(under, key=(lambda p: p["recall"])) | |
| optimal_string = optimal["param_string"] | |
| ps = faiss.ParameterSpace() | |
| ps.initialize(faiss_index) | |
| ps.set_index_parameters(faiss_index, optimal_string) | |
| return index | |
| def execute_request(ids: list[str], mailto: str | None) -> list[Work]: | |
| if len(ids) > 100: | |
| raise ValueError("querying /works endpoint with more than 100 works") | |
| # query with the /works endpoint with a specific list of IDs and fields | |
| search_filter = f"openalex_id:{'|'.join(ids)}" | |
| search_select = ",".join(["id"] + Work.get_raw_fields()) | |
| params = {"filter": search_filter, "select": search_select, "per-page": 100} | |
| if mailto is not None: | |
| params["mailto"] = mailto | |
| response = requests.get("https://api.openalex.org/works", params) | |
| response.raise_for_status() | |
| # the response is not necessarily ordered, so order them | |
| response = {d["id"]: Work.from_dict(d) for d in response.json()["results"]} | |
| return [response[id_] for id_ in ids] | |
| def collapse_newlines(x: str) -> str: | |
| return x.replace("\r\n", " ").replace("\n", " ").replace("\r", " ") | |
| def format_response( | |
| neighbors: list[Work], distances: list[float], calculate_similarity: bool = False | |
| ) -> str: | |
| result_string = "" | |
| for work, distance in zip(neighbors, distances): | |
| entry_string = "## " | |
| if work.title and work.doi: | |
| entry_string += f"[{collapse_newlines(work.title)}]({work.doi})" | |
| elif work.title: | |
| entry_string += f"{collapse_newlines(work.title)}" | |
| elif work.doi: | |
| entry_string += f"[No title]({work.doi})" | |
| else: | |
| entry_string += "No title" | |
| entry_string += "\n\n**" | |
| if len(work.authors) >= 3: # truncate to 3 if necessary | |
| entry_string += ", ".join(work.authors[:3]) + ", ..." | |
| elif work.authors: | |
| entry_string += ", ".join(work.authors) | |
| else: | |
| entry_string += "No author" | |
| entry_string += f", {work.year}" | |
| if work.journal_name: | |
| entry_string += " - " + work.journal_name | |
| entry_string += "**\n\n" | |
| if work.abstract: | |
| abstract = collapse_newlines(work.abstract) | |
| if len(abstract) > 2000: | |
| abstract = abstract[:2000] + "..." | |
| entry_string += abstract | |
| else: | |
| entry_string += "No abstract" | |
| entry_string += "\n\n*" | |
| meta: list[tuple[str, str]] = [] | |
| if work.citations: # don't tack "Cited-by count: 0" on someones's work | |
| meta.append(("Cited-by count", str(work.citations))) | |
| if work.doi: | |
| meta.append(("DOI", work.doi.replace("https://doi.org/", ""))) | |
| if work.oa_url: | |
| # use the netloc for readability, but use the full link if it's not found | |
| netloc = urlparse(work.oa_url).netloc | |
| if netloc: | |
| if netloc.startswith("www."): | |
| link_text = netloc[5:] | |
| else: | |
| link_text = netloc | |
| else: | |
| link_text = work.oa_url | |
| meta.append(("OA", f"[{link_text}]({work.oa_url})")) | |
| if calculate_similarity: | |
| # if query and result are unit vectors, the cosine sim is 1 - dist^2 / 2 | |
| meta.append(("Similarity", f"{1 - distance / 2:.2f}")) # faiss gives dist^2 | |
| else: | |
| meta.append(("Distance", f"{distance:.2f}")) | |
| entry_string += (" " * 4).join(": ".join(tup) for tup in meta) | |
| entry_string += "*\n" | |
| result_string += entry_string | |
| return result_string | |
| def main(): | |
| # TODO: figure out some better defaults? | |
| model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2") | |
| prompt_name = get_env_var("PROMPT_NAME") | |
| trust_remote_code = get_env_var("TRUST_REMOTE_CODE", bool, default=False) | |
| fp16 = get_env_var("FP16", bool, default=False) | |
| dir = get_env_var("DIR", Path) | |
| repo = get_env_var("REPO", str) | |
| search_time_s = get_env_var("SEARCH_TIME_S", float, default=1) | |
| k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet | |
| mailto = get_env_var("MAILTO", str, None) | |
| if dir is None: # acquire the index if it's not local | |
| if repo is None: | |
| repo = "colonelwatch/abstracts-faiss" | |
| dir = Path(snapshot_download(repo, repo_type="dataset")) / "index" | |
| elif repo is not None: | |
| print('warning: used "REPO" and also "DIR", ignoring "REPO"...', file=stderr) | |
| normalize, model = get_model(model_name, dir, trust_remote_code) | |
| index = get_index(dir, search_time_s) | |
| # follow model.encode logic for acquiring the prompt | |
| if prompt_name is None and model.default_prompt_name is not None: | |
| prompt_name = model.default_prompt_name | |
| if not isinstance(prompt_name, str): | |
| raise TypeError("invalid prompt name type") | |
| prompt: str | None = model.prompts[prompt_name] if prompt_name is not None else None | |
| # follow model.encode logic for setting extra_features | |
| extra_features: dict[str, Any] = {} | |
| if prompt is not None: | |
| tokenized = model.tokenize([prompt]) | |
| if "input_ids" in tokenized: | |
| extra_features["prompt_length"] = tokenized["input_ids"].shape[-1] - 1 | |
| model.eval() | |
| if torch.cuda.is_available(): | |
| model = model.half().cuda() if fp16 else model.bfloat16().cuda() | |
| # TODO: if huggingface datasets exposes an fp16 gpu option, use it here | |
| elif fp16: | |
| print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr) | |
| model.compile(mode="reduce-overhead") | |
| def encode_tokens(features: dict[str, Any]) -> npt.NDArray[np.float32]: | |
| # Tokenize (which yields a dict) then do a non-blocking transfer | |
| features = { | |
| k: v.to(model.device, non_blocking=True) for k, v in features.items() | |
| } | extra_features | |
| with torch.no_grad(): | |
| out_features = model.forward(features) | |
| embeddings = out_features["sentence_embedding"] | |
| embeddings = embeddings[0] | |
| if model.truncate_dim: | |
| embeddings = embeddings[: model.truncate_dim] | |
| if normalize: | |
| embeddings = torch.nn.functional.normalize(embeddings, dim=0) | |
| return embeddings.cpu().float().numpy() # faiss expected CPU float32 numpy arr | |
| if spaces: | |
| encode_tokens = spaces.GPU(encode_tokens) | |
| def encode_string(query: str) -> npt.NDArray[np.float32]: | |
| if prompt: | |
| query = prompt + query | |
| tokens = model.tokenize([query]) | |
| return encode_tokens(tokens) | |
| def search(query: str) -> str: | |
| query_embedding = encode_string(query) | |
| distances, faiss_ids = index.search("embedding", query_embedding, k) | |
| openalex_ids = index[faiss_ids]["id"] | |
| works = execute_request(openalex_ids, mailto) | |
| return format_response(works, distances, calculate_similarity=normalize) | |
| with gr.Blocks() as demo: | |
| # figure out the words to describe the quantity | |
| n_entries = len(index) | |
| n_digits = int(log10(n_entries)) | |
| divisor, postfix = { | |
| 0: (1, ""), | |
| 1: (1000, " thousand"), | |
| 2: (1000000, " million"), | |
| 3: (1000000000, " billion"), | |
| }[n_digits // 3] | |
| significand = n_entries / divisor | |
| significand = round(significand, 1 if (n_digits % 3 == 1) else None) | |
| quantity = str(significand) + postfix | |
| # split the (huggingface) model name and get the link | |
| model_publisher, model_human_name = model_name.split("/") | |
| model_link = f"https://huggingface.co/{model_publisher}/{model_human_name}" | |
| gr.Markdown("# abstracts-index") | |
| gr.Markdown( | |
| f"Explore {quantity} academic publications selected from the " | |
| "[OpenAlex](https://openalex.org) dataset (as of January 1st, 2025) with " | |
| "semantic search, not keyword search. This project is an index of the " | |
| "embeddings generated from their titles and abstracts. The embeddings were " | |
| f"generated using the [{model_human_name}]({model_link}) model, and the " | |
| "index was built using the " | |
| "[faiss](https://github.com/facebookresearch/faiss) module. The build " | |
| "scripts and more information available at the main repo " | |
| "[abstracts-search](https://github.com/colonelwatch/abstracts-search) on " | |
| "Github." | |
| ) | |
| query = gr.Textbox( | |
| lines=1, placeholder="Enter your query here", show_label=False | |
| ) | |
| btn = gr.Button("Search") | |
| results = gr.Markdown( | |
| latex_delimiters=[ | |
| {"left": "$$", "right": "$$", "display": False}, | |
| {"left": "$", "right": "$", "display": False}, | |
| ], | |
| container=True, | |
| ) | |
| # NOTE: ZeroGPU doesn't seem to support batching | |
| query.submit(search, inputs=[query], outputs=[results]) | |
| btn.click(search, inputs=[query], outputs=[results]) | |
| demo.queue() | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |