cifkao commited on
Commit
0ca3572
1 Parent(s): 0441206

Allow overwrite

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -316,10 +316,10 @@ def run_context_length_probing(
316
  logprobs = torch.cat([unigram_logprobs.unsqueeze(0), logprobs], dim=0)
317
 
318
  if metric == "NLL loss":
319
- scores = nll_score(logprobs=logprobs, labels=label_ids)
320
  elif metric == "KL divergence":
321
- scores = kl_div_score(logprobs, labels=label_ids)
322
- del logprobs # possibly destroyed by the score computation to save memory
323
 
324
  scores = (-scores).diff(dim=0).transpose(0, 1)
325
  scores = scores.nan_to_num()
 
316
  logprobs = torch.cat([unigram_logprobs.unsqueeze(0), logprobs], dim=0)
317
 
318
  if metric == "NLL loss":
319
+ scores = nll_score(logprobs=logprobs, labels=label_ids, allow_overwrite=True)
320
  elif metric == "KL divergence":
321
+ scores = kl_div_score(logprobs, labels=label_ids, allow_overwrite=True)
322
+ del logprobs # possibly overwritten by the score computation to save memory
323
 
324
  scores = (-scores).diff(dim=0).transpose(0, 1)
325
  scores = scores.nan_to_num()