fukatani commited on
Commit
6d0fa58
1 Parent(s): 7822f48
Files changed (1) hide show
  1. app.py +59 -8
app.py CHANGED
@@ -1,24 +1,75 @@
1
- import streamlit as st
2
-
3
  from japanese.embedding import encode_sentences, get_cadidate_embeddings
4
  from japanese.tokenizer import extract_keyphrase_candidates
5
  from japanese.ranker import DirectedCentralityRnak
6
 
 
7
  from transformers import AutoTokenizer
8
- from transformers import AutoModel
9
-
10
 
11
- if __name__ == '__main__':
12
  # load model
13
  model = AutoModel.from_pretrained('cl-tohoku/bert-base-japanese')
14
  tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese')
15
 
16
- text = st.text_input("origin", "紀元前509年、第7代の王タルクィニウス・スペルブスを追放し共和制を敷いたローマだが、問題は山積していた。まず、王に代わった執政官(コンスル)が元老院の意向で決められるようになったこと、またその被選挙権が40歳以上に限定されていたことから、若い市民を中心としてタルクィニウスを王位に復する王政復古の企みが起こった。これは失敗して、初代執政官ルキウス・ユニウス・ブルトゥスは、彼自身の息子ティトゥスを含む陰謀への参加者を処刑した。ラテン同盟諸都市やエトルリア諸都市との同盟は、これらの都市とローマ王との同盟という形であったため、王の追放で当然に同盟は解消され、対立関係となった。")
17
  tokens, keyphrases = extract_keyphrase_candidates(text, tokenizer)
18
 
19
  document_embs = encode_sentences([tokens], tokenizer, model)
20
  document_feats = get_cadidate_embeddings([keyphrases], document_embs, [tokens])
21
  ranker = DirectedCentralityRnak(document_feats, beta=0.1, lambda1=1, lambda2=0.9, alpha=1.2, processors=8)
22
- phrases = ranker.extract_summary()
23
- st.write(phrases)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
1
  from japanese.embedding import encode_sentences, get_cadidate_embeddings
2
  from japanese.tokenizer import extract_keyphrase_candidates
3
  from japanese.ranker import DirectedCentralityRnak
4
 
5
+ import torch
6
  from transformers import AutoTokenizer
7
+ from transformers import AutoModel, AutoModelForMaskedLM
 
8
 
9
+ def extract_keyphrase(text):
10
  # load model
11
  model = AutoModel.from_pretrained('cl-tohoku/bert-base-japanese')
12
  tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese')
13
 
 
14
  tokens, keyphrases = extract_keyphrase_candidates(text, tokenizer)
15
 
16
  document_embs = encode_sentences([tokens], tokenizer, model)
17
  document_feats = get_cadidate_embeddings([keyphrases], document_embs, [tokens])
18
  ranker = DirectedCentralityRnak(document_feats, beta=0.1, lambda1=1, lambda2=0.9, alpha=1.2, processors=8)
19
+ return ranker.extract_summary()[0]
20
+
21
+
22
+ def preparation(tokenized_text, mask):
23
+ # [CLS],[SEP]の挿入
24
+ tokenized_text.insert(0, '[CLS]') # 単語リストの先頭に[CLS]を付ける
25
+ tokenized_text.append('[SEP]') # 単語リストの最後に[SEP]を付ける
26
+
27
+ maru = []
28
+ for i, word in enumerate(tokenized_text):
29
+ if word == '。' and i != len(tokenized_text) - 2: # 「。」の位置検出
30
+ maru.append(i)
31
+
32
+ for i, loc in enumerate(maru):
33
+ tokenized_text.insert(loc + 1 + i, '[SEP]') # 単語リストの「。」の次に[SEP]を挿入する
34
+
35
+ # 「□」を[MASK]に置き換え 
36
+ mask_index = []
37
+ for index, word in enumerate(tokenized_text):
38
+ if word == mask: # 「□」の位置検出
39
+ tokenized_text[index] = '[MASK]'
40
+ mask_index.append(index)
41
+
42
+ return tokenized_text, mask_index
43
+
44
+
45
+ def mask_prediction(text, mask_word):
46
+ model = AutoModelForMaskedLM.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
47
+ tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese')
48
+
49
+ tokens, _ = extract_keyphrase_candidates(text, tokenizer)
50
+ tokenized_text = tokenizer.tokenize(text)
51
+ tokenized_text, mask_index = preparation(tokenized_text, mask_word) # [CLS],[SEP],[MASK]の追加
52
+ tokens = tokenizer.convert_tokens_to_ids(tokenized_text) # IDリストに変換
53
+ tokens_tensor = torch.tensor([tokens]) # IDテンソルに変換
54
+
55
+ model.eval()
56
+ with torch.no_grad():
57
+ outputs = model(tokens_tensor)
58
+ predictions = outputs[0]
59
+ for i in range(len(mask_index)):
60
+ _, predicted_indexes = torch.topk(predictions[0, mask_index[i]], k=5)
61
+ predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_indexes.tolist())
62
+ return predicted_tokens
63
+
64
+
65
+ if __name__ == '__main__':
66
+ text = st.text_input("origin", "ギリシア人ポリュビオスは,著書『歴史』の中で,ローマ共和政の国制(政治体制)を優れたものと評価している。彼によれば,その国制には,コンスルという王制的要素,元老院という共和制的要素,民衆という民主制的要素が存在しており,これら三者が互いに協調や牽制をしあって均衡しているというのである。ローマ人はこの政治体制を誇りとしており,それは,彼らが自らの国家を指して呼んだ「ローマの元老院と民衆」という名称からも読み取ることができる。")
67
+ phrases = extract_keyphrase(text)
68
+ for phrase in phrases:
69
+ for word in phrase.split("_"):
70
+ distracters = mask_prediction(text, word)
71
+ if distracters is None:
72
+ continue
73
+ for distracter in distracters:
74
+ st.write(text.replace(word, distracter))
75