fukatani's picture
aa
a95f303
from japanese.embedding import encode_sentences, get_cadidate_embeddings
from japanese.tokenizer import extract_keyphrase_candidates
from japanese.ranker import DirectedCentralityRnak
import streamlit as st
import torch
from transformers import AutoTokenizer
from transformers import AutoModel, AutoModelForMaskedLM
def extract_keyphrase(text):
# load model
model = AutoModel.from_pretrained('cl-tohoku/bert-base-japanese')
tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese')
tokens, keyphrases = extract_keyphrase_candidates(text, tokenizer)
document_embs = encode_sentences([tokens], tokenizer, model)
document_feats = get_cadidate_embeddings([keyphrases], document_embs, [tokens])
ranker = DirectedCentralityRnak(document_feats, beta=0.1, lambda1=1, lambda2=0.9, alpha=1.2, processors=8)
return ranker.extract_summary()[0]
def preparation(tokenized_text, mask):
# [CLS],[SEP]の挿入
tokenized_text.insert(0, '[CLS]') # 単語リストの先頭に[CLS]を付ける
tokenized_text.append('[SEP]') # 単語リストの最後に[SEP]を付ける
maru = []
for i, word in enumerate(tokenized_text):
if word == '。' and i != len(tokenized_text) - 2: # 「。」の位置検出
maru.append(i)
for i, loc in enumerate(maru):
tokenized_text.insert(loc + 1 + i, '[SEP]') # 単語リストの「。」の次に[SEP]を挿入する
# 「□」を[MASK]に置き換え 
mask_index = []
for index, word in enumerate(tokenized_text):
if word == mask: # 「□」の位置検出
tokenized_text[index] = '[MASK]'
mask_index.append(index)
return tokenized_text, mask_index
def mask_prediction(text, mask_word):
model = AutoModelForMaskedLM.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese')
tokens, _ = extract_keyphrase_candidates(text, tokenizer)
tokenized_text = tokenizer.tokenize(text)
tokenized_text, mask_index = preparation(tokenized_text, mask_word) # [CLS],[SEP],[MASK]の追加
tokens = tokenizer.convert_tokens_to_ids(tokenized_text) # IDリストに変換
tokens_tensor = torch.tensor([tokens]) # IDテンソルに変換
model.eval()
with torch.no_grad():
outputs = model(tokens_tensor)
predictions = outputs[0]
for i in range(len(mask_index)):
_, predicted_indexes = torch.topk(predictions[0, mask_index[i]], k=5)
predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_indexes.tolist())
return predicted_tokens
if __name__ == '__main__':
text = st.text_input("origin", "ギリシア人ポリュビオスは,著書『歴史』の中で,ローマ共和政の国制(政治体制)を優れたものと評価している。彼によれば,その国制には,コンスルという王制的要素,元老院という共和制的要素,民衆という民主制的要素が存在しており,これら三者が互いに協調や牽制をしあって均衡しているというのである。ローマ人はこの政治体制を誇りとしており,それは,彼らが自らの国家を指して呼んだ「ローマの元老院と民衆」という名称からも読み取ることができる。")
phrases = extract_keyphrase(text)
for phrase in phrases:
for word in phrase.split("_"):
distracters = mask_prediction(text, word)
if distracters is None:
continue
for distracter in distracters:
st.write(text.replace(word, distracter))