from typing import List, Tuple import gradio as gr import numpy as np import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # dundee-geco # tokenizer = GPT2Tokenizer.from_pretrained("dendee-geco_test-on-zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e6_linearLR") # model = GPT2LMHeadModel.from_pretrained("dendee-geco_test-on-zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e6_linearLR", return_dict=True) # zuco tokenizer = GPT2Tokenizer.from_pretrained("zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e5_linearLR") model = GPT2LMHeadModel.from_pretrained("zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e5_linearLR", return_dict=True) model.to(device) def calculate_surprisals( input_text: str, normalize_surprisals: bool = True ) -> Tuple[float, List[Tuple[str, float]]]: input_tokens = [ token.replace("Ġ", "") for token in tokenizer.tokenize(input_text) if token != "▁" ] input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device) logits = model(input_ids)['logits'].squeeze(0) # (1, seq_len) logprob = torch.log_softmax(logits, dim=-1) # can't calculate surprisals for the first token, hence 0 surprisals = [0] + (- torch.gather(logprob[:-1, :], -1, input_ids[:, 1:]).squeeze(0)).tolist() mean_surprisal = np.mean(surprisals[1:]) if normalize_surprisals: min_surprisal = np.min(surprisals) max_surprisal = np.max(surprisals) surprisals = [ (surprisal - min_surprisal) / (max_surprisal - min_surprisal) for surprisal in surprisals ] assert min(surprisals) >= 0 assert max(surprisals) <= 1 tokens2surprisal: List[Tuple[str, float]] = [] for token, surprisal in zip(input_tokens, surprisals): tokens2surprisal.append((token, surprisal)) return mean_surprisal, tokens2surprisal def highlight_token(token: str, score: float): html_color = "#%02X%02X%02X" % (255, int(255 * (1 - score)), int(255 * (1 - score))) return '{}'.format( html_color, token ) def create_highlighted_text(tokens2scores: List[Tuple[str, float]]): highlighted_text: str = "" for token, score in tokens2scores: highlighted_text += highlight_token(token, score) + ' ' highlighted_text += "

" return highlighted_text def main(input_text: str) -> Tuple[float, str]: mean_surprisal, tokens2surprisal = calculate_surprisals( input_text, normalize_surprisals=True ) highlighted_text = create_highlighted_text(tokens2surprisal) return round(mean_surprisal, 2), highlighted_text if __name__ == "__main__": demo = gr.Interface( fn=main, title="Demo: Highlight text based on eye movement", description="Text is highlighted based on surprisal. (The higher the surprisal, the more difficult to read.)", inputs=gr.inputs.Textbox( lines=5, label="Text", placeholder="Input text here", ), outputs=[ gr.Number(label="Surprisal"), gr.outputs.HTML(label="surprisals by token"), ], examples=[ "This is a sample text.", "Many girls insulted themselves.", "Many girls insulted herself.", "These casserols disgust Kayla.", "These casseroles disgusts Kayla." ], ) demo.launch()