kaisugi's picture
fix
3d40d6a
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch
from torch import Tensor
from transformers import AutoModel, AutoTokenizer
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
@st.cache_resource
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large')
model.eval()
return model, tokenizer
@st.cache_resource
def load_title_data():
title_df = pd.read_csv('anlp2024.tsv', names=["pid", "title"], sep="\t")
return title_df
@st.cache_resource
def load_title_embeddings():
npz_comp = np.load("anlp2024.npz")
title_embeddings = npz_comp["arr_0"]
return title_embeddings
@st.cache_data
def get_retrieval_results(index, input_text, top_k, tokenizer, title_df):
batch_dict = tokenizer(f"query: {input_text}", max_length=512, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
outputs = model(**batch_dict)
query_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
_, ids = index.search(x=np.array([query_embeddings]), k=top_k)
retrieved_titles = []
retrieved_pids = []
for id in ids[0]:
retrieved_titles.append(title_df.loc[id, "title"])
retrieved_pids.append(title_df.loc[id, "pid"])
df = pd.DataFrame({"pids": retrieved_pids, "paper": retrieved_titles})
return df
if __name__ == "__main__":
model, tokenizer = load_model_and_tokenizer()
title_df = load_title_data()
title_embeddings = load_title_embeddings()
index = faiss.IndexFlatL2(1024)
index.add(title_embeddings)
st.markdown("## NLP2024 類似論文検索")
input_text = st.text_input('input', '', placeholder='ここに論文のタイトルを入力してください')
top_k = st.number_input('top_k', min_value=1, value=10, step=1)
if st.button('検索'):
stripped_input_text = input_text.strip()
df = get_retrieval_results(index, stripped_input_text, top_k, tokenizer, title_df)
st.table(df)