File size: 3,504 Bytes
774956b
 
 
 
 
 
 
 
3dcdd1a
f8a8921
 
3dcdd1a
 
 
 
 
 
774956b
 
 
 
 
 
 
 
 
 
 
 
 
235ab91
 
774956b
235ab91
774956b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235ab91
774956b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef55e76
 
774956b
 
b3e527a
ef55e76
774956b
 
b3e527a
774956b
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import List, Tuple

import gradio as gr
import numpy as np
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# dundee-geco
# tokenizer = GPT2Tokenizer.from_pretrained("dendee-geco_test-on-zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e6_linearLR")
# model = GPT2LMHeadModel.from_pretrained("dendee-geco_test-on-zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e6_linearLR", return_dict=True)

# zuco
tokenizer = GPT2Tokenizer.from_pretrained("zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e5_linearLR")
model = GPT2LMHeadModel.from_pretrained("zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e5_linearLR", return_dict=True)


model.to(device)


def calculate_surprisals(
    input_text: str, normalize_surprisals: bool = True
) -> Tuple[float, List[Tuple[str, float]]]:
    input_tokens = [
        token.replace("Ġ", "")
        for token in tokenizer.tokenize(input_text)
        if token != "▁"
    ]
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)

    logits = model(input_ids)['logits'].squeeze(0) # (1, seq_len)
    logprob = torch.log_softmax(logits, dim=-1)
    # can't calculate surprisals for the first token, hence 0
    surprisals = [0] + (- torch.gather(logprob[:-1, :], -1, input_ids[:, 1:]).squeeze(0)).tolist()
    mean_surprisal = np.mean(surprisals[1:])

    if normalize_surprisals:
        min_surprisal = np.min(surprisals)
        max_surprisal = np.max(surprisals)
        surprisals = [
            (surprisal - min_surprisal) / (max_surprisal - min_surprisal)
            for surprisal in surprisals
        ]
        assert min(surprisals) >= 0
        assert max(surprisals) <= 1

    tokens2surprisal: List[Tuple[str, float]] = []
    for token, surprisal in zip(input_tokens, surprisals):
        tokens2surprisal.append((token, surprisal))

    return mean_surprisal, tokens2surprisal


def highlight_token(token: str, score: float):
    html_color = "#%02X%02X%02X" % (255, int(255 * (1 - score)), int(255 * (1 - score)))
    return '<span style="background-color: {}; color: black">{}</span>'.format(
        html_color, token
    )


def create_highlighted_text(tokens2scores: List[Tuple[str, float]]):
    highlighted_text: str = ""
    for token, score in tokens2scores:
        highlighted_text += highlight_token(token, score) + ' '
    highlighted_text += "<br><br>"
    return highlighted_text


def main(input_text: str) -> Tuple[float, str]:
    mean_surprisal, tokens2surprisal = calculate_surprisals(
        input_text, normalize_surprisals=True
    )
    highlighted_text = create_highlighted_text(tokens2surprisal)
    return round(mean_surprisal, 2), highlighted_text


if __name__ == "__main__":
    demo = gr.Interface(
        fn=main,
        title="Demo: Highlight text based on eye movement",
        description="Text is highlighted based on surprisal. (The higher the surprisal, the more difficult to read.)",
        inputs=gr.inputs.Textbox(
            lines=5,
            label="Text",
            placeholder="Input text here",
        ),
        outputs=[
            gr.Number(label="Surprisal"),
            gr.outputs.HTML(label="surprisals by token"),
        ],
        examples=[
            "This is a sample text.",
            "Many girls insulted themselves.",
            "Many girls insulted herself.",
            "These casserols disgust Kayla.",
            "These casseroles disgusts Kayla."
        ],
    )

    demo.launch()