import faiss import numpy as np import pandas as pd import streamlit as st import torch from torch import Tensor import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer import os os.environ['KMP_DUPLICATE_LIB_OK']='True' def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] @st.cache_resource def load_model_and_tokenizer(): tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large') model = AutoModel.from_pretrained('intfloat/multilingual-e5-large') model.eval() return model, tokenizer @st.cache_resource def load_title_data(): title_df = pd.read_csv('anlp2024.tsv', names=["pid", "title"], sep="\t") return title_df @st.cache_resource def load_title_embeddings(): npz_comp = np.load("anlp2024.npz") title_embeddings = npz_comp["arr_0"] return title_embeddings def get_retrieval_results(index, input_text, top_k, tokenizer, title_df): batch_dict = tokenizer([f"query: {input_text}"], max_length=512, padding=True, truncation=True, return_tensors='pt') with torch.no_grad(): outputs = model(**batch_dict) query_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) query_embeddings = F.normalize(query_embeddings, p=2, dim=1) _, ids = index.search(x=query_embeddings.detach().numpy().copy(), k=top_k) retrieved_titles = [] retrieved_pids = [] for id in ids[0]: retrieved_titles.append(title_df.loc[id, "title"]) retrieved_pids.append(title_df.loc[id, "pid"]) df = pd.DataFrame({"pids": retrieved_pids, "paper": retrieved_titles}) return df if __name__ == "__main__": model, tokenizer = load_model_and_tokenizer() title_df = load_title_data() title_embeddings = load_title_embeddings() index = faiss.IndexFlatL2(1024) index.add(title_embeddings) st.markdown("## NLP2024 論文検索") input_text = st.text_input('query', '', placeholder='') top_k = st.number_input('top_k', min_value=1, value=10, step=1) if st.button('検索'): stripped_input_text = input_text.strip() df = get_retrieval_results(index, stripped_input_text, top_k, tokenizer, title_df) st.table(df)