Ayemos commited on
Commit
774956b
1 Parent(s): 8a9e244
Files changed (2) hide show
  1. app.py +92 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
7
+
8
+ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
9
+ tokenizer = GPT2Tokenizer.from_pretrained("dendee-geco_test-on-zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e6_linearLR")
10
+ model = GPT2LMHeadModel.from_pretrained('dendee-geco_test-on-zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e6_linearLR', return_dict=True)
11
+ model.to(device)
12
+
13
+
14
+ def calculate_surprisals(
15
+ input_text: str, normalize_surprisals: bool = True
16
+ ) -> Tuple[float, List[Tuple[str, float]]]:
17
+ input_tokens = [
18
+ token.replace("Ġ", "")
19
+ for token in tokenizer.tokenize(input_text)
20
+ if token != "▁"
21
+ ]
22
+ input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
23
+
24
+ logits = model(input_ids)['logits'].squeeze(0)
25
+ # can't calculate surprisals for the first token, hence 0
26
+ surprisals = [0] + (- torch.gather(logits[:-1, :], -1, input_ids[:, 1:]).squeeze(0)).tolist()
27
+ mean_surprisal = np.mean(surprisals[1:])
28
+
29
+ if normalize_surprisals:
30
+ min_surprisal = np.min(surprisals)
31
+ max_surprisal = np.max(surprisals)
32
+ surprisals = [
33
+ (surprisal - min_surprisal) / (max_surprisal - min_surprisal)
34
+ for surprisal in surprisals
35
+ ]
36
+ assert min(surprisals) >= 0
37
+ assert max(surprisals) <= 1
38
+
39
+ tokens2surprisal: List[Tuple[str, float]] = []
40
+ for token, surprisal in zip(input_tokens, surprisals):
41
+ tokens2surprisal.append((token, surprisal))
42
+
43
+ return mean_surprisal, tokens2surprisal
44
+
45
+
46
+ def highlight_token(token: str, score: float):
47
+ html_color = "#%02X%02X%02X" % (255, int(255 * (1 - score)), int(255 * (1 - score)))
48
+ return '<span style="background-color: {}; color: black">{}</span>'.format(
49
+ html_color, token
50
+ )
51
+
52
+
53
+ def create_highlighted_text(tokens2scores: List[Tuple[str, float]]):
54
+ highlighted_text: str = ""
55
+ for token, score in tokens2scores:
56
+ highlighted_text += highlight_token(token, score)
57
+ highlighted_text += "<br><br>"
58
+ return highlighted_text
59
+
60
+
61
+ def main(input_text: str) -> Tuple[float, str]:
62
+ mean_surprisal, tokens2surprisal = calculate_surprisals(
63
+ input_text, normalize_surprisals=True
64
+ )
65
+ highlighted_text = create_highlighted_text(tokens2surprisal)
66
+ return round(mean_surprisal, 2), highlighted_text
67
+
68
+
69
+ if __name__ == "__main__":
70
+ demo = gr.Interface(
71
+ fn=main,
72
+ title="Demo",
73
+ description="The input text is highlighted based on readability. (The higher the surprisal, the more difficult to read.)",
74
+ inputs=gr.inputs.Textbox(
75
+ lines=5,
76
+ label="text",
77
+ placeholder="input text here",
78
+ ),
79
+ outputs=[
80
+ gr.Number(label="surprisals"),
81
+ gr.outputs.HTML(label="surprisals by token"),
82
+ ],
83
+ examples=[
84
+ "This is a sample text.",
85
+ "Many girls insulted themselves.",
86
+ "Many girls insulted herself.",
87
+ "These casserols disgust Kayla.",
88
+ "These casseroles disgusts Kayla."
89
+ ],
90
+ )
91
+
92
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.12.1
2
+ transformers==4.20.0
3
+ sentencepiece==0.1.97