from sentence_transformers import util from nltk.tokenize import sent_tokenize from nltk import word_tokenize, pos_tag import torch import numpy as np import tqdm def compute_sentencewise_scores(model, query_sents, candidate_sents): # list of sentences from query and candidate q_v, c_v = get_embedding(model, query_sents, candidate_sents) return util.cos_sim(q_v, c_v) def get_embedding(model, query_sents, candidate_sents): q_v = model.encode(query_sents) c_v = model.encode(candidate_sents) return q_v, c_v def get_top_k(score_mat, K=3): """ Pick top K sentences to show """ idx = torch.argsort(-score_mat) picked_sent = idx[:,:K] picked_scores = torch.vstack( [score_mat[i,picked_sent[i]] for i in range(picked_sent.shape[0])] ) return picked_sent, picked_scores def get_words(sent): """ Input: list of sentences Output: list of list of words per sentence, all words in, index of starting words for each sentence """ words = [] sent_start_id = [] # keep track of the word index where the new sentence starts counter = 0 for x in sent: #w = x.split() w = word_tokenize(x) nw = len(w) counter += nw words.append(w) sent_start_id.append(counter) words = [word_tokenize(x) for x in sent] all_words = [item for sublist in words for item in sublist] sent_start_id.pop() sent_start_id = [0] + sent_start_id assert(len(sent_start_id) == len(sent)) return words, all_words, sent_start_id def get_match_phrase(w1, w2, method='pos'): """ Input: list of words for query and candidate text Output: word list and binary mask of matching phrases between the inputs """ mask1 = np.zeros(len(w1)) mask2 = np.zeros(len(w2)) if method == 'pos': # POS tags that should be considered for matching phrase include = [ 'NN', 'NNS', 'NNP', 'NNPS', 'LS', 'SYM', 'FW' ] pos1 = pos_tag(w1) pos2 = pos_tag(w2) for i, (w, p) in enumerate(pos2): if w.lower() in w1 and p in include: j = w1.index(w.lower()) mask2[i] = 1 mask1[j] = 1 return mask1, mask2 def remove_spaces(words, attrs): # make the output more readable by removing unnecessary spacings from the tokenizer # e.g. # 1. spacing for parenthesis # 2. spacing for single/double quotations # 3. spacing for commas and periods # 4. spacing for possessive quotations assert(len(words) == len(attrs)) word_out, attr_out = [], [] idx, single_q, double_q = 0, 0, 0 while idx < len(words): # stick to the word that appears right before if words[idx] in [',', '.', '%', ')', ':', '?', ';', "'s"]: ww = word_out.pop() aa = attr_out.pop() word_out.append(ww + words[idx]) attr_out.append(aa) idx += 1 # stick to the word that appears right after elif words[idx] in ["("]: word_out.append(words[idx] + words[idx+1]) attr_out.append(attrs[idx+1]) idx += 2 # quotes elif words[idx] == '"': double_q += 1 if double_q == 2: # this is closing quote: stick to word before ww = word_out.pop() aa = attr_out.pop() word_out.append(ww + words[idx]) attr_out.append(aa) idx += 1 double_q = 0 else: # this is opening quote: stick to the word after word_out.append(words[idx] + words[idx+1]) attr_out.append(attrs[idx+1]) idx += 2 elif words[idx] == "'": single_q += 1 if single_q == 2: # this is closing quote: stick to word before ww = word_out.pop() aa = attr_out.pop() word_out.append(ww + words[idx]) attr_out.append(aa) idx += 1 single_q = 0 else: if words[idx-1][-1] == 's': #possessive quote # stick to the word before, reset counter ww = word_out.pop() aa = attr_out.pop() word_out.append(ww + words[idx]) attr_out.append(aa) idx += 1 single_q = 0 else: # this is opening quote: stick to the word after word_out.append(words[idx] + words[idx+1]) attr_out.append(attrs[idx+1]) idx += 2 else: word_out.append(words[idx]) attr_out.append(attrs[idx]) idx += 1 assert(len(word_out) == len(attr_out)) return word_out, attr_out def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores): """ Mark the words that are highlighted, both by in terms of sentence and phrase """ num_query_sent = sent_ids.shape[0] num_words = len(all_words) output = dict() output['all_words'] = all_words output['words_by_sentence'] = words # for each query sentence, mark the highlight information for i in range(num_query_sent): query_words = word_tokenize(query_sents[i]) is_selected_sent = np.zeros(num_words) is_selected_phrase = np.zeros(num_words) word_scores = np.zeros(num_words) # for each selected sentences from the candidate, compile information for sid, sscore in zip(sent_ids[i], sent_scores[i]): #print(len(sent_start_id), sid, sid+1) if sid+1 < len(sent_start_id): sent_range = (sent_start_id[sid], sent_start_id[sid+1]) is_selected_sent[sent_range[0]:sent_range[1]] = 1 word_scores[sent_range[0]:sent_range[1]] = sscore _, is_selected_phrase[sent_range[0]:sent_range[1]] = \ get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]]) else: is_selected_sent[sent_start_id[sid]:] = 1 word_scores[sent_start_id[sid]:] = sscore _, is_selected_phrase[sent_start_id[sid]:] = \ get_match_phrase(query_words, all_words[sent_start_id[sid]:]) # update selected phrase scores (-1 meaning a different color in gradio) word_scores[is_selected_sent+is_selected_phrase==2] = -0.5 output[i] = { 'is_selected_sent': is_selected_sent, 'is_selected_phrase': is_selected_phrase, 'scores': word_scores } return output def get_highlight_info(model, text1, text2, K=None): """ Get highlight information from two texts """ sent1 = sent_tokenize(text1) # query sent2 = sent_tokenize(text2) # candidate if K is None: # if K is not set, select based on the length of the candidate K = int(len(sent2) / 3) score_mat = compute_sentencewise_scores(model, sent1, sent2) sent_ids, sent_scores = get_top_k(score_mat, K=K) words2, all_words2, sent_start_id2 = get_words(sent2) info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores) # get top sentence pairs from the query and candidate (score, index_pair) top_pair_num = 5 top_pairs = [] ii = np.unravel_index(np.argsort(np.array(sent_scores).ravel())[-top_pair_num:], sent_scores.shape) for i, j in zip(ii[0][::-1], ii[1][::-1]): score = sent_scores[i,j].item() index_pair = (i, sent_ids[i,j].item()) top_pairs.append((score, index_pair)) # list of (score, (sent_id_query, sent_id_candidate)) # convert top_pairs to corresponding highlights format for GRadio Interpretation component top_pairs_info = dict() count = 0 for s, (sidq, sidc) in top_pairs: q_sent = sent1[sidq] c_sent = sent2[sidc] q_words = word_tokenize(q_sent) c_words = word_tokenize(c_sent) mask1, mask2 = get_match_phrase(q_words, c_words) sc = 0.5 mask1 *= -sc # mark matching phrases as blue (-1: darkest) mask2 *= -sc # mark matching phrases as blue assert(len(mask1) == len(q_words) and len(mask2) == len(c_words)) # spacing q_words, mask1 = remove_spaces(q_words, mask1) c_words, mask2 = remove_spaces(c_words, mask2) top_pairs_info[count] = { 'query': { 'original': q_sent, 'interpretation': list(zip(q_words, mask1)) }, 'candidate': { 'original': c_sent, 'interpretation': list(zip(c_words, mask2)) }, 'score': s, 'sent_idx': (sidq, sidc) } count += 1 return sent_ids, sent_scores, info, top_pairs_info ### Document-level operations # TODO Use specter_MFR def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20): # compute document scores for each papers # concatenate title and abstract title_abs = [] for t, a in zip(titles, abstracts): if t is not None and a is not None: # title_abs.append(t + ' [SEP] ' + a) title_abs.append(a) num_docs = len(title_abs) no_iter = int(np.ceil(num_docs / batch)) scores = [] with torch.no_grad(): # batch for i in tqdm.tqdm(range(no_iter)): # preprocess the input inputs = tokenizer( [query] + title_abs[i*batch:(i+1)*batch], padding=True, truncation=True, return_tensors="pt", max_length=512 ) inputs.to(doc_model.device) result = doc_model(**inputs) # take the first token in the batch as the embedding embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy() # compute cosine similarity q_emb = embeddings[0,:] p_emb = embeddings[1:,:] nn = np.linalg.norm(q_emb) * np.linalg.norm(p_emb, axis=1) scores += list(np.dot(p_emb, q_emb) / nn) assert(len(scores) == num_docs) return scores def compute_document_score(doc_model, tokenizer, query, papers, batch=5): scores = [] titles = [] abstracts = [] urls = [] for p in papers: if p['title'] is not None and p['abstract'] is not None: titles.append(p['title']) abstracts.append(p['abstract']) urls.append(p['url']) scores = predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=batch) assert(len(scores) == len(abstracts)) idx_sorted = np.argsort(scores)[::-1] titles_sorted = [titles[x] for x in idx_sorted] abstracts_sorted = [abstracts[x] for x in idx_sorted] scores_sorted = [scores[x] for x in idx_sorted] urls_sorted = [urls[x] for x in idx_sorted] return titles_sorted, abstracts_sorted, urls_sorted, scores_sorted