Ayemos commited on
Commit
235ab91
1 Parent(s): 774956b
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -7,7 +7,7 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer
7
 
8
  device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
9
  tokenizer = GPT2Tokenizer.from_pretrained("dendee-geco_test-on-zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e6_linearLR")
10
- model = GPT2LMHeadModel.from_pretrained('dendee-geco_test-on-zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e6_linearLR', return_dict=True)
11
  model.to(device)
12
 
13
 
@@ -21,9 +21,10 @@ def calculate_surprisals(
21
  ]
22
  input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
23
 
24
- logits = model(input_ids)['logits'].squeeze(0)
 
25
  # can't calculate surprisals for the first token, hence 0
26
- surprisals = [0] + (- torch.gather(logits[:-1, :], -1, input_ids[:, 1:]).squeeze(0)).tolist()
27
  mean_surprisal = np.mean(surprisals[1:])
28
 
29
  if normalize_surprisals:
@@ -53,7 +54,7 @@ def highlight_token(token: str, score: float):
53
  def create_highlighted_text(tokens2scores: List[Tuple[str, float]]):
54
  highlighted_text: str = ""
55
  for token, score in tokens2scores:
56
- highlighted_text += highlight_token(token, score)
57
  highlighted_text += "<br><br>"
58
  return highlighted_text
59
 
 
7
 
8
  device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
9
  tokenizer = GPT2Tokenizer.from_pretrained("dendee-geco_test-on-zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e6_linearLR")
10
+ model = GPT2LMHeadModel.from_pretrained("dendee-geco_test-on-zuco1.0_gpt2_tmptoken_TRT_bs32_lr1e6_linearLR", return_dict=True)
11
  model.to(device)
12
 
13
 
 
21
  ]
22
  input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
23
 
24
+ logits = model(input_ids)['logits'].squeeze(0) # (1, seq_len)
25
+ logprob = torch.log_softmax(logits, dim=-1)
26
  # can't calculate surprisals for the first token, hence 0
27
+ surprisals = [0] + (- torch.gather(logprob[:-1, :], -1, input_ids[:, 1:]).squeeze(0)).tolist()
28
  mean_surprisal = np.mean(surprisals[1:])
29
 
30
  if normalize_surprisals:
 
54
  def create_highlighted_text(tokens2scores: List[Tuple[str, float]]):
55
  highlighted_text: str = ""
56
  for token, score in tokens2scores:
57
+ highlighted_text += highlight_token(token, score) + ' '
58
  highlighted_text += "<br><br>"
59
  return highlighted_text
60