fukatani's picture
a
efd0f7c
raw
history blame
2.9 kB
import torch
import numpy as np
from transformers import BertTokenizer, AutoTokenizer, BertModel, AutoModel
def encode_sentence(tokenizer, model, tokens):
is_split = []
input_tokens = ['[CLS]']
for token in tokens:
tmp = tokenizer.tokenize(token)
if len(input_tokens) + len(tmp) >= 511:
break
else:
input_tokens.extend(tmp)
is_split.append(len(tmp))
input_tokens += ["[SEP]"]
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
input_ids = torch.LongTensor([input_ids])
outputs = model(input_ids, output_hidden_states=True).last_hidden_state.detach().numpy()
bertcls = outputs[0, 0, :]
o1 = outputs[0, :, :]
cls_token = o1[0]
tokens_emb = []
i = 1
for j in is_split:
if j == 1:
tokens_emb.append(o1[i])
i += 1
else:
tokens_emb.append(sum(o1[i:i+j]) / j)
# tokens_emb.append(np.max(np.array(o1[i: i+j]), axis=0))
i += j
# if i >= len(is_split):
# break
assert len(tokens_emb) == len(is_split)
return tokens_emb, bertcls, cls_token
def flat_list(l):
return [x for ll in l for x in ll]
def encode_sentences(token_list, tokenizer, model):
tokenizer.do_word_tokenize = False
document_embeddings = []
cnt = 0
for tokens in token_list:
tokens_emb, bertcls, cls_token = encode_sentence(tokenizer, model, tokens)
document_embeddings.append({
'document_id': cnt,
'doc_cls': cls_token,
'doc_bertcls': bertcls,
"tokens": tokens_emb
})
cnt += 1
return document_embeddings
def get_cadidate_embeddings(token_list, document_embeddings, tokens):
document_feats = []
cnt = 0
for candidate_phrase, document_emb, each_tokens in zip(token_list, document_embeddings, tokens):
sentence_emb = document_emb['tokens']
tmp_embeddings = []
tmp_candidate_phrase = []
for tmp, (i, j) in candidate_phrase:
if j<=i:
continue
if j >= len(sentence_emb):
break
# tmp_embeddings.append(sum(sentence_emb[i:j]) / (j-i))
tmp_embeddings.append(np.max(np.array(sentence_emb[i:j]), axis=0))
tmp_candidate_phrase.append(tmp)
candidate_phrases_embeddings = tmp_embeddings
candidate_phrases = tmp_candidate_phrase
document_feats.append({
'document_id': cnt,
'tokens': each_tokens,
'candidate_phrases': candidate_phrases,
'candidate_phrases_embeddings': candidate_phrases_embeddings,
# 'sentence_embeddings': document_emb['doc_bertcls'],
'sentence_embeddings': document_emb['doc_cls'],
})
cnt += 1
return document_feats