Spaces:
Running
Running
Allow overwrite
Browse files
app.py
CHANGED
@@ -316,10 +316,10 @@ def run_context_length_probing(
|
|
316 |
logprobs = torch.cat([unigram_logprobs.unsqueeze(0), logprobs], dim=0)
|
317 |
|
318 |
if metric == "NLL loss":
|
319 |
-
scores = nll_score(logprobs=logprobs, labels=label_ids)
|
320 |
elif metric == "KL divergence":
|
321 |
-
scores = kl_div_score(logprobs, labels=label_ids)
|
322 |
-
del logprobs # possibly
|
323 |
|
324 |
scores = (-scores).diff(dim=0).transpose(0, 1)
|
325 |
scores = scores.nan_to_num()
|
|
|
316 |
logprobs = torch.cat([unigram_logprobs.unsqueeze(0), logprobs], dim=0)
|
317 |
|
318 |
if metric == "NLL loss":
|
319 |
+
scores = nll_score(logprobs=logprobs, labels=label_ids, allow_overwrite=True)
|
320 |
elif metric == "KL divergence":
|
321 |
+
scores = kl_div_score(logprobs, labels=label_ids, allow_overwrite=True)
|
322 |
+
del logprobs # possibly overwritten by the score computation to save memory
|
323 |
|
324 |
scores = (-scores).diff(dim=0).transpose(0, 1)
|
325 |
scores = scores.nan_to_num()
|