Ayemos commited on
Commit
c603b9e
1 Parent(s): d3eb806

initial commit

Browse files
Files changed (2) hide show
  1. app.py +93 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, T5Tokenizer
7
+
8
+ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
9
+ tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
10
+ tokenizer.do_lower_case = True
11
+
12
+ model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
13
+ model.to(device)
14
+
15
+
16
+ def calculate_surprisals(
17
+ input_text: str, normalize_surprisals: bool = True
18
+ ) -> Tuple[float, List[Tuple[str, float]]]:
19
+ input_tokens = [
20
+ token.replace("▁", "")
21
+ for token in tokenizer.tokenize(input_text)
22
+ if token != "▁"
23
+ ]
24
+ input_ids = tokenizer.encode(
25
+ "<s>" + input_text, add_special_tokens=False, return_tensors="pt"
26
+ ).to(device)
27
+
28
+ logits = model(input_ids)["logits"].squeeze(0)
29
+
30
+ surprisals = []
31
+ for i in range(logits.shape[0] - 1):
32
+ if input_ids[0][i + 1] == 9:
33
+ continue
34
+ logit = logits[i]
35
+ prob = torch.softmax(logit, dim=0)
36
+ neg_logprob = -torch.log(prob)
37
+ surprisals.append(neg_logprob[input_ids[0][i + 1]].item())
38
+ mean_surprisal = np.mean(surprisals)
39
+
40
+ if normalize_surprisals:
41
+ min_surprisal = np.min(surprisals)
42
+ max_surprisal = np.max(surprisals)
43
+ surprisals = [
44
+ (surprisal - min_surprisal) / (max_surprisal - min_surprisal)
45
+ for surprisal in surprisals
46
+ ]
47
+ assert min(surprisals) >= 0
48
+ assert max(surprisals) <= 1
49
+
50
+ tokens2surprisal: List[Tuple[str, float]] = []
51
+ for token, surprisal in zip(input_tokens, surprisals):
52
+ tokens2surprisal.append((token, surprisal))
53
+
54
+ return mean_surprisal, tokens2surprisal
55
+
56
+
57
+ def highlight_token(token: str, score: float):
58
+ html_color = "#%02X%02X%02X" % (255, int(255 * (1 - score)), int(255 * (1 - score)))
59
+ return '<span style="background-color: {}; color: black">{}</span>'.format(
60
+ html_color, token
61
+ )
62
+
63
+
64
+ def create_highlighted_text(tokens2scores: List[Tuple[str, float]]):
65
+ highlighted_text: str = ""
66
+ for token, score in tokens2scores:
67
+ highlighted_text += highlight_token(token, score)
68
+ highlighted_text += "<br><br>"
69
+ return highlighted_text
70
+
71
+
72
+ def main(input_text: str) -> Tuple[float, str]:
73
+ mean_surprisal, tokens2surprisal = calculate_surprisals(
74
+ input_text, normalize_surprisals=True
75
+ )
76
+ highlighted_text = create_highlighted_text(tokens2surprisal)
77
+ # return mean_surprisal, highlighted_text
78
+ return highlighted_text
79
+
80
+
81
+ if __name__ == "__main__":
82
+ demo = gr.Interface(
83
+ fn=main,
84
+ title="読みにくい箇所を検出するAI(デモ)",
85
+ description="テキストを入力すると、読みにくさに応じてハイライトされて出力されます。",
86
+ inputs=gr.inputs.Textbox(
87
+ lines=5, label="テキスト", placeholder="ここにテキストを入力してください。"
88
+ ),
89
+ outputs=[
90
+ gr.outputs.HTML(label="トークン毎サプライザル"),
91
+ ],
92
+ )
93
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.12.1
2
+ transformers==4.20.0
3
+