import pinecone import requests import streamlit as st import torch from transformers import AutoTokenizer, AutoModel from config import config def search(text: str, k: int = 5): """Get the k closest articles to the text.""" embeds = _get_embeddings(text) r = requests.post( f"https://{config.pinecone_index}-5b18b87.svc.{config.pinecone_env}.pinecone.io/query", headers={ "Api-Key": config.pinecone_api_key, "accept": "application/json", "content-type": "application/json", }, json={ "vector": embeds, "top_k": k, "includeMetadata": True, "includeValues": False, }, ) if r.status_code == 200: return r.json() else: raise Exception(f"Error: {r.status_code} - {r.text}") def _get_embeddings(text: str): inputs_ids = st.session_state.tokenizer(text, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): last_hidden_states = st.session_state.model(**inputs_ids)[0] return last_hidden_states.mean(dim=1).squeeze().tolist() password = st.text_input("Password", type="password") if password == config.password: st.title("PubMed Embeddings") st.subheader("Search for a PubMed article and get its id.") text = st.text_input("Search for a PubMed article", "Epidemiology of COVID-19") with st.spinner("Loading Embedding Model..."): pinecone.init(api_key=config.pinecone_api_key, env=config.pinecone_env) if "index" not in st.session_state: st.session_state.index = pinecone.Index(config.pinecone_index) if "tokenizer" not in st.session_state: st.session_state.tokenizer = AutoTokenizer.from_pretrained(config.model_name) if "model" not in st.session_state: st.session_state.model = AutoModel.from_pretrained(config.model_name) if st.button("Search"): with st.spinner("Searching..."): results = search(text) for res in results["matches"]: st.write(f"{res['id']} - confidence: {res['score']:.2f}") else: st.write("Password incorrect!")