kaisugi's picture
fix
b62fe6e
raw
history blame
No virus
2.48 kB
import faiss
import numpy as np
import pandas as pd
import streamlit as st
import torch
from torch import Tensor
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
@st.cache_resource
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large')
model.eval()
return model, tokenizer
@st.cache_resource
def load_title_data():
title_df = pd.read_csv('anlp2024.tsv', names=["pid", "title"], sep="\t")
return title_df
@st.cache_resource
def load_title_embeddings():
npz_comp = np.load("anlp2024.npz")
title_embeddings = npz_comp["arr_0"]
return title_embeddings
def get_retrieval_results(index, input_text, top_k, tokenizer, title_df):
batch_dict = tokenizer([f"query: {input_text}"], max_length=512, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
outputs = model(**batch_dict)
query_embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
_, ids = index.search(x=query_embeddings.detach().numpy().copy(), k=top_k)
retrieved_titles = []
retrieved_pids = []
for id in ids[0]:
retrieved_titles.append(title_df.loc[id, "title"])
retrieved_pids.append(title_df.loc[id, "pid"])
df = pd.DataFrame({"pids": retrieved_pids, "paper": retrieved_titles})
return df
if __name__ == "__main__":
model, tokenizer = load_model_and_tokenizer()
title_df = load_title_data()
title_embeddings = load_title_embeddings()
index = faiss.IndexFlatL2(1024)
index.add(title_embeddings)
st.markdown("## NLP2024 論文検索")
input_text = st.text_input('query', '', placeholder='')
top_k = st.number_input('top_k', min_value=1, value=10, step=1)
if st.button('検索'):
stripped_input_text = input_text.strip()
df = get_retrieval_results(index, stripped_input_text, top_k, tokenizer, title_df)
st.table(df)