Ayemos's picture
fix typo
f8a8921
from typing import List, Tuple
import gradio as gr
import numpy as np
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# dundee-geco
# tokenizer = GPT2Tokenizer.from_pretrained("dendee-geco_test-on-zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e6_linearLR")
# model = GPT2LMHeadModel.from_pretrained("dendee-geco_test-on-zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e6_linearLR", return_dict=True)
# zuco
tokenizer = GPT2Tokenizer.from_pretrained("zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e5_linearLR")
model = GPT2LMHeadModel.from_pretrained("zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e5_linearLR", return_dict=True)
model.to(device)
def calculate_surprisals(
input_text: str, normalize_surprisals: bool = True
) -> Tuple[float, List[Tuple[str, float]]]:
input_tokens = [
token.replace("Ġ", "")
for token in tokenizer.tokenize(input_text)
if token != "▁"
]
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
logits = model(input_ids)['logits'].squeeze(0) # (1, seq_len)
logprob = torch.log_softmax(logits, dim=-1)
# can't calculate surprisals for the first token, hence 0
surprisals = [0] + (- torch.gather(logprob[:-1, :], -1, input_ids[:, 1:]).squeeze(0)).tolist()
mean_surprisal = np.mean(surprisals[1:])
if normalize_surprisals:
min_surprisal = np.min(surprisals)
max_surprisal = np.max(surprisals)
surprisals = [
(surprisal - min_surprisal) / (max_surprisal - min_surprisal)
for surprisal in surprisals
]
assert min(surprisals) >= 0
assert max(surprisals) <= 1
tokens2surprisal: List[Tuple[str, float]] = []
for token, surprisal in zip(input_tokens, surprisals):
tokens2surprisal.append((token, surprisal))
return mean_surprisal, tokens2surprisal
def highlight_token(token: str, score: float):
html_color = "#%02X%02X%02X" % (255, int(255 * (1 - score)), int(255 * (1 - score)))
return '<span style="background-color: {}; color: black">{}</span>'.format(
html_color, token
)
def create_highlighted_text(tokens2scores: List[Tuple[str, float]]):
highlighted_text: str = ""
for token, score in tokens2scores:
highlighted_text += highlight_token(token, score) + ' '
highlighted_text += "<br><br>"
return highlighted_text
def main(input_text: str) -> Tuple[float, str]:
mean_surprisal, tokens2surprisal = calculate_surprisals(
input_text, normalize_surprisals=True
)
highlighted_text = create_highlighted_text(tokens2surprisal)
return round(mean_surprisal, 2), highlighted_text
if __name__ == "__main__":
demo = gr.Interface(
fn=main,
title="Demo: Highlight text based on eye movement",
description="Text is highlighted based on surprisal. (The higher the surprisal, the more difficult to read.)",
inputs=gr.inputs.Textbox(
lines=5,
label="Text",
placeholder="Input text here",
),
outputs=[
gr.Number(label="Surprisal"),
gr.outputs.HTML(label="surprisals by token"),
],
examples=[
"This is a sample text.",
"Many girls insulted themselves.",
"Many girls insulted herself.",
"These casserols disgust Kayla.",
"These casseroles disgusts Kayla."
],
)
demo.launch()