File size: 11,399 Bytes
6eff5e7
 
580aef7
6eff5e7
 
2fad322
6eff5e7
 
 
 
963bf46
6eff5e7
 
 
 
 
963bf46
6eff5e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963bf46
 
 
 
6eff5e7
 
 
 
580aef7
 
6eff5e7
 
 
 
580aef7
6eff5e7
 
 
 
 
 
a6756ef
963bf46
 
 
 
580aef7
 
a6756ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
091bb76
a6756ef
091bb76
 
580aef7
5b4e16a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580aef7
963bf46
 
 
6eff5e7
 
 
 
 
 
 
 
 
580aef7
6eff5e7
 
580aef7
6eff5e7
580aef7
6eff5e7
 
 
 
 
 
091bb76
580aef7
6eff5e7
580aef7
 
091bb76
580aef7
 
 
300debd
6eff5e7
 
 
 
 
 
 
 
 
580aef7
963bf46
 
 
6eff5e7
 
580aef7
 
6eff5e7
 
 
580aef7
 
6eff5e7
091bb76
 
 
 
 
81ca652
091bb76
 
 
 
 
 
 
 
 
 
 
 
300debd
 
 
091bb76
5b4e16a
 
 
 
 
091bb76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6eff5e7
963bf46
4068146
6eff5e7
963bf46
6eff5e7
 
 
 
 
961f39c
 
6eff5e7
 
 
 
 
963bf46
2fad322
963bf46
6eff5e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963bf46
6eff5e7
 
 
81ca652
6eff5e7
0532283
 
 
81ca652
6eff5e7
0532283
6eff5e7
 
 
 
 
81ca652
6eff5e7
81ca652
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
from sentence_transformers import util
from nltk.tokenize import sent_tokenize
from nltk import word_tokenize, pos_tag
import torch
import numpy as np
import tqdm

def compute_sentencewise_scores(model, query_sents, candidate_sents):
    # list of sentences from query and candidate
    q_v, c_v = get_embedding(model, query_sents, candidate_sents)
    
    return util.cos_sim(q_v, c_v)
    
def get_embedding(model, query_sents, candidate_sents):
    q_v = model.encode(query_sents)
    c_v = model.encode(candidate_sents)
   
    return q_v, c_v

def get_top_k(score_mat, K=3):
    """
    Pick top K sentences to show
    """
    idx = torch.argsort(-score_mat)
    picked_sent = idx[:,:K]
    picked_scores = torch.vstack(
        [score_mat[i,picked_sent[i]] for i in range(picked_sent.shape[0])]
    )
    
    return picked_sent, picked_scores

def get_words(sent):
    """
    Input: list of sentences
    Output: list of list of words per sentence, all words in, index of starting words for each sentence
    """
    words = []
    sent_start_id = [] # keep track of the word index where the new sentence starts
    counter = 0
    for x in sent:
        #w = x.split()
        w = word_tokenize(x)
        nw = len(w)
        counter += nw
        words.append(w)
        sent_start_id.append(counter)
    words = [word_tokenize(x) for x in sent]
    all_words = [item for sublist in words for item in sublist]
    sent_start_id.pop()
    sent_start_id = [0] + sent_start_id
    assert(len(sent_start_id) == len(sent))
    return words, all_words, sent_start_id

def get_match_phrase(w1, w2, method='pos'):
    """
    Input: list of words for query and candidate text
    Output: word list and binary mask of matching phrases between the inputs
    """
    mask1 = np.zeros(len(w1))
    mask2 = np.zeros(len(w2))
    if method == 'pos':
        # POS tags that should be considered for matching phrase
        include = [
            'NN',
            'NNS',
            'NNP',
            'NNPS',
            'LS',
            'SYM',
            'FW'
        ]
        pos1 = pos_tag(w1)
        pos2 = pos_tag(w2)
        for i, (w, p) in enumerate(pos2):
            if w.lower() in w1 and p in include:
                j = w1.index(w.lower())
                mask2[i] = 1
                mask1[j] = 1
    return mask1, mask2

def remove_spaces(words, attrs):
    # make the output more readable by removing unnecessary spacings from the tokenizer
    # e.g.
    # 1. spacing for parenthesis
    # 2. spacing for single/double quotations
    # 3. spacing for commas and periods
    # 4. spacing for possessive quotations
    assert(len(words) == len(attrs))
    word_out, attr_out = [], []
    idx, single_q, double_q = 0, 0, 0
    while idx < len(words):
        # stick to the word that appears right before
        if words[idx] in [',', '.', '%', ')', ':', '?', ';', "'s"]:
            ww = word_out.pop()
            aa = attr_out.pop()
            word_out.append(ww + words[idx])
            attr_out.append(aa)
            idx += 1
        # stick to the word that appears right after
        elif words[idx] in ["("]:
            word_out.append(words[idx] + words[idx+1])
            attr_out.append(attrs[idx+1])
            idx += 2
        # quotes 
        elif words[idx] == '"':
            double_q += 1
            if double_q == 2:
                # this is closing quote: stick to word before
                ww = word_out.pop()
                aa = attr_out.pop()
                word_out.append(ww + words[idx])
                attr_out.append(aa)
                idx += 1
                double_q = 0
            else:
                # this is opening quote: stick to the word after
                word_out.append(words[idx] + words[idx+1])
                attr_out.append(attrs[idx+1])
                idx += 2
        elif words[idx] == "'":
            single_q += 1
            if single_q == 2:
                # this is closing quote: stick to word before
                ww = word_out.pop()
                aa = attr_out.pop()
                word_out.append(ww + words[idx])
                attr_out.append(aa)
                idx += 1
                single_q = 0
            else:
                if words[idx-1][-1] == 's': #possessive quote
                    # stick to the word before, reset counter
                    ww = word_out.pop()
                    aa = attr_out.pop()
                    word_out.append(ww + words[idx])
                    attr_out.append(aa)
                    idx += 1
                    single_q = 0
                else:
                    # this is opening quote: stick to the word after
                    word_out.append(words[idx] + words[idx+1])
                    attr_out.append(attrs[idx+1])
                    idx += 2
        else:
            word_out.append(words[idx])
            attr_out.append(attrs[idx])
            idx += 1         
            
    assert(len(word_out) == len(attr_out))
    return word_out, attr_out

def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
    """
    Mark the words that are highlighted, both by in terms of sentence and phrase
    """
    num_query_sent = sent_ids.shape[0]
    num_words = len(all_words)
    
    output = dict()
    output['all_words'] = all_words
    output['words_by_sentence'] = words
    
    # for each query sentence, mark the highlight information
    for i in range(num_query_sent):
        query_words = word_tokenize(query_sents[i])
        is_selected_sent = np.zeros(num_words)
        is_selected_phrase = np.zeros(num_words)
        word_scores = np.zeros(num_words)
        
        # for each selected sentences from the candidate, compile information
        for sid, sscore in zip(sent_ids[i], sent_scores[i]):
            #print(len(sent_start_id), sid, sid+1)
            if sid+1 < len(sent_start_id):
                sent_range = (sent_start_id[sid], sent_start_id[sid+1])
                is_selected_sent[sent_range[0]:sent_range[1]] = 1
                word_scores[sent_range[0]:sent_range[1]] = sscore
                _, is_selected_phrase[sent_range[0]:sent_range[1]] = \
                    get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]])
            else:
                is_selected_sent[sent_start_id[sid]:] = 1
                word_scores[sent_start_id[sid]:] = sscore
                _, is_selected_phrase[sent_start_id[sid]:] = \
                    get_match_phrase(query_words, all_words[sent_start_id[sid]:])
                    
        # update selected phrase scores (-1 meaning a different color in gradio)
        word_scores[is_selected_sent+is_selected_phrase==2] = -0.5
            
        output[i] = {
            'is_selected_sent': is_selected_sent,
            'is_selected_phrase': is_selected_phrase,
            'scores': word_scores
        }

    return output

def get_highlight_info(model, text1, text2, K=None):
    """
    Get highlight information from two texts
    """
    sent1 = sent_tokenize(text1) # query
    sent2 = sent_tokenize(text2) # candidate
    if K is None: # if K is not set, select based on the length of the candidate
        K = int(len(sent2) / 3)
    score_mat = compute_sentencewise_scores(model, sent1, sent2)

    sent_ids, sent_scores = get_top_k(score_mat, K=K)
    words2, all_words2, sent_start_id2 = get_words(sent2)
    info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
    
    # get top sentence pairs from the query and candidate (score, index_pair)
    top_pair_num = 5
    top_pairs = []
    ii = np.unravel_index(np.argsort(np.array(sent_scores).ravel())[-top_pair_num:], sent_scores.shape)
    for i, j in zip(ii[0][::-1], ii[1][::-1]):
        score = sent_scores[i,j].item()
        index_pair = (i, sent_ids[i,j].item())
        top_pairs.append((score, index_pair)) # list of (score, (sent_id_query, sent_id_candidate))
        
    # convert top_pairs to corresponding highlights format for GRadio Interpretation component
    top_pairs_info = dict()
    count = 0
    for s, (sidq, sidc) in top_pairs:
        q_sent = sent1[sidq]
        c_sent = sent2[sidc]
        q_words = word_tokenize(q_sent)
        c_words = word_tokenize(c_sent)
        mask1, mask2 = get_match_phrase(q_words, c_words)
        sc = 0.5
        mask1 *= -sc # mark matching phrases as blue (-1: darkest)
        mask2 *= -sc # mark matching phrases as blue
        assert(len(mask1) == len(q_words) and len(mask2) == len(c_words))
        
        # spacing
        q_words, mask1 = remove_spaces(q_words, mask1)
        c_words, mask2 = remove_spaces(c_words, mask2)
        
        top_pairs_info[count] = {
            'query': {
                'original': q_sent,
                'interpretation': list(zip(q_words, mask1))
            },
            'candidate': {
                'original': c_sent,
                'interpretation': list(zip(c_words, mask2))
            },
            'score': s,
            'sent_idx': (sidq, sidc) 
        }
        count += 1
    
    return sent_ids, sent_scores, info, top_pairs_info

### Document-level operations
# TODO Use specter_MFR
def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20):
    # compute document scores for each papers
    
    # concatenate title and abstract
    title_abs = []
    for t, a in zip(titles, abstracts):
        if t is not None and a is not None:
            # title_abs.append(t + ' [SEP] ' + a)    
            title_abs.append(a)
            
    num_docs = len(title_abs) 
    no_iter = int(np.ceil(num_docs / batch))
    scores = []
    with torch.no_grad():
        # batch
        for i in tqdm.tqdm(range(no_iter)):
            # preprocess the input
            inputs = tokenizer(
                [query] + title_abs[i*batch:(i+1)*batch], 
                padding=True, 
                truncation=True, 
                return_tensors="pt", 
                max_length=512
            )
            inputs.to(doc_model.device)
            result = doc_model(**inputs)
        
            # take the first token in the batch as the embedding
            embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy()
        
            # compute cosine similarity
            q_emb = embeddings[0,:]
            p_emb = embeddings[1:,:]
            nn = np.linalg.norm(q_emb) * np.linalg.norm(p_emb, axis=1)
            scores += list(np.dot(p_emb, q_emb) / nn)

    assert(len(scores) == num_docs)
    
    return scores

def compute_document_score(doc_model, tokenizer, query, papers, batch=5):
    scores = []
    titles = []
    abstracts = []
    urls = []
    for p in papers:
        if p['title'] is not None and p['abstract'] is not None:
            titles.append(p['title'])
            abstracts.append(p['abstract'])
            urls.append(p['url'])
    scores = predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=batch)
    assert(len(scores) == len(abstracts))
    idx_sorted = np.argsort(scores)[::-1]
    
    titles_sorted = [titles[x] for x in idx_sorted]
    abstracts_sorted = [abstracts[x] for x in idx_sorted]
    scores_sorted = [scores[x] for x in idx_sorted]
    urls_sorted = [urls[x] for x in idx_sorted]
    
    return titles_sorted, abstracts_sorted, urls_sorted, scores_sorted