from transformers import AutoModel, AutoTokenizer import faiss import numpy as np import pandas as pd import streamlit as st import torch import math import os import re os.environ['KMP_DUPLICATE_LIB_OK']='True' @st.cache(allow_output_mutation=True) def load_model_and_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("kaisugi/scitoricsbert") model = AutoModel.from_pretrained("kaisugi/scitoricsbert", output_attentions=True) model.eval() return model, tokenizer @st.cache(allow_output_mutation=True) def load_sentence_data(): sentence_df = pd.read_csv("sentence_data_858k.csv.gz") return sentence_df @st.cache(allow_output_mutation=True) def load_sentence_embeddings_and_index(): npz_comp = np.load("sentence_embeddings_858k.npz") sentence_embeddings = npz_comp["arr_0"] faiss.normalize_L2(sentence_embeddings) D = 768 N = 857610 Xt = sentence_embeddings[:100000] X = sentence_embeddings # Param of PQ M = 16 # The number of sub-vector. Typically this is 8, 16, 32, etc. nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte # Param of IVF nlist = int(math.sqrt(N)) # The number of cells (space partition). Typical value is sqrt(N) # Param of HNSW hnsw_m = 32 # The number of neighbors for HNSW. This is typically 32 # Setup quantizer = faiss.IndexHNSWFlat(D, hnsw_m) index = faiss.IndexIVFPQ(quantizer, D, nlist, M, nbits) # Train index.train(Xt) # Add index.add(X) # Search index.nprobe = 8 # Runtime param. The number of cells that are visited for search. return sentence_embeddings, index @st.cache(allow_output_mutation=True) def formulaic_phrase_extraction(sentences, model, tokenizer): THRESHOLD = 0.01 LAYER = 10 output_sentences = [] with torch.no_grad(): inputs = tokenizer.batch_encode_plus( sentences, padding=True, truncation=True, max_length=512, return_tensors='pt' ) outputs = model(**inputs) attention = outputs[-1] cls_attentions = torch.mean(attention[LAYER][0], dim=0) for sentence, cls_attention in zip(sentences, cls_attentions): check_bool_arr = list((cls_attention > THRESHOLD).numpy())[1:-1] tokens = tokenizer.tokenize(sentence) cur_tokens = tokens.copy() while True: flg = False for idx, token in enumerate(cur_tokens): if token.startswith("##"): flg = True back_token = token.replace("##", "") front_token = cur_tokens.pop(idx-1) cur_tokens[idx-1] = front_token + back_token back_bool_val = check_bool_arr[idx] front_bool_val = check_bool_arr.pop(idx-1) check_bool_arr[idx-1] = front_bool_val and back_bool_val if not flg: break result = " ".join([f'{original_word}' if b else original_word for (b, original_word) in zip(check_bool_arr, sentence.split())]) output_sentences.append(result) return output_sentences @st.cache(allow_output_mutation=True) def get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=True): with torch.no_grad(): inputs = tokenizer.encode_plus( input_text, padding=True, truncation=True, max_length=512, return_tensors='pt' ) outputs = model(**inputs) query_embeddings = outputs.last_hidden_state[:, 0, :][0] query_embeddings = query_embeddings.detach().cpu().numpy() query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, ord=2) _, ids = index.search(x=np.array([query_embeddings]), k=top_k) retrieved_sentences = [] retrieved_paper_ids = [] for id in ids[0]: cur_sentence = sentence_df.loc[id, "sentence"] cur_link = f"https://aclanthology.org/{sentence_df.loc[id, 'file_id']}" if len(exclude_word_list) == 0: retrieved_sentences.append(cur_sentence) retrieved_paper_ids.append(cur_link) else: exclude_word_list_regex = '|'.join(exclude_word_list) pat = re.compile(f'{exclude_word_list_regex}') if not bool(pat.search(cur_sentence)): retrieved_sentences.append(cur_sentence) retrieved_paper_ids.append(cur_link) if phrase_annotated: retrieved_sentences = formulaic_phrase_extraction(retrieved_sentences, model, tokenizer) return retrieved_sentences, retrieved_paper_ids if __name__ == "__main__": model, tokenizer = load_model_and_tokenizer() sentence_df = load_sentence_data() sentence_embeddings, index = load_sentence_embeddings_and_index() st.markdown("## AI-based Paraphrasing for Academic Writing") input_text = st.text_area("text input", "Our model shows good results.", placeholder="Write something here...") top_k = st.number_input('top_k (upperbound)', min_value=1, value=30, step=1) input_words = st.text_input("exclude words (comma separated)", "good, result") agree = st.checkbox('Include phrase annotation') if st.button('search'): exclude_word_list = [s.strip() for s in input_words.split(",") if s.strip() != ""] retrieved_sentences, retrieved_paper_ids = get_retrieval_results(index, input_text, top_k, model, tokenizer, sentence_df, exclude_word_list, phrase_annotated=agree) result_table_markdown = "| sentence | source link |\n|:---|:---|\n" for (retrieved_sentence, retrieved_paper_id) in zip(retrieved_sentences, retrieved_paper_ids): result_table_markdown += f"| {retrieved_sentence} | {retrieved_paper_id} |\n" st.markdown(result_table_markdown, unsafe_allow_html=True) st.markdown("---\n#### How this works") st.markdown("This app uses ScitoricsBERT [(Sugimoto and Aizawa, 2022)](https://aclanthology.org/2022.sdp-1.7/), a functional sentence representation model, to retrieve sentences that are functionally similar to the input. It also extracts phrasal patterns that accord to the function, by leveraging self-attention patterns within ScitoricsBERT.")