File size: 14,025 Bytes
5ee7598
 
6eff5e7
580aef7
6eff5e7
 
2fad322
6eff5e7
5ee7598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6eff5e7
 
 
 
 
 
 
 
 
 
 
6004e76
 
 
6eff5e7
 
 
963bf46
 
 
 
6eff5e7
 
 
 
580aef7
6eff5e7
 
 
 
580aef7
6eff5e7
 
 
 
 
 
a6756ef
963bf46
 
 
 
580aef7
 
a6756ef
 
 
 
 
 
 
 
 
 
 
 
 
 
5ee7598
 
 
 
091bb76
580aef7
5b4e16a
 
 
 
 
 
 
 
 
 
 
 
5ee7598
5b4e16a
 
 
 
 
 
5ee7598
5b4e16a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ee7598
 
 
 
 
 
 
 
 
 
 
 
5b4e16a
 
 
 
 
 
 
 
6004e76
 
 
 
 
 
 
 
 
580aef7
963bf46
 
 
6eff5e7
6004e76
6eff5e7
 
 
 
 
 
 
 
6004e76
 
 
 
 
 
6eff5e7
6004e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6eff5e7
 
 
6004e76
963bf46
 
 
6eff5e7
 
5ee7598
6004e76
 
 
 
6eff5e7
580aef7
 
6eff5e7
6004e76
091bb76
 
 
81ca652
091bb76
 
 
 
 
 
 
 
 
 
 
 
300debd
 
 
091bb76
5b4e16a
 
 
 
 
091bb76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6eff5e7
963bf46
6eff5e7
963bf46
6eff5e7
 
 
 
 
e7933f3
6eff5e7
 
 
 
 
963bf46
2fad322
963bf46
6eff5e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7933f3
6eff5e7
 
 
81ca652
fd61399
 
6eff5e7
0532283
 
 
81ca652
fd61399
 
e7933f3
 
 
 
6eff5e7
0532283
6eff5e7
 
 
 
 
81ca652
fd61399
 
6eff5e7
fd61399
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
from sentence_transformers import util, SentenceTransformer
from transformers import BertModel
from nltk.tokenize import sent_tokenize
from nltk import word_tokenize, pos_tag
import torch
import numpy as np
import tqdm

def compute_sentencewise_scores(model, query_sents, candidate_sents, tokenizer=None):
    if isinstance(model, SentenceTransformer): 
        # if the model is using SentenceTrasformer style
        q_v, c_v = get_embedding(model, query_sents, candidate_sents)
    elif isinstance(model, BertModel): 
        # if the model is BERT-style model using transformers library
        inputs = tokenizer(
            query_sents + candidate_sents, 
            padding=True, 
            truncation=True, 
            return_tensors="pt", 
            max_length=512
        )
        inputs.to(model.device)
        result = model(**inputs)
        embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy()
        q_v = embeddings[:len(query_sents)]
        c_v = embeddings[len(query_sents):]
    else:
        raise ValueError('model not supported at the time')
    assert(q_v.shape[1] == c_v.shape[1])
    assert(q_v.shape[0] == len(query_sents))
    assert(c_v.shape[0] == len(candidate_sents))
    return util.cos_sim(q_v, c_v)
    
def get_embedding(model, query_sents, candidate_sents):
    q_v = model.encode(query_sents)
    c_v = model.encode(candidate_sents)
    return q_v, c_v

def get_top_k(score_mat, K=3):
    """
    Pick top K sentences to show
    """
    picked_scores, picked_sent = torch.sort(-score_mat, axis=1) 
    picked_sent = picked_sent[:,:K]
    picked_scores = -picked_scores[:,:K]
    return picked_sent, picked_scores

def get_words(sent):
    """
    Input: list of sentences
    Output: list of list of words per sentence, all words in, index of starting words for each sentence
    """
    words = []
    sent_start_id = [] # keep track of the word index where the new sentence starts
    counter = 0
    for x in sent:
        w = word_tokenize(x)
        nw = len(w)
        counter += nw
        words.append(w)
        sent_start_id.append(counter)
    words = [word_tokenize(x) for x in sent]
    all_words = [item for sublist in words for item in sublist]
    sent_start_id.pop()
    sent_start_id = [0] + sent_start_id
    assert(len(sent_start_id) == len(sent))
    return words, all_words, sent_start_id

def get_match_phrase(w1, w2, method='pos'):
    """
    Input: list of words for query and candidate text
    Output: word list and binary mask of matching phrases between the inputs
    """
    mask1 = np.zeros(len(w1))
    mask2 = np.zeros(len(w2))
    if method == 'pos':
        # POS tags that should be considered for matching phrase
        include = [
            'NN',
            'NNS',
            'NNP',
            'NNPS',
            'LS',
            'SYM',
            'FW'
        ]
        pos1 = pos_tag(w1)
        pos2 = pos_tag(w2)
        for i, (w, p) in enumerate(pos2):
            for j, (w_, p_) in enumerate(pos1):
                if w.lower() == w_.lower() and p in include:
                    mask2[i] = 1
                    mask1[j] = 1
    return mask1, mask2

def remove_spaces(words, attrs):
    # make the output more readable by removing unnecessary spacings from the tokenizer
    # e.g.
    # 1. spacing for parenthesis
    # 2. spacing for single/double quotations
    # 3. spacing for commas and periods
    # 4. spacing for possessive quotations
    assert(len(words) == len(attrs))
    word_out, attr_out = [], []
    idx, single_q, double_q = 0, 0, 0
    while idx < len(words):
        # stick to the word that appears right before
        if words[idx] in [',', '.', '%', ')', ':', '?', ';', "'s", '”', "''"]:
            ww = word_out.pop()
            aa = attr_out.pop()
            word_out.append(ww + words[idx])
            attr_out.append(aa)
            idx += 1
        # stick to the word that appears right after
        elif words[idx] in ["(", '“']:
            word_out.append(words[idx] + words[idx+1])
            attr_out.append(attrs[idx+1])
            idx += 2
        # quotes 
        elif words[idx] == '"':
            double_q += 1
            if double_q == 2:
                # this is closing quote: stick to word before
                ww = word_out.pop()
                aa = attr_out.pop()
                word_out.append(ww + words[idx])
                attr_out.append(aa)
                idx += 1
                double_q = 0
            else:
                # this is opening quote: stick to the word after
                word_out.append(words[idx] + words[idx+1])
                attr_out.append(attrs[idx+1])
                idx += 2
        elif words[idx] == "'":
            single_q += 1
            if single_q == 2:
                # this is closing quote: stick to word before
                ww = word_out.pop()
                aa = attr_out.pop()
                word_out.append(ww + words[idx])
                attr_out.append(aa)
                idx += 1
                single_q = 0
            else:
                if words[idx-1][-1] == 's': #possessive quote
                    # stick to the word before, reset counter
                    ww = word_out.pop()
                    aa = attr_out.pop()
                    word_out.append(ww + words[idx])
                    attr_out.append(aa)
                    idx += 1
                    single_q = 0
                else:
                    # this is opening quote: stick to the word after
                    word_out.append(words[idx] + words[idx+1])
                    attr_out.append(attrs[idx+1])
                    idx += 2
        elif words[idx] == '``':
            # this is opening quote: stick to the word after, but change to real double quote
            word_out.append('"' + words[idx+1])
            attr_out.append(attrs[idx+1])
            idx += 2
        elif words[idx] == "''":
            # this is closing quote: stick to word before, but change to real double quote
            ww = word_out.pop()
            aa = attr_out.pop()
            word_out.append(ww + '"')
            attr_out.append(aa)
            idx += 1
        else:
            word_out.append(words[idx])
            attr_out.append(attrs[idx])
            idx += 1         
            
    assert(len(word_out) == len(attr_out))
    return word_out, attr_out

def scale_scores(arr, vmin=0.1, vmax=1):
    # rescale positive and negative attributions to be between vmin and vmax.
    # while keeping 0 at 0.
    pos_max, pos_min = np.max(arr[arr > 0]), np.min(arr[arr > 0])
    out = (arr - pos_min) / (pos_max - pos_min) * (vmax - vmin) + vmin
    idx = np.where(arr == 0.0)[0]
    out[idx] = 0.0
    return out

def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
    """
    Mark the words that are highlighted, both by in terms of sentence and phrase
    """
    num_query_sent = sent_ids.shape[0]
    num_cand_sent = sent_ids.shape[1]
    num_words = len(all_words)
    
    output = dict()
    output['all_words'] = all_words
    output['words_by_sentence'] = words
    
    # for each query sentence, mark the highlight information
    for i in range(num_query_sent):
        output[i] = dict()
        for j in range(1, num_cand_sent+1): # for each number of selected sentences from candidate
            query_words = word_tokenize(query_sents[i])
            is_selected_sent = np.zeros(num_words)
            is_selected_phrase = np.zeros(num_words)
            word_scores = np.zeros(num_words)
            
            # for each selected sentences from the candidate, compile information
            for sid, sscore in zip(sent_ids[i][:j], sent_scores[i][:j]):
                #print(len(sent_start_id), sid, sid+1)
                if sid+1 < len(sent_start_id):
                    sent_range = (sent_start_id[sid], sent_start_id[sid+1])
                    is_selected_sent[sent_range[0]:sent_range[1]] = 1
                    word_scores[sent_range[0]:sent_range[1]] = sscore
                    _, is_selected_phrase[sent_range[0]:sent_range[1]] = \
                        get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]])
                else:
                    is_selected_sent[sent_start_id[sid]:] = 1
                    word_scores[sent_start_id[sid]:] = sscore
                    _, is_selected_phrase[sent_start_id[sid]:] = \
                        get_match_phrase(query_words, all_words[sent_start_id[sid]:])
            
            # scale the word_scores: maximum value gets the darkest, minimum value gets the lightest color
            if j > 1:
                word_scores = scale_scores(word_scores)
                        
            # update selected phrase scores (-1 meaning a different color in gradio)
            word_scores[is_selected_sent+is_selected_phrase==2] = -0.5
                
            output[i][j] = {
                'is_selected_sent': is_selected_sent,
                'is_selected_phrase': is_selected_phrase,
                'scores': word_scores
            }

    return output

def get_highlight_info(model, tokenizer, text1, text2, K=None, top_pair_num=5):
    """
    Get highlight information from two texts
    """
    sent1 = sent_tokenize(text1) # query
    sent2 = sent_tokenize(text2) # candidate
    score_mat = compute_sentencewise_scores(model, sent1, sent2, tokenizer=tokenizer)
    
    if K is None: # if K is not set, get all information
        K = score_mat.shape[1]
        
    sent_ids, sent_scores = get_top_k(score_mat, K=K)
    words2, all_words2, sent_start_id2 = get_words(sent2)
    info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
    
    # get top sentence pairs from the query and candidate (score, index_pair) to show upfront
    top_pairs = []
    ii = np.unravel_index(np.argsort(np.array(sent_scores).ravel())[-top_pair_num:], sent_scores.shape)
    for i, j in zip(ii[0][::-1], ii[1][::-1]):
        score = sent_scores[i,j].item()
        index_pair = (i, sent_ids[i,j].item())
        top_pairs.append((score, index_pair)) # list of (score, (sent_id_query, sent_id_candidate))
        
    # convert top_pairs to corresponding highlights format for GRadio Interpretation component
    top_pairs_info = dict()
    count = 0
    for s, (sidq, sidc) in top_pairs:
        q_sent = sent1[sidq]
        c_sent = sent2[sidc]
        q_words = word_tokenize(q_sent)
        c_words = word_tokenize(c_sent)
        mask1, mask2 = get_match_phrase(q_words, c_words)
        sc = 0.5
        mask1 *= -sc # mark matching phrases as blue (-1: darkest)
        mask2 *= -sc # mark matching phrases as blue
        assert(len(mask1) == len(q_words) and len(mask2) == len(c_words))
        
        # spacing
        q_words, mask1 = remove_spaces(q_words, mask1)
        c_words, mask2 = remove_spaces(c_words, mask2)
        
        top_pairs_info[count] = {
            'query': {
                'original': q_sent,
                'interpretation': list(zip(q_words, mask1))
            },
            'candidate': {
                'original': c_sent,
                'interpretation': list(zip(c_words, mask2))
            },
            'score': s,
            'sent_idx': (sidq, sidc) 
        }
        count += 1
    
    return sent_ids, sent_scores, info, top_pairs_info

### Document-level operations
def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20):
    # compute document scores for each papers
    
    # concatenate title and abstract
    title_abs = []
    for t, a in zip(titles, abstracts):
        if t is not None and a is not None:
            title_abs.append(t + ' [SEP] ' + a) # title + abstract
            
    num_docs = len(title_abs) 
    no_iter = int(np.ceil(num_docs / batch))
    scores = []
    with torch.no_grad():
        # batch
        for i in tqdm.tqdm(range(no_iter)):
            # preprocess the input
            inputs = tokenizer(
                [query] + title_abs[i*batch:(i+1)*batch], 
                padding=True, 
                truncation=True, 
                return_tensors="pt", 
                max_length=512
            )
            inputs.to(doc_model.device)
            result = doc_model(**inputs)
        
            # take the first token in the batch as the embedding
            embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy()
        
            # compute cosine similarity
            q_emb = embeddings[0,:]
            p_emb = embeddings[1:,:]
            nn = np.linalg.norm(q_emb) * np.linalg.norm(p_emb, axis=1)
            scores += list(np.dot(p_emb, q_emb) / nn)

    assert(len(scores) == num_docs)
    
    return scores

def compute_document_score(doc_model, tokenizer, query_title, query_abs, papers, batch=5):
    scores = []
    titles = []
    abstracts = []
    urls = []
    years = []
    citations = []
    for p in papers:
        if p['title'] is not None and p['abstract'] is not None:
            titles.append(p['title'])
            abstracts.append(p['abstract'])
            urls.append(p['url'])
            years.append(p['year'])
            citations.append(p['citationCount'])
    if query_title == '':
        query = query_abs
    else:
        query = query_title + ' [SEP] ' + query_abs # feed in submission title and abstract
    scores = predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=batch)
    assert(len(scores) == len(abstracts))
    idx_sorted = np.argsort(scores)[::-1]
    
    titles_sorted = [titles[x] for x in idx_sorted]
    abstracts_sorted = [abstracts[x] for x in idx_sorted]
    scores_sorted = [scores[x] for x in idx_sorted]
    urls_sorted = [urls[x] for x in idx_sorted]
    years_sorted = [years[x] for x in idx_sorted]
    citations_sorted = [citations[x] for x in idx_sorted]
    
    return titles_sorted, abstracts_sorted, urls_sorted, scores_sorted, years_sorted, citations_sorted