Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
from transformers import BertTokenizer, AutoTokenizer, BertModel, AutoModel | |
def encode_sentence(tokenizer, model, tokens): | |
is_split = [] | |
input_tokens = ['[CLS]'] | |
for token in tokens: | |
tmp = tokenizer.tokenize(token) | |
if len(input_tokens) + len(tmp) >= 511: | |
break | |
else: | |
input_tokens.extend(tmp) | |
is_split.append(len(tmp)) | |
input_tokens += ["[SEP]"] | |
input_ids = tokenizer.convert_tokens_to_ids(input_tokens) | |
input_ids = torch.LongTensor([input_ids]) | |
outputs = model(input_ids, output_hidden_states=True).last_hidden_state.detach().numpy() | |
bertcls = outputs[0, 0, :] | |
o1 = outputs[0, :, :] | |
cls_token = o1[0] | |
tokens_emb = [] | |
i = 1 | |
for j in is_split: | |
if j == 1: | |
tokens_emb.append(o1[i]) | |
i += 1 | |
else: | |
tokens_emb.append(sum(o1[i:i+j]) / j) | |
# tokens_emb.append(np.max(np.array(o1[i: i+j]), axis=0)) | |
i += j | |
# if i >= len(is_split): | |
# break | |
assert len(tokens_emb) == len(is_split) | |
return tokens_emb, bertcls, cls_token | |
def flat_list(l): | |
return [x for ll in l for x in ll] | |
def encode_sentences(token_list, tokenizer, model): | |
tokenizer.do_word_tokenize = False | |
document_embeddings = [] | |
cnt = 0 | |
for tokens in token_list: | |
tokens_emb, bertcls, cls_token = encode_sentence(tokenizer, model, tokens) | |
document_embeddings.append({ | |
'document_id': cnt, | |
'doc_cls': cls_token, | |
'doc_bertcls': bertcls, | |
"tokens": tokens_emb | |
}) | |
cnt += 1 | |
return document_embeddings | |
def get_cadidate_embeddings(token_list, document_embeddings, tokens): | |
document_feats = [] | |
cnt = 0 | |
for candidate_phrase, document_emb, each_tokens in zip(token_list, document_embeddings, tokens): | |
sentence_emb = document_emb['tokens'] | |
tmp_embeddings = [] | |
tmp_candidate_phrase = [] | |
for tmp, (i, j) in candidate_phrase: | |
if j<=i: | |
continue | |
if j >= len(sentence_emb): | |
break | |
# tmp_embeddings.append(sum(sentence_emb[i:j]) / (j-i)) | |
tmp_embeddings.append(np.max(np.array(sentence_emb[i:j]), axis=0)) | |
tmp_candidate_phrase.append(tmp) | |
candidate_phrases_embeddings = tmp_embeddings | |
candidate_phrases = tmp_candidate_phrase | |
document_feats.append({ | |
'document_id': cnt, | |
'tokens': each_tokens, | |
'candidate_phrases': candidate_phrases, | |
'candidate_phrases_embeddings': candidate_phrases_embeddings, | |
# 'sentence_embeddings': document_emb['doc_bertcls'], | |
'sentence_embeddings': document_emb['doc_cls'], | |
}) | |
cnt += 1 | |
return document_feats | |