import faiss import numpy as np import pandas as pd import streamlit as st import torch from transformers import AutoModel, AutoTokenizer import os os.environ['KMP_DUPLICATE_LIB_OK']='True' @st.cache(allow_output_mutation=True) def load_model_and_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-v2") model = AutoModel.from_pretrained("kaisugi/anlp_embedding_model") model.eval() return model, tokenizer @st.cache(allow_output_mutation=True) def load_title_data(): title_df = pd.read_csv("anlp2023.csv") return title_df @st.cache(allow_output_mutation=True) def load_title_embeddings(): npz_comp = np.load("anlp_title_embeddings.npz") title_embeddings = npz_comp["arr_0"] return title_embeddings @st.cache def get_retrieval_results(index, input_text, top_k, tokenizer, title_df): with torch.no_grad(): inputs = tokenizer.encode_plus( input_text, padding=True, truncation="only_second", return_tensors="pt", max_length=512, ) outputs = model(**inputs) query_embeddings = outputs.last_hidden_state[:, 0, :][0] query_embeddings = query_embeddings.detach().cpu().numpy() _, ids = index.search(x=np.array([query_embeddings]), k=top_k) retrieved_titles = [] retrieved_pids = [] for id in ids[0]: retrieved_titles.append(title_df.loc[id, "title"]) retrieved_pids.append(title_df.loc[id, "pid"]) df = pd.DataFrame({"pids": retrieved_pids, "paper": retrieved_titles}) return df if __name__ == "__main__": model, tokenizer = load_model_and_tokenizer() title_df = load_title_data() title_embeddings = load_title_embeddings() index = faiss.IndexFlatL2(768) index.add(title_embeddings) st.markdown("## NLP2023 類似論文検索") input_text = st.text_input('input', '', placeholder='ここに論文のタイトルを入力してください') top_k = st.number_input('top_k', min_value=1, value=10, step=1) if st.button('検索'): stripped_input_text = input_text.strip() df = get_retrieval_results(index, stripped_input_text, top_k, tokenizer, title_df) st.table(df)