File size: 2,902 Bytes
fce29f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import numpy as np

from transformers import BertTokenizer, AutoTokenizer, BertModel, AutoModel


def encode_sentence(tokenizer, model, tokens):
    is_split = []
    input_tokens = ['[CLS]']
    for token in tokens:
        tmp = tokenizer.tokenize(token)
        
        if len(input_tokens) + len(tmp) >= 511:
            break
        else:
            input_tokens.extend(tmp)
            is_split.append(len(tmp))
    input_tokens += ["[SEP]"]
    input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
    
    input_ids = torch.LongTensor([input_ids])
    outputs = model(input_ids, output_hidden_states=True).last_hidden_state.detach().numpy()
    bertcls  = outputs[0, 0, :]
    o1 = outputs[0, :, :]
    cls_token = o1[0]
    
    tokens_emb = []
    i = 1
    for j in is_split:
        if j == 1:
            tokens_emb.append(o1[i])
            i += 1
        else:
            tokens_emb.append(sum(o1[i:i+j]) / j)
            # tokens_emb.append(np.max(np.array(o1[i: i+j]), axis=0))
            i += j
        # if i >= len(is_split):
        #     break
    assert len(tokens_emb) == len(is_split)
    return tokens_emb, bertcls, cls_token

def flat_list(l):
    return [x for ll in l for x in ll]

def encode_sentences(token_list, tokenizer, model):
    tokenizer.do_word_tokenize = False

    document_embeddings = []
    cnt = 0
    for tokens in token_list:
        tokens_emb, bertcls, cls_token = encode_sentence(tokenizer, model, tokens)

        document_embeddings.append({
            'document_id': cnt,
            'doc_cls': cls_token,
            'doc_bertcls': bertcls,
            "tokens": tokens_emb
        })
        cnt += 1

    return document_embeddings


def get_cadidate_embeddings(token_list, document_embeddings, tokens):
    document_feats = []
    cnt = 0
    for candidate_phrase, document_emb, each_tokens in zip(token_list, document_embeddings, tokens):
        sentence_emb = document_emb['tokens']
        
        tmp_embeddings = []
        tmp_candidate_phrase = []
        
        for tmp, (i, j) in candidate_phrase:
            if j<=i:
                continue
            if j >= len(sentence_emb):
                break
            # tmp_embeddings.append(sum(sentence_emb[i:j]) / (j-i))
            tmp_embeddings.append(np.max(np.array(sentence_emb[i:j]), axis=0))
            tmp_candidate_phrase.append(tmp)

        candidate_phrases_embeddings = tmp_embeddings
        candidate_phrases = tmp_candidate_phrase

        document_feats.append({
            'document_id': cnt,
            'tokens': each_tokens,
            'candidate_phrases': candidate_phrases,
            'candidate_phrases_embeddings': candidate_phrases_embeddings,
            # 'sentence_embeddings': document_emb['doc_bertcls'],
            'sentence_embeddings': document_emb['doc_cls'],
        })
        cnt += 1
    return document_feats