Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- scripts/decode/en-ja/llama2/beam_search.sh +19 -0
- scripts/decode/en-ja/llama2/greedy_inference.sh +13 -0
- scripts/decode/en-ja/llama2/hf_inference.sh +13 -0
- scripts/decode/en-ja/llama2/top_p_inference.sh +17 -0
- scripts/decode/en-ja/llama2/top_p_inference_1.sh +20 -0
- scripts/decode/en-ja/llama2/top_p_inference_2.sh +21 -0
- scripts/decode/en-ja/mistral-ve/top_p_inference.sh +16 -0
- scripts/decode/en-ja/mistral-ve/top_p_inference_cpo.sh +17 -0
- scripts/decode/en-ja/mistral/top_p_inference_2.sh +20 -0
- scripts/yans/lm-evaluation-harness/.github/workflows/new_tasks.yml +72 -0
- scripts/yans/lm-evaluation-harness/.github/workflows/publish.yml +78 -0
- scripts/yans/lm-evaluation-harness/.github/workflows/unit_tests.yml +95 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/__init__.py +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/__init__.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/filter.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/group.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/instance.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/metrics.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/model.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/registry.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/samplers.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/task.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/filter.py +56 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/group.py +117 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/instance.py +38 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/metrics.py +570 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/model.py +385 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/registry.py +192 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/samplers.py +198 -0
- scripts/yans/lm-evaluation-harness/lm_eval/api/task.py +1674 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__init__.py +28 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/__init__.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/anthropic_llms.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/api_models.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/dummy.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/gguf.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/huggingface.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/mamba_lm.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/nemo_lm.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/neuralmagic.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/neuron_optimum.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/openai_completions.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/optimum_lm.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/textsynth.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/utils.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/vllm_causallms.cpython-310.pyc +0 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/anthropic_llms.py +362 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/api_models.py +641 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/huggingface.py +1356 -0
- scripts/yans/lm-evaluation-harness/lm_eval/models/nemo_lm.py +537 -0
scripts/decode/en-ja/llama2/beam_search.sh
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
set -eux
|
2 |
+
LLM_RECIPES_DIR=/code/llm-recipes
|
3 |
+
source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
|
4 |
+
|
5 |
+
MAX_INPUT_TOKENS=158
|
6 |
+
BEAM_SIZE=50
|
7 |
+
|
8 |
+
python /code/llm-recipes/tools/hf_inference_distrubuted.py \
|
9 |
+
--model /work/models/additiona_trained_hf/llama2-en-ja-continuous-pretrained-v0-dev-finetune-chunked-docs-all-averaged-841-845 \
|
10 |
+
-i /work/wmt2024_test/LLM/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl \
|
11 |
+
-o /work/translation/wmt2024_test/en-ja/llama2-beam \
|
12 |
+
-g 0 1 2 3 4 5 6 7 \
|
13 |
+
--attn_implementation sdpa \
|
14 |
+
--dynamic_max_new_token_ratio 3.0 \
|
15 |
+
--num_return_sequences ${BEAM_SIZE} \
|
16 |
+
--num_beams ${BEAM_SIZE} \
|
17 |
+
--max_input_tokens ${MAX_INPUT_TOKENS} \
|
18 |
+
-b 158
|
19 |
+
|
scripts/decode/en-ja/llama2/greedy_inference.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LLM_RECIPES_DIR=/code/llm-recipes
|
2 |
+
source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
|
3 |
+
|
4 |
+
python /code/llm-recipes/tools/hf_inference.py \
|
5 |
+
--model /work/models/translation_finetuned_hf/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-chunked-docs-all-averaged-71-75 \
|
6 |
+
-i /work/wmt2024_test/LLM/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl \
|
7 |
+
-o /work/translation/wmt24_test/en-ja/mistral-greedy \
|
8 |
+
-g 0 \
|
9 |
+
-b 4096 \
|
10 |
+
--dynamic_max_new_token_ratio 3.0
|
11 |
+
|
12 |
+
echo "Done!"
|
13 |
+
|
scripts/decode/en-ja/llama2/hf_inference.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LLM_RECIPES_DIR=/code/llm-recipes
|
2 |
+
source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
|
3 |
+
|
4 |
+
python /code/llm-recipes/tools/hf_inference.py \
|
5 |
+
--model /work/models/translation_finetuned_hf/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-chunked-docs-all-averaged-71-75 \
|
6 |
+
-i /work/wmt2024_test/LLM/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl \
|
7 |
+
-o /work/translation/wmt24_test/en-ja/mistral-greedy \
|
8 |
+
-g 0 \
|
9 |
+
-b 4096 \
|
10 |
+
--dynamic_max_new_token_ratio 3.0
|
11 |
+
|
12 |
+
echo "Done!"
|
13 |
+
|
scripts/decode/en-ja/llama2/top_p_inference.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
set -eux
|
2 |
+
LLM_RECIPES_DIR=/code/llm-recipes
|
3 |
+
source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
|
4 |
+
|
5 |
+
i=4
|
6 |
+
GPU_ID=4
|
7 |
+
python /code/llm-recipes/tools/hf_inference.py \
|
8 |
+
--model /work/models/translation_finetuned_hf/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-chunked-docs-all-averaged-71-75 \
|
9 |
+
-i /work/wmt2024_test/LLM/split/en-ja/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl.0${i} \
|
10 |
+
-o /work/translation/wmt24_test/en-ja/mistral-top-p-0.95/split_0${i} \
|
11 |
+
-g ${GPU_ID} \
|
12 |
+
-b 500 \
|
13 |
+
--attn_implementation sdpa \
|
14 |
+
--dynamic_max_new_token_ratio 3.0 \
|
15 |
+
--num_return_sequences 100 \
|
16 |
+
--do_sample \
|
17 |
+
--top_p 0.95 &
|
scripts/decode/en-ja/llama2/top_p_inference_1.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
set -eux
|
2 |
+
LLM_RECIPES_DIR=/code/llm-recipes
|
3 |
+
source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
|
4 |
+
|
5 |
+
for i in `seq 0 6`; do
|
6 |
+
python /code/llm-recipes/tools/hf_inference.py \
|
7 |
+
--model /work/models/additiona_trained_hf/llama2-en-ja-continuous-pretrained-v0-dev-finetune-chunked-docs-all-averaged-841-845 \
|
8 |
+
-i /work/wmt2024_test/LLM/split/en-ja/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl.0${i} \
|
9 |
+
-o /work/translation/wmt24_test/en-ja/llama2-top-p-0.95/split_0${i} \
|
10 |
+
-g ${i} \
|
11 |
+
-b 158 \
|
12 |
+
--attn_implementation sdpa \
|
13 |
+
--dynamic_max_new_token_ratio 3.0 \
|
14 |
+
--num_return_sequences 50 \
|
15 |
+
--do_sample \
|
16 |
+
--top_p 0.95 \
|
17 |
+
--max_input_tokens 158 &
|
18 |
+
done
|
19 |
+
wait
|
20 |
+
|
scripts/decode/en-ja/llama2/top_p_inference_2.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
set -eux
|
2 |
+
LLM_RECIPES_DIR=/code/llm-recipes
|
3 |
+
source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
|
4 |
+
|
5 |
+
for i in `seq 7 9`; do
|
6 |
+
GPU_ID=$((i-5))
|
7 |
+
python /code/llm-recipes/tools/hf_inference.py \
|
8 |
+
--model /work/models/additiona_trained_hf/llama2-en-ja-continuous-pretrained-v0-dev-finetune-chunked-docs-all-averaged-841-845 \
|
9 |
+
-i /work/wmt2024_test/LLM/split/en-ja/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl.0${i} \
|
10 |
+
-o /work/translation/wmt24_test/en-ja/llama2-top-p-0.95/split_0${i} \
|
11 |
+
-g ${GPU_ID} \
|
12 |
+
-b 158 \
|
13 |
+
--attn_implementation sdpa \
|
14 |
+
--dynamic_max_new_token_ratio 3.0 \
|
15 |
+
--num_return_sequences 50 \
|
16 |
+
--do_sample \
|
17 |
+
--top_p 0.95 \
|
18 |
+
--max_input_tokens 158 &
|
19 |
+
done
|
20 |
+
wait
|
21 |
+
|
scripts/decode/en-ja/mistral-ve/top_p_inference.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
set -eux
|
2 |
+
LLM_RECIPES_DIR=/code/llm-recipes
|
3 |
+
source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
|
4 |
+
|
5 |
+
python /code/llm-recipes/tools/hf_inference_distrubuted.py \
|
6 |
+
--model /work/models/translation_finetuned_hf/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-ve-sim-chunked-docs-all-averaged-596-600 \
|
7 |
+
-i /work/wmt2024_test/LLM/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl \
|
8 |
+
-o /work/translation/wmt2024_test/en-ja/mistral-ve-top-p-0.95 \
|
9 |
+
-g 0 1 2 3 4 5 6 7 \
|
10 |
+
-b 125 \
|
11 |
+
--attn_implementation sdpa \
|
12 |
+
--dynamic_max_new_token_ratio 2.0 \
|
13 |
+
--num_return_sequences 80 \
|
14 |
+
--do_sample \
|
15 |
+
--top_p 0.95 \
|
16 |
+
--max_input_tokens 125
|
scripts/decode/en-ja/mistral-ve/top_p_inference_cpo.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
set -eux
|
2 |
+
LLM_RECIPES_DIR=/code/llm-recipes
|
3 |
+
source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
|
4 |
+
|
5 |
+
python /code/llm-recipes/tools/hf_inference_distrubuted.py \
|
6 |
+
--model /work/models/translation_finetuned_hf/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-ve-sim-chunked-docs-all-averaged-596-600 \
|
7 |
+
-i /work/wmt2024_test/LLM/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl \
|
8 |
+
-o /work/translation/wmt2024_test/en-ja/mistral-ve-top-p-0.95-cpo \
|
9 |
+
-p /work/models/dpo/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-ve-sim-chunked-docs-all-cpo-lora/checkpoint-200 \
|
10 |
+
-g 0 1 2 3 4 5 6 7 \
|
11 |
+
-b 125 \
|
12 |
+
--attn_implementation sdpa \
|
13 |
+
--dynamic_max_new_token_ratio 2.0 \
|
14 |
+
--num_return_sequences 80 \
|
15 |
+
--do_sample \
|
16 |
+
--top_p 0.95 \
|
17 |
+
--max_input_tokens 125 \
|
scripts/decode/en-ja/mistral/top_p_inference_2.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
set -eux
|
2 |
+
LLM_RECIPES_DIR=/code/llm-recipes
|
3 |
+
source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
|
4 |
+
|
5 |
+
for i in `seq 8 9`; do
|
6 |
+
# minus 2 for gpu id
|
7 |
+
GPU_ID=$((i-2))
|
8 |
+
python /code/llm-recipes/tools/hf_inference.py \
|
9 |
+
--model /work/models/translation_finetuned_hf/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-chunked-docs-all-averaged-71-75 \
|
10 |
+
-i /work/wmt2024_test/LLM/split/en-ja/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl.0${i} \
|
11 |
+
-o /work/translation/wmt24_test/en-ja/mistral-top-p-0.95/split_0${i} \
|
12 |
+
-g ${GPU_ID} \
|
13 |
+
-b 400 \
|
14 |
+
--attn_implementation sdpa \
|
15 |
+
--dynamic_max_new_token_ratio 3.0 \
|
16 |
+
--num_return_sequences 100 \
|
17 |
+
--do_sample \
|
18 |
+
--top_p 0.95 &
|
19 |
+
done
|
20 |
+
wait
|
scripts/yans/lm-evaluation-harness/.github/workflows/new_tasks.yml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Tasks Modified
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- 'main'
|
7 |
+
pull_request:
|
8 |
+
branches:
|
9 |
+
- 'main'
|
10 |
+
workflow_dispatch:
|
11 |
+
# comment/edit out the above to stop/change the triggers
|
12 |
+
jobs:
|
13 |
+
changed_files:
|
14 |
+
runs-on: ubuntu-latest # windows-latest || macos-latest
|
15 |
+
timeout-minutes: 120
|
16 |
+
name: Scan for changed tasks
|
17 |
+
steps:
|
18 |
+
- name: checkout
|
19 |
+
uses: actions/checkout@v3
|
20 |
+
with:
|
21 |
+
fetch-depth: 2 # OR "2" -> To retrieve the preceding commit.
|
22 |
+
|
23 |
+
# Uses the tj-actions/changed-files action to check for changes.
|
24 |
+
# Outputs provided here: https://github.com/tj-actions/changed-files#outputs
|
25 |
+
# The `files_yaml` input optionally takes a yaml string to specify filters,
|
26 |
+
# and prepends the filter name to the standard output names.
|
27 |
+
- name: Check task folders
|
28 |
+
id: changed-tasks
|
29 |
+
uses: tj-actions/changed-files@v44.5.2
|
30 |
+
with:
|
31 |
+
# tasks checks the tasks folder and api checks the api folder for changes
|
32 |
+
files_yaml: |
|
33 |
+
tasks:
|
34 |
+
- lm_eval/tasks/**
|
35 |
+
api:
|
36 |
+
- lm_eval/api/**
|
37 |
+
write_output_files: true
|
38 |
+
|
39 |
+
# The next step is optional; the files are written to the workspace by default (above).
|
40 |
+
# so it's just for debugging
|
41 |
+
- name: Run Tests
|
42 |
+
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
|
43 |
+
run: |
|
44 |
+
echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV'
|
45 |
+
echo "One or more test file(s) has changed."
|
46 |
+
echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}"
|
47 |
+
|
48 |
+
- name: Set up Python 3.9
|
49 |
+
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
|
50 |
+
uses: actions/setup-python@v4
|
51 |
+
with:
|
52 |
+
python-version: 3.9
|
53 |
+
cache: 'pip'
|
54 |
+
cache-dependency-path: setup.py
|
55 |
+
- name: Install dependencies
|
56 |
+
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
|
57 |
+
run: |
|
58 |
+
python -m pip install --upgrade pip
|
59 |
+
pip install -e '.[dev,ifeval]' --extra-index-url https://download.pytorch.org/whl/cpu
|
60 |
+
# Install optional git dependencies
|
61 |
+
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
|
62 |
+
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
63 |
+
- name: Test with pytest
|
64 |
+
# if new tasks are added, run tests on them
|
65 |
+
if: steps.changed-tasks.outputs.tasks_any_modified == 'true'
|
66 |
+
run: python -m pytest tests/test_tasks.py -s -vv
|
67 |
+
# if api is modified, run tests on it
|
68 |
+
- name: Test more tasks with pytest
|
69 |
+
env:
|
70 |
+
API: true
|
71 |
+
if: steps.changed-tasks.outputs.api_any_modified == 'true'
|
72 |
+
run: python -m pytest tests/test_tasks.py -s -vv
|
scripts/yans/lm-evaluation-harness/.github/workflows/publish.yml
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Publish Python distribution to PyPI
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
tags:
|
6 |
+
- '*'
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
build:
|
10 |
+
name: Build distribution
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
|
13 |
+
steps:
|
14 |
+
- uses: actions/checkout@v4
|
15 |
+
- name: Set up Python
|
16 |
+
uses: actions/setup-python@v4
|
17 |
+
with:
|
18 |
+
python-version: "3.x"
|
19 |
+
|
20 |
+
- name: Install pypa/build
|
21 |
+
run: >-
|
22 |
+
python3 -m
|
23 |
+
pip install
|
24 |
+
build
|
25 |
+
--user
|
26 |
+
- name: Build a binary wheel and a source tarball
|
27 |
+
run: python3 -m build
|
28 |
+
- name: Store the distribution packages
|
29 |
+
uses: actions/upload-artifact@v3
|
30 |
+
with:
|
31 |
+
name: python-package-distributions
|
32 |
+
path: dist/
|
33 |
+
|
34 |
+
publish-to-pypi:
|
35 |
+
name: >-
|
36 |
+
Publish Python distribution to PyPI
|
37 |
+
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
|
38 |
+
needs:
|
39 |
+
- build
|
40 |
+
runs-on: ubuntu-latest
|
41 |
+
environment:
|
42 |
+
name: pypi
|
43 |
+
url: https://pypi.org/p/lm_eval
|
44 |
+
permissions:
|
45 |
+
id-token: write # IMPORTANT: mandatory for trusted publishing
|
46 |
+
|
47 |
+
steps:
|
48 |
+
- name: Download all the dists
|
49 |
+
uses: actions/download-artifact@v3
|
50 |
+
with:
|
51 |
+
name: python-package-distributions
|
52 |
+
path: dist/
|
53 |
+
- name: Publish distribution to PyPI
|
54 |
+
uses: pypa/gh-action-pypi-publish@release/v1
|
55 |
+
|
56 |
+
publish-to-testpypi:
|
57 |
+
name: Publish Python distribution to TestPyPI
|
58 |
+
needs:
|
59 |
+
- build
|
60 |
+
runs-on: ubuntu-latest
|
61 |
+
|
62 |
+
environment:
|
63 |
+
name: testpypi
|
64 |
+
url: https://test.pypi.org/p/lm_eval
|
65 |
+
|
66 |
+
permissions:
|
67 |
+
id-token: write # IMPORTANT: mandatory for trusted publishing
|
68 |
+
|
69 |
+
steps:
|
70 |
+
- name: Download all the dists
|
71 |
+
uses: actions/download-artifact@v3
|
72 |
+
with:
|
73 |
+
name: python-package-distributions
|
74 |
+
path: dist/
|
75 |
+
- name: Publish distribution to TestPyPI
|
76 |
+
uses: pypa/gh-action-pypi-publish@release/v1
|
77 |
+
with:
|
78 |
+
repository-url: https://test.pypi.org/legacy/
|
scripts/yans/lm-evaluation-harness/.github/workflows/unit_tests.yml
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
2 |
+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
3 |
+
# just comment out unwanted steps to turn off the test.
|
4 |
+
name: Unit Tests
|
5 |
+
|
6 |
+
on:
|
7 |
+
push:
|
8 |
+
branches:
|
9 |
+
- 'main'
|
10 |
+
pull_request:
|
11 |
+
branches:
|
12 |
+
- 'main'
|
13 |
+
workflow_dispatch:
|
14 |
+
# Jobs run concurrently and steps run sequentially within a job.
|
15 |
+
# jobs: linter and cpu_tests. Add more jobs/steps as required.
|
16 |
+
jobs:
|
17 |
+
linter:
|
18 |
+
name: Linters
|
19 |
+
runs-on: ubuntu-latest
|
20 |
+
timeout-minutes: 5
|
21 |
+
|
22 |
+
steps:
|
23 |
+
- name: Checkout Code
|
24 |
+
uses: actions/checkout@v4
|
25 |
+
- name: Set up Python 3.8
|
26 |
+
uses: actions/setup-python@v5
|
27 |
+
with:
|
28 |
+
python-version: 3.8
|
29 |
+
cache: pip
|
30 |
+
cache-dependency-path: pyproject.toml
|
31 |
+
- name: Pre-Commit
|
32 |
+
env:
|
33 |
+
SKIP: "no-commit-to-branch,mypy"
|
34 |
+
|
35 |
+
uses: pre-commit/action@v3.0.1
|
36 |
+
# # mypy turned off for now
|
37 |
+
# - name: Lint with mypy
|
38 |
+
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
|
39 |
+
# Job 2
|
40 |
+
testcpu:
|
41 |
+
name: CPU Tests
|
42 |
+
runs-on: ubuntu-latest
|
43 |
+
strategy:
|
44 |
+
matrix:
|
45 |
+
python-version: [ "3.8", "3.9", "3.10", "3.11" ]
|
46 |
+
timeout-minutes: 30
|
47 |
+
steps:
|
48 |
+
- name: Checkout Code
|
49 |
+
uses: actions/checkout@v4
|
50 |
+
- name: Set up Python ${{ matrix.python-version }}
|
51 |
+
uses: actions/setup-python@v5
|
52 |
+
with:
|
53 |
+
python-version: ${{ matrix.python-version }}
|
54 |
+
cache: pip
|
55 |
+
cache-dependency-path: pyproject.toml
|
56 |
+
- name: Install dependencies
|
57 |
+
run: |
|
58 |
+
python -m pip install --upgrade pip
|
59 |
+
pip install -e '.[dev,sentencepiece,api]' --extra-index-url https://download.pytorch.org/whl/cpu
|
60 |
+
# Install optional git dependencies
|
61 |
+
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
|
62 |
+
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
63 |
+
- name: Test with pytest
|
64 |
+
run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/models/test_neuralmagic.py --ignore=tests/models/test_openvino.py
|
65 |
+
- name: Archive artifacts
|
66 |
+
uses: actions/upload-artifact@v3
|
67 |
+
with:
|
68 |
+
name: output_results
|
69 |
+
path: |
|
70 |
+
test_logs/*
|
71 |
+
testmodels:
|
72 |
+
name: External LM Tests
|
73 |
+
runs-on: ubuntu-latest
|
74 |
+
timeout-minutes: 30
|
75 |
+
steps:
|
76 |
+
- name: Checkout Code
|
77 |
+
uses: actions/checkout@v4
|
78 |
+
- name: Set up Python 3.8
|
79 |
+
uses: actions/setup-python@v5
|
80 |
+
with:
|
81 |
+
python-version: 3.8
|
82 |
+
cache: pip
|
83 |
+
cache-dependency-path: pyproject.toml
|
84 |
+
- name: Install dependencies
|
85 |
+
run: |
|
86 |
+
python -m pip install --upgrade pip
|
87 |
+
pip install -e '.[dev,optimum,deepsparse,sparseml,api]' --extra-index-url https://download.pytorch.org/whl/cpu
|
88 |
+
- name: Test with pytest
|
89 |
+
run: python -m pytest tests/models --showlocals -s -vv
|
90 |
+
- name: Archive artifacts
|
91 |
+
uses: actions/upload-artifact@v3
|
92 |
+
with:
|
93 |
+
name: output_results
|
94 |
+
path: |
|
95 |
+
test_logs/*
|
scripts/yans/lm-evaluation-harness/lm_eval/api/__init__.py
ADDED
File without changes
|
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (160 Bytes). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/filter.cpython-310.pyc
ADDED
Binary file (2.72 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/group.cpython-310.pyc
ADDED
Binary file (4.61 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/instance.cpython-310.pyc
ADDED
Binary file (1.51 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/metrics.cpython-310.pyc
ADDED
Binary file (13.2 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/model.cpython-310.pyc
ADDED
Binary file (14.1 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/registry.cpython-310.pyc
ADDED
Binary file (5.11 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/samplers.cpython-310.pyc
ADDED
Binary file (4.81 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/task.cpython-310.pyc
ADDED
Binary file (43.5 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/api/filter.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Callable, Iterable, List, Union
|
4 |
+
|
5 |
+
from lm_eval.api.instance import Instance
|
6 |
+
|
7 |
+
|
8 |
+
class Filter(ABC):
|
9 |
+
"""
|
10 |
+
Filter classes operate on a per-task level.
|
11 |
+
They take all model outputs (`instance.resps` for all `task.instances`)
|
12 |
+
across all instances of a task, and perform operations.
|
13 |
+
In a single run, one can configure any number of separate filters or lists of filters.
|
14 |
+
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, **kwargs) -> None:
|
18 |
+
"""
|
19 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
20 |
+
"""
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
|
24 |
+
"""
|
25 |
+
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
|
26 |
+
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
|
27 |
+
if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
|
28 |
+
[<filtered resps for instance 0>, <filtered resps for instance 1>]
|
29 |
+
"""
|
30 |
+
return resps
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class FilterEnsemble:
|
35 |
+
"""
|
36 |
+
FilterEnsemble creates a pipeline applying multiple filters.
|
37 |
+
Its intended usage is to stack multiple post-processing steps in order.
|
38 |
+
`task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
|
39 |
+
pipeline separately.
|
40 |
+
"""
|
41 |
+
|
42 |
+
name: str
|
43 |
+
filters: List[Callable[[], Filter]]
|
44 |
+
|
45 |
+
def apply(self, instances: List[Instance]) -> None:
|
46 |
+
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
|
47 |
+
resps, docs = list(resps), list(docs)
|
48 |
+
|
49 |
+
for f in self.filters:
|
50 |
+
# apply filters in sequence
|
51 |
+
resps = f().apply(resps, docs)
|
52 |
+
|
53 |
+
# add the end results after filtering to filtered_requests of their respective source instances.
|
54 |
+
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
|
55 |
+
for inst, resp in zip(instances, resps):
|
56 |
+
inst.filtered_resps[self.name] = resp
|
scripts/yans/lm-evaluation-harness/lm_eval/api/group.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from dataclasses import asdict, dataclass
|
3 |
+
from inspect import getsource
|
4 |
+
from typing import Any, Callable, List, Optional, Union
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class AggMetricConfig(dict):
|
9 |
+
metric: Optional[str] = None
|
10 |
+
aggregation: Optional[str] = "mean"
|
11 |
+
weight_by_size: Optional[str] = False
|
12 |
+
# list of filter names which should be incorporated into the aggregated metric.
|
13 |
+
filter_list: Optional[Union[str, list]] = "none"
|
14 |
+
|
15 |
+
def __post_init__(self):
|
16 |
+
if self.aggregation != "mean":
|
17 |
+
raise ValueError(
|
18 |
+
f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{self.aggregation}'."
|
19 |
+
)
|
20 |
+
|
21 |
+
if isinstance(self.filter_list, str):
|
22 |
+
self.filter_list = [self.filter_list]
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class GroupConfig(dict):
|
27 |
+
group: Optional[str] = None
|
28 |
+
group_alias: Optional[str] = None
|
29 |
+
task: Optional[Union[str, list]] = None
|
30 |
+
aggregate_metric_list: Optional[
|
31 |
+
Union[List[AggMetricConfig], AggMetricConfig, dict]
|
32 |
+
] = None
|
33 |
+
metadata: Optional[dict] = (
|
34 |
+
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
|
35 |
+
)
|
36 |
+
|
37 |
+
def __getitem__(self, item):
|
38 |
+
return getattr(self, item)
|
39 |
+
|
40 |
+
def __setitem__(self, item, value):
|
41 |
+
return setattr(self, item, value)
|
42 |
+
|
43 |
+
def __post_init__(self):
|
44 |
+
if self.aggregate_metric_list is not None:
|
45 |
+
if isinstance(self.aggregate_metric_list, dict):
|
46 |
+
self.aggregate_metric_list = [self.aggregate_metric_list]
|
47 |
+
|
48 |
+
self.aggregate_metric_list = [
|
49 |
+
AggMetricConfig(**item) if isinstance(item, dict) else item
|
50 |
+
for item in self.aggregate_metric_list
|
51 |
+
]
|
52 |
+
|
53 |
+
def to_dict(self, keep_callable: bool = False) -> dict:
|
54 |
+
"""dumps the current config as a dictionary object, as a printable format.
|
55 |
+
null fields will not be printed.
|
56 |
+
Used for dumping results alongside full task configuration
|
57 |
+
|
58 |
+
:return: dict
|
59 |
+
A printable dictionary version of the TaskConfig object.
|
60 |
+
|
61 |
+
# TODO: should any default value in the TaskConfig not be printed?
|
62 |
+
"""
|
63 |
+
cfg_dict = asdict(self)
|
64 |
+
# remove values that are `None`
|
65 |
+
for k, v in list(cfg_dict.items()):
|
66 |
+
if callable(v):
|
67 |
+
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
|
68 |
+
return cfg_dict
|
69 |
+
|
70 |
+
def serialize_function(
|
71 |
+
self, value: Union[Callable, str], keep_callable=False
|
72 |
+
) -> Union[Callable, str]:
|
73 |
+
"""Serializes a given function or string.
|
74 |
+
|
75 |
+
If 'keep_callable' is True, the original callable is returned.
|
76 |
+
Otherwise, attempts to return the source code of the callable using 'getsource'.
|
77 |
+
"""
|
78 |
+
if keep_callable:
|
79 |
+
return value
|
80 |
+
else:
|
81 |
+
try:
|
82 |
+
return getsource(value)
|
83 |
+
except (TypeError, OSError):
|
84 |
+
return str(value)
|
85 |
+
|
86 |
+
|
87 |
+
class ConfigurableGroup(abc.ABC):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
config: Optional[dict] = None,
|
91 |
+
) -> None:
|
92 |
+
self._config = GroupConfig(**config)
|
93 |
+
|
94 |
+
@property
|
95 |
+
def group(self):
|
96 |
+
return self._config.group
|
97 |
+
|
98 |
+
@property
|
99 |
+
def group_alias(self):
|
100 |
+
return self._config.group_alias
|
101 |
+
|
102 |
+
@property
|
103 |
+
def version(self):
|
104 |
+
return self._config.version
|
105 |
+
|
106 |
+
@property
|
107 |
+
def config(self):
|
108 |
+
return self._config.to_dict()
|
109 |
+
|
110 |
+
@property
|
111 |
+
def group_name(self) -> Any:
|
112 |
+
return self._config.group
|
113 |
+
|
114 |
+
def __repr__(self):
|
115 |
+
return (
|
116 |
+
f"ConfigurableGroup(group={self.group}," f"group_alias={self.group_alias})"
|
117 |
+
)
|
scripts/yans/lm-evaluation-harness/lm_eval/api/instance.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Literal, Optional, Tuple
|
3 |
+
|
4 |
+
|
5 |
+
OutputType = Literal[
|
6 |
+
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
|
7 |
+
]
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class Instance:
|
12 |
+
request_type: OutputType
|
13 |
+
doc: dict
|
14 |
+
arguments: tuple
|
15 |
+
idx: int
|
16 |
+
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
|
17 |
+
default_factory=lambda: (None, None, None)
|
18 |
+
)
|
19 |
+
resps: list = field(default_factory=list)
|
20 |
+
filtered_resps: dict = field(default_factory=dict)
|
21 |
+
|
22 |
+
# initialized after init
|
23 |
+
task_name: Optional[str] = None
|
24 |
+
doc_id: Optional[int] = None
|
25 |
+
repeats: Optional[int] = None
|
26 |
+
|
27 |
+
def __post_init__(self) -> None:
|
28 |
+
# unpack metadata field
|
29 |
+
self.task_name, self.doc_id, self.repeats = self.metadata
|
30 |
+
|
31 |
+
@property
|
32 |
+
def args(self):
|
33 |
+
"""
|
34 |
+
Returns (string,) where `string` is the string to calculate loglikelihood over
|
35 |
+
"""
|
36 |
+
return (
|
37 |
+
self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
|
38 |
+
)
|
scripts/yans/lm-evaluation-harness/lm_eval/api/metrics.py
ADDED
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
import string
|
6 |
+
from collections.abc import Iterable
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import sacrebleu
|
11 |
+
|
12 |
+
from lm_eval.api.registry import register_aggregation, register_metric
|
13 |
+
|
14 |
+
|
15 |
+
eval_logger = logging.getLogger("lm-eval")
|
16 |
+
|
17 |
+
|
18 |
+
# Register Aggregations First
|
19 |
+
@register_aggregation("bypass")
|
20 |
+
def bypass_agg(arr):
|
21 |
+
return 999
|
22 |
+
|
23 |
+
|
24 |
+
@register_aggregation("mean")
|
25 |
+
def mean(arr):
|
26 |
+
return sum(arr) / len(arr)
|
27 |
+
|
28 |
+
|
29 |
+
@register_aggregation("median")
|
30 |
+
def median(arr):
|
31 |
+
return arr[len(arr) // 2]
|
32 |
+
|
33 |
+
|
34 |
+
# Certain metrics must be calculated across all documents in a benchmark.
|
35 |
+
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
|
36 |
+
@register_aggregation("perplexity")
|
37 |
+
def perplexity(items):
|
38 |
+
return math.exp(-mean(items))
|
39 |
+
|
40 |
+
|
41 |
+
@register_aggregation("weighted_perplexity")
|
42 |
+
def weighted_perplexity(items):
|
43 |
+
return math.exp(-weighted_mean(items))
|
44 |
+
|
45 |
+
|
46 |
+
@register_aggregation("bits_per_byte")
|
47 |
+
def bits_per_byte(items):
|
48 |
+
return -weighted_mean(items) / math.log(2)
|
49 |
+
|
50 |
+
|
51 |
+
@register_aggregation("f1")
|
52 |
+
def f1_score(items):
|
53 |
+
from sklearn.metrics import f1_score
|
54 |
+
|
55 |
+
unzipped_list = list(zip(*items))
|
56 |
+
golds = unzipped_list[0]
|
57 |
+
preds = unzipped_list[1]
|
58 |
+
fscore = f1_score(golds, preds)
|
59 |
+
|
60 |
+
return np.max(fscore)
|
61 |
+
|
62 |
+
|
63 |
+
@register_aggregation("matthews_corrcoef")
|
64 |
+
def matthews_corrcoef(items):
|
65 |
+
from sklearn.metrics import matthews_corrcoef
|
66 |
+
|
67 |
+
unzipped_list = list(zip(*items))
|
68 |
+
golds = unzipped_list[0]
|
69 |
+
preds = unzipped_list[1]
|
70 |
+
return matthews_corrcoef(golds, preds)
|
71 |
+
|
72 |
+
|
73 |
+
@register_aggregation("bleu")
|
74 |
+
def bleu(items):
|
75 |
+
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
|
76 |
+
for evaluating a generated sentence to a reference sentence. It counts matching
|
77 |
+
n-grams in the candidate translation to n-grams in the reference text, where
|
78 |
+
1-gram or unigram would be each token and a bigram comparison would be each
|
79 |
+
word pair. The comparison is made regardless of word order
|
80 |
+
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
|
81 |
+
Paper: https://www.aclweb.org/anthology/P02-1040/
|
82 |
+
|
83 |
+
Higher is better
|
84 |
+
"""
|
85 |
+
refs = list(zip(*items))[0]
|
86 |
+
preds = list(zip(*items))[1]
|
87 |
+
refs, preds = _sacreformat(refs, preds)
|
88 |
+
return sacrebleu.corpus_bleu(preds, refs).score
|
89 |
+
|
90 |
+
|
91 |
+
@register_aggregation("chrf")
|
92 |
+
def chrf(items):
|
93 |
+
"""chrF++ is a tool for automatic evaluation of machine translation output
|
94 |
+
based on character n-gram precision and recall enhanced with word n-grams.
|
95 |
+
Source: https://github.com/m-popovic/chrF
|
96 |
+
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
|
97 |
+
|
98 |
+
Higher is better # TODO I think
|
99 |
+
"""
|
100 |
+
refs = list(zip(*items))[0]
|
101 |
+
preds = list(zip(*items))[1]
|
102 |
+
refs, preds = _sacreformat(refs, preds)
|
103 |
+
return sacrebleu.corpus_chrf(preds, refs).score
|
104 |
+
|
105 |
+
|
106 |
+
@register_aggregation("ter")
|
107 |
+
def ter(items):
|
108 |
+
"""Translation Error Rate is an error metric for machine translation that
|
109 |
+
measures the number of edits required to change a system output into one
|
110 |
+
of the references
|
111 |
+
Source: http://www.cs.umd.edu/~snover/tercom/
|
112 |
+
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
|
113 |
+
|
114 |
+
Lower is better
|
115 |
+
"""
|
116 |
+
refs = list(zip(*items))[0]
|
117 |
+
preds = list(zip(*items))[1]
|
118 |
+
refs, preds = _sacreformat(refs, preds)
|
119 |
+
return sacrebleu.corpus_ter(preds, refs).score
|
120 |
+
|
121 |
+
|
122 |
+
@register_aggregation("brier_score")
|
123 |
+
def brier_score(items): # This is a passthrough function
|
124 |
+
gold, predictions = list(zip(*items))
|
125 |
+
bs, num_class = np.array(predictions).shape
|
126 |
+
|
127 |
+
gold = list(gold)
|
128 |
+
gold_one_hot = np.eye(num_class)[gold]
|
129 |
+
return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
|
130 |
+
|
131 |
+
|
132 |
+
@register_metric(
|
133 |
+
metric="brier_score",
|
134 |
+
higher_is_better=False,
|
135 |
+
output_type=["multiple_choice"],
|
136 |
+
aggregation="brier_score",
|
137 |
+
)
|
138 |
+
def brier_score_fn(items): # This is a passthrough function
|
139 |
+
return items
|
140 |
+
|
141 |
+
|
142 |
+
@register_metric(
|
143 |
+
metric="acc",
|
144 |
+
higher_is_better=True,
|
145 |
+
output_type=["loglikelihood", "multiple_choice"],
|
146 |
+
aggregation="mean",
|
147 |
+
)
|
148 |
+
def acc_fn(items): # This is a passthrough function
|
149 |
+
return items
|
150 |
+
|
151 |
+
|
152 |
+
@register_metric(
|
153 |
+
metric="acc_norm",
|
154 |
+
higher_is_better=True,
|
155 |
+
output_type=["loglikelihood", "multiple_choice"],
|
156 |
+
aggregation="mean",
|
157 |
+
)
|
158 |
+
def acc_norm_fn(items): # This is a passthrough function
|
159 |
+
return items
|
160 |
+
|
161 |
+
|
162 |
+
@register_metric(
|
163 |
+
metric="acc_mutual_info",
|
164 |
+
higher_is_better=True,
|
165 |
+
output_type="multiple_choice",
|
166 |
+
aggregation="mean",
|
167 |
+
)
|
168 |
+
def acc_mutual_info_fn(items): # This is a passthrough function
|
169 |
+
return items
|
170 |
+
|
171 |
+
|
172 |
+
### the code used in the `exact_match_hf_evaluate` function is ported from
|
173 |
+
### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py
|
174 |
+
### which is under the apache license.
|
175 |
+
|
176 |
+
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
177 |
+
|
178 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
179 |
+
# you may not use this file except in compliance with the License.
|
180 |
+
# You may obtain a copy of the License at
|
181 |
+
|
182 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
183 |
+
|
184 |
+
|
185 |
+
# Unless required by applicable law or agreed to in writing, software
|
186 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
187 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
188 |
+
# See the License for the specific language governing permissions and
|
189 |
+
# limitations under the License.
|
190 |
+
def exact_match_hf_evaluate(
|
191 |
+
predictions,
|
192 |
+
references,
|
193 |
+
regexes_to_ignore=None,
|
194 |
+
ignore_case=False,
|
195 |
+
ignore_punctuation=False,
|
196 |
+
ignore_numbers=False,
|
197 |
+
):
|
198 |
+
if regexes_to_ignore is not None:
|
199 |
+
for s in regexes_to_ignore:
|
200 |
+
predictions = np.array([re.sub(s, "", x) for x in predictions])
|
201 |
+
references = np.array([re.sub(s, "", x) for x in references])
|
202 |
+
else:
|
203 |
+
predictions = np.asarray(predictions)
|
204 |
+
references = np.asarray(references)
|
205 |
+
|
206 |
+
if ignore_case:
|
207 |
+
predictions = np.char.lower(predictions)
|
208 |
+
references = np.char.lower(references)
|
209 |
+
|
210 |
+
if ignore_punctuation:
|
211 |
+
repl_table = string.punctuation.maketrans("", "", string.punctuation)
|
212 |
+
predictions = np.char.translate(predictions, table=repl_table)
|
213 |
+
references = np.char.translate(references, table=repl_table)
|
214 |
+
|
215 |
+
if ignore_numbers:
|
216 |
+
repl_table = string.digits.maketrans("", "", string.digits)
|
217 |
+
predictions = np.char.translate(predictions, table=repl_table)
|
218 |
+
references = np.char.translate(references, table=repl_table)
|
219 |
+
|
220 |
+
score_list = predictions == references
|
221 |
+
|
222 |
+
return {"exact_match": np.mean(score_list)}
|
223 |
+
|
224 |
+
|
225 |
+
###
|
226 |
+
|
227 |
+
|
228 |
+
@register_metric(
|
229 |
+
metric="exact_match",
|
230 |
+
higher_is_better=True,
|
231 |
+
output_type="generate_until",
|
232 |
+
aggregation="mean",
|
233 |
+
)
|
234 |
+
def exact_match_fn(**kwargs):
|
235 |
+
return exact_match_hf_evaluate(**kwargs)
|
236 |
+
|
237 |
+
|
238 |
+
@register_metric(
|
239 |
+
metric="perplexity",
|
240 |
+
higher_is_better=False,
|
241 |
+
output_type="loglikelihood",
|
242 |
+
aggregation="perplexity",
|
243 |
+
)
|
244 |
+
def perplexity_fn(items): # This is a passthrough function
|
245 |
+
return items
|
246 |
+
|
247 |
+
|
248 |
+
@register_metric(
|
249 |
+
metric="word_perplexity",
|
250 |
+
higher_is_better=False,
|
251 |
+
output_type="loglikelihood_rolling",
|
252 |
+
aggregation="weighted_perplexity",
|
253 |
+
)
|
254 |
+
def word_perplexity_fn(items): # This is a passthrough function
|
255 |
+
return items
|
256 |
+
|
257 |
+
|
258 |
+
@register_metric(
|
259 |
+
metric="byte_perplexity",
|
260 |
+
higher_is_better=False,
|
261 |
+
output_type="loglikelihood_rolling",
|
262 |
+
aggregation="weighted_perplexity",
|
263 |
+
)
|
264 |
+
def byte_perplexity_fn(items): # This is a passthrough function
|
265 |
+
return items
|
266 |
+
|
267 |
+
|
268 |
+
@register_metric(
|
269 |
+
metric="bits_per_byte",
|
270 |
+
higher_is_better=False,
|
271 |
+
output_type="loglikelihood_rolling",
|
272 |
+
aggregation="bits_per_byte",
|
273 |
+
)
|
274 |
+
def bits_per_byte_fn(items): # This is a passthrough function
|
275 |
+
return items
|
276 |
+
|
277 |
+
|
278 |
+
def pop_stddev(arr):
|
279 |
+
mu = mean(arr)
|
280 |
+
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
|
281 |
+
|
282 |
+
|
283 |
+
def sample_stddev(arr):
|
284 |
+
mu = mean(arr)
|
285 |
+
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
|
286 |
+
|
287 |
+
|
288 |
+
def mean_stderr(arr):
|
289 |
+
return sample_stddev(arr) / math.sqrt(len(arr))
|
290 |
+
|
291 |
+
|
292 |
+
@register_metric(
|
293 |
+
metric="bypass",
|
294 |
+
higher_is_better=True,
|
295 |
+
output_type=["loglikelihood", "multiple_choice", "generate_until"],
|
296 |
+
aggregation="bypass",
|
297 |
+
)
|
298 |
+
def bypass(items):
|
299 |
+
return None
|
300 |
+
|
301 |
+
|
302 |
+
@register_metric(
|
303 |
+
metric="mcc",
|
304 |
+
higher_is_better=True,
|
305 |
+
output_type="multiple_choice",
|
306 |
+
aggregation="matthews_corrcoef",
|
307 |
+
)
|
308 |
+
def mcc_fn(items): # This is a passthrough function
|
309 |
+
return items
|
310 |
+
|
311 |
+
|
312 |
+
@register_metric(
|
313 |
+
metric="f1",
|
314 |
+
higher_is_better=True,
|
315 |
+
output_type="multiple_choice",
|
316 |
+
aggregation="f1",
|
317 |
+
)
|
318 |
+
def f1_fn(items): # This is a passthrough function
|
319 |
+
return items
|
320 |
+
|
321 |
+
|
322 |
+
@register_metric(
|
323 |
+
metric="bleu",
|
324 |
+
higher_is_better=True,
|
325 |
+
output_type="generate_until",
|
326 |
+
aggregation="bleu",
|
327 |
+
)
|
328 |
+
def bleu_fn(items): # This is a passthrough function
|
329 |
+
return items
|
330 |
+
|
331 |
+
|
332 |
+
@register_metric(
|
333 |
+
metric="chrf",
|
334 |
+
higher_is_better=True,
|
335 |
+
output_type="generate_until",
|
336 |
+
aggregation="chrf",
|
337 |
+
)
|
338 |
+
def chrf_fn(items): # This is a passthrough function
|
339 |
+
return items
|
340 |
+
|
341 |
+
|
342 |
+
@register_metric(
|
343 |
+
metric="ter",
|
344 |
+
higher_is_better=True,
|
345 |
+
output_type="generate_until",
|
346 |
+
aggregation="ter",
|
347 |
+
)
|
348 |
+
def ter_fn(items): # This is a passthrough function
|
349 |
+
return items
|
350 |
+
|
351 |
+
|
352 |
+
@register_metric(
|
353 |
+
metric="acc_all",
|
354 |
+
higher_is_better=True,
|
355 |
+
output_type="loglikelihood",
|
356 |
+
aggregation="mean",
|
357 |
+
)
|
358 |
+
def acc_all(items):
|
359 |
+
# Only count as correct if all answers are labeled correctly for each question
|
360 |
+
question_scoring_dict = {}
|
361 |
+
preds = list(zip(*items))[0]
|
362 |
+
docs = list(zip(*items))[1]
|
363 |
+
|
364 |
+
for doc, pred in zip(docs, preds):
|
365 |
+
paragraph_id = doc["idx"]["paragraph"]
|
366 |
+
question_id = doc["idx"]["question"]
|
367 |
+
if (paragraph_id, question_id) not in question_scoring_dict:
|
368 |
+
question_scoring_dict[(paragraph_id, question_id)] = []
|
369 |
+
|
370 |
+
gold_label = doc["label"] == 1
|
371 |
+
|
372 |
+
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
|
373 |
+
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
|
374 |
+
return acc
|
375 |
+
|
376 |
+
|
377 |
+
def acc_all_stderr(items):
|
378 |
+
# Only count as correct if all answers are labeled correctly for each question
|
379 |
+
question_scoring_dict = {}
|
380 |
+
preds = list(zip(*items))[0]
|
381 |
+
docs = list(zip(*items))[1]
|
382 |
+
|
383 |
+
for doc, pred in zip(docs, preds):
|
384 |
+
question_id = doc["idx"]["question"]
|
385 |
+
if question_id not in question_scoring_dict:
|
386 |
+
question_scoring_dict[question_id] = []
|
387 |
+
|
388 |
+
gold_label = doc["label"] == 1
|
389 |
+
question_scoring_dict[question_id].append(gold_label == pred)
|
390 |
+
|
391 |
+
acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
|
392 |
+
return acc
|
393 |
+
|
394 |
+
|
395 |
+
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
396 |
+
"""Compute max metric between prediction and each ground truth."""
|
397 |
+
scores_for_ground_truths = []
|
398 |
+
for ground_truth in ground_truths:
|
399 |
+
score = metric_fn(prediction, ground_truth)
|
400 |
+
scores_for_ground_truths.append(score)
|
401 |
+
return max(scores_for_ground_truths)
|
402 |
+
|
403 |
+
|
404 |
+
def weighted_mean(items):
|
405 |
+
a, b = zip(*items)
|
406 |
+
return sum(a) / sum(b)
|
407 |
+
|
408 |
+
|
409 |
+
def is_non_str_iterable(obj):
|
410 |
+
return isinstance(obj, Iterable) and not isinstance(obj, str)
|
411 |
+
|
412 |
+
|
413 |
+
def _sacreformat(refs, preds):
|
414 |
+
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
|
415 |
+
# Sacrebleu expects (List[str], List[List[str])
|
416 |
+
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
|
417 |
+
|
418 |
+
# Note [ref1_stream] is the first reference for each pred.
|
419 |
+
# So lists are size N and (M, N) for N preds and M possible refs for each pred
|
420 |
+
# This is a different order of dimensions that I would expect
|
421 |
+
|
422 |
+
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
|
423 |
+
# Must become List[List[str]] with the inner list corresponding to preds
|
424 |
+
if not is_non_str_iterable(refs):
|
425 |
+
refs = list(refs)
|
426 |
+
if not is_non_str_iterable(refs[0]):
|
427 |
+
refs = [[ref] for ref in refs]
|
428 |
+
refs = list(zip(*refs))
|
429 |
+
# Note the number of refs in each ref list much match the number of preds
|
430 |
+
|
431 |
+
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
|
432 |
+
if not is_non_str_iterable(preds):
|
433 |
+
preds = list(preds)
|
434 |
+
if is_non_str_iterable(preds[0]):
|
435 |
+
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
|
436 |
+
preds = [pred[0] for pred in preds]
|
437 |
+
|
438 |
+
return refs, preds
|
439 |
+
|
440 |
+
|
441 |
+
# stderr stuff
|
442 |
+
|
443 |
+
|
444 |
+
class _bootstrap_internal:
|
445 |
+
def __init__(self, f, n) -> None:
|
446 |
+
self.f = f
|
447 |
+
self.n = n
|
448 |
+
|
449 |
+
def __call__(self, v):
|
450 |
+
i, xs = v
|
451 |
+
rnd = random.Random()
|
452 |
+
rnd.seed(i)
|
453 |
+
res = []
|
454 |
+
for _ in range(self.n):
|
455 |
+
res.append(self.f(rnd.choices(xs, k=len(xs))))
|
456 |
+
return res
|
457 |
+
|
458 |
+
|
459 |
+
def bootstrap_stderr(f, xs, iters):
|
460 |
+
import multiprocessing as mp
|
461 |
+
|
462 |
+
pool = mp.Pool(mp.cpu_count())
|
463 |
+
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
|
464 |
+
# equivalent to stderr calculated without Bessel's correction in the stddev.
|
465 |
+
# Unfortunately, I haven't been able to figure out what the right correction is
|
466 |
+
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
|
467 |
+
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
|
468 |
+
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
|
469 |
+
res = []
|
470 |
+
chunk_size = min(1000, iters)
|
471 |
+
from tqdm import tqdm
|
472 |
+
|
473 |
+
print("bootstrapping for stddev:", f.__name__)
|
474 |
+
for bootstrap in tqdm(
|
475 |
+
pool.imap(
|
476 |
+
_bootstrap_internal(f, chunk_size),
|
477 |
+
[(i, xs) for i in range(iters // chunk_size)],
|
478 |
+
),
|
479 |
+
total=iters // chunk_size,
|
480 |
+
):
|
481 |
+
# sample w replacement
|
482 |
+
res.extend(bootstrap)
|
483 |
+
|
484 |
+
pool.close()
|
485 |
+
return sample_stddev(res)
|
486 |
+
|
487 |
+
|
488 |
+
def stderr_for_metric(metric, bootstrap_iters: int):
|
489 |
+
if bootstrap_iters <= 0:
|
490 |
+
# return no function (don't compute stderr) if bootstrap iters = 0
|
491 |
+
return None
|
492 |
+
|
493 |
+
bootstrappable = [
|
494 |
+
median,
|
495 |
+
matthews_corrcoef,
|
496 |
+
f1_score,
|
497 |
+
perplexity,
|
498 |
+
bleu,
|
499 |
+
chrf,
|
500 |
+
ter,
|
501 |
+
]
|
502 |
+
|
503 |
+
if metric in bootstrappable:
|
504 |
+
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
|
505 |
+
|
506 |
+
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
|
507 |
+
|
508 |
+
return stderr.get(metric, None)
|
509 |
+
|
510 |
+
|
511 |
+
def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
|
512 |
+
# Used to aggregate bootstrapped stderrs across subtasks in a group,
|
513 |
+
# when we are weighting by the size of each subtask.
|
514 |
+
#
|
515 |
+
|
516 |
+
assert len(stderrs) == len(sizes)
|
517 |
+
|
518 |
+
# formula source: https://en.wikipedia.org/wiki/Pooled_variance
|
519 |
+
# and: https://stats.stackexchange.com/a/4841331
|
520 |
+
# this empirically seems to match running `stderr_for_metric` on all instances
|
521 |
+
# from the subtasks concatenated with each other.
|
522 |
+
pooled_sample_var = (
|
523 |
+
sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)])
|
524 |
+
) / (sum(sizes) - len(sizes))
|
525 |
+
|
526 |
+
return np.sqrt(pooled_sample_var / sum(sizes))
|
527 |
+
|
528 |
+
|
529 |
+
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
|
530 |
+
assert (
|
531 |
+
metrics is not None
|
532 |
+
), "Need to pass a list of each subtask's metric for this stderr aggregation"
|
533 |
+
assert len(stderrs) == len(sizes) and len(sizes) == len(metrics)
|
534 |
+
|
535 |
+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation.
|
536 |
+
# This formula depends on sample means.
|
537 |
+
# removed because it seems to give erroneously huge stderrs for groupings of tasks
|
538 |
+
# and does not seem to match up with bootstrap-calculated stderrs for groups.
|
539 |
+
|
540 |
+
### don't use this unless a statistician has told you it's the right thing to do ###
|
541 |
+
|
542 |
+
# accumulators: we'll aggregate pairwise N - 1 times
|
543 |
+
variance = stderrs[0] ** 2
|
544 |
+
curr_size = sizes[0]
|
545 |
+
curr_score = metrics[0]
|
546 |
+
|
547 |
+
for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]):
|
548 |
+
curr_score = ((curr_score * curr_size) + (score * size)) / (
|
549 |
+
curr_size + size
|
550 |
+
) # NOTE: this assumes our aggregation fn is "mean"
|
551 |
+
|
552 |
+
variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / (
|
553 |
+
curr_size + size - 1
|
554 |
+
) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * (
|
555 |
+
curr_score - score
|
556 |
+
) ** 2
|
557 |
+
|
558 |
+
return np.sqrt(variance)
|
559 |
+
|
560 |
+
|
561 |
+
def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
|
562 |
+
# A helper function that is used to aggregate
|
563 |
+
# subtask scores cross-task.
|
564 |
+
# TODO: does not hold for non-mean aggregations
|
565 |
+
if not weight_by_size:
|
566 |
+
sizes = [1] * len(sizes)
|
567 |
+
|
568 |
+
assert len(metrics) == len(sizes)
|
569 |
+
|
570 |
+
return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes)
|
scripts/yans/lm-evaluation-harness/lm_eval/api/model.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import hashlib
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from typing import Dict, List, Optional, Tuple, Type, TypeVar
|
7 |
+
|
8 |
+
import transformers
|
9 |
+
from sqlitedict import SqliteDict
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from lm_eval import utils
|
13 |
+
|
14 |
+
|
15 |
+
eval_logger = logging.getLogger("lm-eval")
|
16 |
+
|
17 |
+
T = TypeVar("T", bound="LM")
|
18 |
+
|
19 |
+
|
20 |
+
class LM(abc.ABC):
|
21 |
+
def __init__(self) -> None:
|
22 |
+
"""Defines the interface that should be implemented by all LM subclasses.
|
23 |
+
LMs are assumed to take text (strings) as input and yield strings as output
|
24 |
+
(inputs/outputs should be tokenization-agnostic.)
|
25 |
+
|
26 |
+
"""
|
27 |
+
# set rank and world size to a single process, by default.
|
28 |
+
self._rank = 0
|
29 |
+
self._world_size = 1
|
30 |
+
self.cache_hook = CacheHook(None)
|
31 |
+
|
32 |
+
@abc.abstractmethod
|
33 |
+
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
|
34 |
+
"""Compute log-likelihood of generating a continuation from a context.
|
35 |
+
Downstream tasks should attempt to use loglikelihood instead of other
|
36 |
+
LM calls whenever possible.
|
37 |
+
|
38 |
+
:param requests: list[Instance]
|
39 |
+
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
|
40 |
+
`context: str`
|
41 |
+
Context string. Implementations of LM must be able to handle an
|
42 |
+
empty context string.
|
43 |
+
`continuation: str`
|
44 |
+
The continuation over which log likelihood will be calculated. If
|
45 |
+
there is a word boundary, the space should be in the continuation.
|
46 |
+
For example, context="hello" continuation=" world" is correct.
|
47 |
+
|
48 |
+
:return: list[tuple[float, bool]]
|
49 |
+
A list of pairs (logprob, isgreedy)
|
50 |
+
`logprob: float`
|
51 |
+
The log probability of `continuation`.
|
52 |
+
`isgreedy`:
|
53 |
+
Whether `continuation` would be generated by greedy sampling from `context`.
|
54 |
+
"""
|
55 |
+
pass
|
56 |
+
|
57 |
+
@abc.abstractmethod
|
58 |
+
def loglikelihood_rolling(self, requests) -> List[float]:
|
59 |
+
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
60 |
+
- We will use the full max context length of the model.
|
61 |
+
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
|
62 |
+
the max context length.
|
63 |
+
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
|
64 |
+
which may simply concatenate multiple documents together.
|
65 |
+
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
|
66 |
+
multiple chunks, the last input will still a full-sized context.
|
67 |
+
Example:
|
68 |
+
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
|
69 |
+
Prefix: BOS/EOS
|
70 |
+
Max context length: 4
|
71 |
+
Resulting input/prediction pairs:
|
72 |
+
|
73 |
+
INPUT: BOS 0 1 2
|
74 |
+
PRED: 0 1 2 3
|
75 |
+
|
76 |
+
INPUT: 3 4 5 6
|
77 |
+
PRED: 4 5 6 7
|
78 |
+
|
79 |
+
INPUT: 5 6 7 8
|
80 |
+
PRED: 8 9
|
81 |
+
|
82 |
+
Observe that:
|
83 |
+
1. Each token is predicted exactly once
|
84 |
+
2. For the last pair, we provide the full context, but only score the last two tokens
|
85 |
+
|
86 |
+
:param requests: list[Instance]
|
87 |
+
A list of Instance objects with property `args` which returns a tuple (context,).
|
88 |
+
string: str
|
89 |
+
String for which we are computing overall loglikelihood
|
90 |
+
:return: list[tuple[float]]
|
91 |
+
A list of tuples (logprob,)
|
92 |
+
logprob: float
|
93 |
+
The log probability of `context` conditioned on the BOS/EOS token.
|
94 |
+
Can also be overridden for custom cases by `prefix_token_id`.
|
95 |
+
"""
|
96 |
+
pass
|
97 |
+
|
98 |
+
# TODO: Add an optional max length
|
99 |
+
@abc.abstractmethod
|
100 |
+
def generate_until(self, requests) -> List[str]:
|
101 |
+
"""Generate greedily until a stopping sequence
|
102 |
+
|
103 |
+
:param requests: list[Instance]
|
104 |
+
A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
|
105 |
+
context: str
|
106 |
+
Context string
|
107 |
+
gen_kwargs: dict
|
108 |
+
A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
|
109 |
+
:return: list[str]
|
110 |
+
A list of model generated continuations.
|
111 |
+
continuation: str
|
112 |
+
The generated continuation.
|
113 |
+
"""
|
114 |
+
pass
|
115 |
+
|
116 |
+
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
|
117 |
+
"""
|
118 |
+
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
|
119 |
+
|
120 |
+
:param chat_history: list[dict[str, str]]
|
121 |
+
A list of dictionaries with keys 'role' and 'content'.
|
122 |
+
Values are strings representing the role name and the content of the message, respectively.
|
123 |
+
:return: str
|
124 |
+
A string representing the chat history in a format that can be used as input to the LM.
|
125 |
+
"""
|
126 |
+
raise NotImplementedError(
|
127 |
+
"To use this model with chat templates, please implement the 'apply_chat_template' method for your model type."
|
128 |
+
)
|
129 |
+
|
130 |
+
@classmethod
|
131 |
+
def create_from_arg_string(
|
132 |
+
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
|
133 |
+
) -> T:
|
134 |
+
"""
|
135 |
+
Creates an instance of the LM class using the given argument string and additional config.
|
136 |
+
|
137 |
+
Parameters:
|
138 |
+
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
|
139 |
+
- additional_config: Optional dictionary containing additional configuration parameters.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
- Instance of the LM class.
|
143 |
+
"""
|
144 |
+
additional_config = {} if additional_config is None else additional_config
|
145 |
+
args = utils.simple_parse_args_string(arg_string)
|
146 |
+
args2 = {k: v for k, v in additional_config.items() if v is not None}
|
147 |
+
return cls(**args, **args2)
|
148 |
+
|
149 |
+
@classmethod
|
150 |
+
def create_from_arg_obj(
|
151 |
+
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
|
152 |
+
) -> T:
|
153 |
+
"""
|
154 |
+
Creates an instance of the LM class using the given arg_obj
|
155 |
+
|
156 |
+
Parameters:
|
157 |
+
- arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
|
158 |
+
- additional_config: Optional dictionary containing additional configuration parameters.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
- Instance of the LM class.
|
162 |
+
"""
|
163 |
+
|
164 |
+
additional_config = {} if additional_config is None else additional_config
|
165 |
+
additional_config = {
|
166 |
+
k: v for k, v in additional_config.items() if v is not None
|
167 |
+
}
|
168 |
+
|
169 |
+
return cls(**arg_dict, **additional_config)
|
170 |
+
|
171 |
+
@property
|
172 |
+
def rank(self):
|
173 |
+
# used in the case of parallelism. Hardcoded to
|
174 |
+
# ensure no errors arise using API models which do
|
175 |
+
# not support multi-device parallelism nor expect it.
|
176 |
+
return self._rank
|
177 |
+
|
178 |
+
@property
|
179 |
+
def world_size(self):
|
180 |
+
# used in the case of parallelism. Hardcoded to
|
181 |
+
# ensure no errors arise using API models which do
|
182 |
+
# not support multi-device parallelism nor expect it.
|
183 |
+
return self._world_size
|
184 |
+
|
185 |
+
@property
|
186 |
+
def tokenizer_name(self) -> str:
|
187 |
+
"""Must be defined for LM subclasses which implement Chat Templating.
|
188 |
+
Should return the name of the tokenizer or chat template used.
|
189 |
+
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
|
190 |
+
"""
|
191 |
+
raise NotImplementedError(
|
192 |
+
"To use this model with chat templates, please implement the 'tokenizer_name' property."
|
193 |
+
)
|
194 |
+
|
195 |
+
@property
|
196 |
+
def chat_template(self) -> str:
|
197 |
+
"""Must be defined for LM subclasses that implement Chat Templating.
|
198 |
+
Should return the structure of the chat template applied to user/assistant messages.
|
199 |
+
This is used only to save in the experiment results for reproducibility.
|
200 |
+
"""
|
201 |
+
raise NotImplementedError(
|
202 |
+
"To use this model with chat templates, please implement the 'chat_template' property."
|
203 |
+
)
|
204 |
+
|
205 |
+
def set_cache_hook(self, cache_hook) -> None:
|
206 |
+
self.cache_hook = cache_hook
|
207 |
+
|
208 |
+
|
209 |
+
### SQLite-based caching of LM responses
|
210 |
+
def hash_args(attr, args):
|
211 |
+
dat = json.dumps([attr] + list(args))
|
212 |
+
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
|
213 |
+
|
214 |
+
|
215 |
+
class CacheHook:
|
216 |
+
def __init__(self, cachinglm) -> None:
|
217 |
+
if cachinglm is None:
|
218 |
+
self.dbdict = None
|
219 |
+
return
|
220 |
+
|
221 |
+
self.dbdict = cachinglm.dbdict
|
222 |
+
|
223 |
+
def add_partial(self, attr, req, res) -> None:
|
224 |
+
if self.dbdict is None:
|
225 |
+
return
|
226 |
+
hsh = hash_args(attr, req)
|
227 |
+
self.dbdict[hsh] = res
|
228 |
+
|
229 |
+
|
230 |
+
class CachingLM:
|
231 |
+
def __init__(self, lm, cache_db) -> None:
|
232 |
+
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
|
233 |
+
|
234 |
+
:param lm: LM
|
235 |
+
Underlying LM
|
236 |
+
:param cache_db: str
|
237 |
+
Path to cache db
|
238 |
+
"""
|
239 |
+
self.lm = lm
|
240 |
+
self.cache_db = cache_db
|
241 |
+
if os.path.dirname(cache_db):
|
242 |
+
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
|
243 |
+
self.dbdict = SqliteDict(cache_db, autocommit=True)
|
244 |
+
|
245 |
+
# add hook to lm
|
246 |
+
lm.set_cache_hook(self.get_cache_hook())
|
247 |
+
|
248 |
+
def __getattr__(self, attr: str):
|
249 |
+
lm_attr = getattr(self.lm, attr)
|
250 |
+
if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
|
251 |
+
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
|
252 |
+
return lm_attr
|
253 |
+
|
254 |
+
def fn(requests):
|
255 |
+
res = []
|
256 |
+
remaining_reqs = []
|
257 |
+
warned = False
|
258 |
+
# figure out which ones are cached and which ones are new
|
259 |
+
eval_logger.info(
|
260 |
+
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
|
261 |
+
)
|
262 |
+
for req in tqdm(requests, desc="Checking cached requests"):
|
263 |
+
hsh = hash_args(attr, req.args)
|
264 |
+
if attr == "generate_until" and req.args[1].get("do_sample", False):
|
265 |
+
# when we are doing non-greedy generation, don't use the cache
|
266 |
+
# (else every "randomly sampled" generation would be identical for repeats > 1).
|
267 |
+
if not warned:
|
268 |
+
eval_logger.warning(
|
269 |
+
f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
|
270 |
+
)
|
271 |
+
warned = True
|
272 |
+
res.append(None)
|
273 |
+
remaining_reqs.append(req)
|
274 |
+
elif hsh in self.dbdict:
|
275 |
+
ob = self.dbdict[hsh]
|
276 |
+
|
277 |
+
assert ob is not None
|
278 |
+
|
279 |
+
res.append(ob)
|
280 |
+
else:
|
281 |
+
res.append(None)
|
282 |
+
remaining_reqs.append(req)
|
283 |
+
eval_logger.info(
|
284 |
+
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
|
285 |
+
)
|
286 |
+
# actually run the LM on the requests that do not have cached results
|
287 |
+
rem_res = getattr(self.lm, attr)(remaining_reqs)
|
288 |
+
|
289 |
+
# stick the new ones back into the list and also cache any of the new ones
|
290 |
+
resptr = 0
|
291 |
+
for req, r in zip(remaining_reqs, rem_res):
|
292 |
+
while res[resptr] is not None:
|
293 |
+
resptr += 1
|
294 |
+
|
295 |
+
res[resptr] = r
|
296 |
+
|
297 |
+
# caching
|
298 |
+
hsh = hash_args(attr, req.args)
|
299 |
+
self.dbdict[hsh] = r
|
300 |
+
self.dbdict.commit()
|
301 |
+
|
302 |
+
return res
|
303 |
+
|
304 |
+
return fn
|
305 |
+
|
306 |
+
def get_cache_hook(self):
|
307 |
+
return CacheHook(self)
|
308 |
+
|
309 |
+
|
310 |
+
class TemplateLM(LM):
|
311 |
+
"""
|
312 |
+
A class acting as intermediary between the LM base class
|
313 |
+
and boilerplate often included in other LM subclasses.
|
314 |
+
"""
|
315 |
+
|
316 |
+
@property
|
317 |
+
@abc.abstractmethod
|
318 |
+
def eot_token_id(self):
|
319 |
+
pass
|
320 |
+
|
321 |
+
@property
|
322 |
+
def prefix_token_id(self):
|
323 |
+
# it is used as prefix for loglikelihood
|
324 |
+
return self.eot_token_id
|
325 |
+
|
326 |
+
@abc.abstractmethod
|
327 |
+
def tok_encode(self, string: str, **kwargs) -> List[int]:
|
328 |
+
"""
|
329 |
+
Tokenize a string using the model's tokenizer and return a list of token IDs.
|
330 |
+
"""
|
331 |
+
pass
|
332 |
+
|
333 |
+
@abc.abstractmethod
|
334 |
+
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
|
335 |
+
pass
|
336 |
+
|
337 |
+
def _encode_pair(
|
338 |
+
self, context: str, continuation: str
|
339 |
+
) -> Tuple[List[int], List[int]]:
|
340 |
+
n_spaces = len(context) - len(context.rstrip())
|
341 |
+
if n_spaces > 0:
|
342 |
+
continuation = context[-n_spaces:] + continuation
|
343 |
+
context = context[:-n_spaces]
|
344 |
+
|
345 |
+
model_class = getattr(self, "AUTO_MODEL_CLASS", None)
|
346 |
+
|
347 |
+
if model_class == transformers.AutoModelForSeq2SeqLM:
|
348 |
+
context_enc = self.tok_encode(context)
|
349 |
+
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
|
350 |
+
else:
|
351 |
+
whole_enc = self.tok_encode(context + continuation)
|
352 |
+
context_enc = self.tok_encode(context)
|
353 |
+
|
354 |
+
context_enc_len = len(context_enc)
|
355 |
+
continuation_enc = whole_enc[context_enc_len:]
|
356 |
+
|
357 |
+
return context_enc, continuation_enc
|
358 |
+
|
359 |
+
def loglikelihood(
|
360 |
+
self, requests, disable_tqdm: bool = False
|
361 |
+
) -> List[Tuple[float, bool]]:
|
362 |
+
new_reqs = []
|
363 |
+
for context, continuation in [req.args for req in requests]:
|
364 |
+
if context == "":
|
365 |
+
# BOS or EOS as context
|
366 |
+
context_enc, continuation_enc = (
|
367 |
+
[self.prefix_token_id],
|
368 |
+
self.tok_encode(continuation),
|
369 |
+
)
|
370 |
+
else:
|
371 |
+
context_enc, continuation_enc = self._encode_pair(context, continuation)
|
372 |
+
|
373 |
+
new_reqs.append(((context, continuation), context_enc, continuation_enc))
|
374 |
+
|
375 |
+
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
|
376 |
+
|
377 |
+
@abc.abstractmethod
|
378 |
+
def loglikelihood_rolling(
|
379 |
+
self, requests, disable_tqdm: bool = False
|
380 |
+
) -> List[float]:
|
381 |
+
pass
|
382 |
+
|
383 |
+
@abc.abstractmethod
|
384 |
+
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
|
385 |
+
pass
|
scripts/yans/lm-evaluation-harness/lm_eval/api/registry.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Callable, Dict
|
3 |
+
|
4 |
+
import evaluate as hf_evaluate
|
5 |
+
|
6 |
+
from lm_eval.api.model import LM
|
7 |
+
|
8 |
+
|
9 |
+
eval_logger = logging.getLogger("lm-eval")
|
10 |
+
|
11 |
+
MODEL_REGISTRY = {}
|
12 |
+
|
13 |
+
|
14 |
+
def register_model(*names):
|
15 |
+
# either pass a list or a single alias.
|
16 |
+
# function receives them as a tuple of strings
|
17 |
+
|
18 |
+
def decorate(cls):
|
19 |
+
for name in names:
|
20 |
+
assert issubclass(
|
21 |
+
cls, LM
|
22 |
+
), f"Model '{name}' ({cls.__name__}) must extend LM class"
|
23 |
+
|
24 |
+
assert (
|
25 |
+
name not in MODEL_REGISTRY
|
26 |
+
), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
|
27 |
+
|
28 |
+
MODEL_REGISTRY[name] = cls
|
29 |
+
return cls
|
30 |
+
|
31 |
+
return decorate
|
32 |
+
|
33 |
+
|
34 |
+
def get_model(model_name):
|
35 |
+
try:
|
36 |
+
return MODEL_REGISTRY[model_name]
|
37 |
+
except KeyError:
|
38 |
+
raise ValueError(
|
39 |
+
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
TASK_REGISTRY = {}
|
44 |
+
GROUP_REGISTRY = {}
|
45 |
+
ALL_TASKS = set()
|
46 |
+
func2task_index = {}
|
47 |
+
|
48 |
+
|
49 |
+
def register_task(name):
|
50 |
+
def decorate(fn):
|
51 |
+
assert (
|
52 |
+
name not in TASK_REGISTRY
|
53 |
+
), f"task named '{name}' conflicts with existing registered task!"
|
54 |
+
|
55 |
+
TASK_REGISTRY[name] = fn
|
56 |
+
ALL_TASKS.add(name)
|
57 |
+
func2task_index[fn.__name__] = name
|
58 |
+
return fn
|
59 |
+
|
60 |
+
return decorate
|
61 |
+
|
62 |
+
|
63 |
+
def register_group(name):
|
64 |
+
def decorate(fn):
|
65 |
+
func_name = func2task_index[fn.__name__]
|
66 |
+
if name in GROUP_REGISTRY:
|
67 |
+
GROUP_REGISTRY[name].append(func_name)
|
68 |
+
else:
|
69 |
+
GROUP_REGISTRY[name] = [func_name]
|
70 |
+
ALL_TASKS.add(name)
|
71 |
+
return fn
|
72 |
+
|
73 |
+
return decorate
|
74 |
+
|
75 |
+
|
76 |
+
OUTPUT_TYPE_REGISTRY = {}
|
77 |
+
METRIC_REGISTRY = {}
|
78 |
+
METRIC_AGGREGATION_REGISTRY = {}
|
79 |
+
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
|
80 |
+
HIGHER_IS_BETTER_REGISTRY = {}
|
81 |
+
FILTER_REGISTRY = {}
|
82 |
+
|
83 |
+
DEFAULT_METRIC_REGISTRY = {
|
84 |
+
"loglikelihood": [
|
85 |
+
"perplexity",
|
86 |
+
"acc",
|
87 |
+
],
|
88 |
+
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
|
89 |
+
"multiple_choice": ["acc", "acc_norm"],
|
90 |
+
"generate_until": ["exact_match"],
|
91 |
+
}
|
92 |
+
|
93 |
+
|
94 |
+
def register_metric(**args):
|
95 |
+
# TODO: do we want to enforce a certain interface to registered metrics?
|
96 |
+
def decorate(fn):
|
97 |
+
assert "metric" in args
|
98 |
+
name = args["metric"]
|
99 |
+
|
100 |
+
for key, registry in [
|
101 |
+
("metric", METRIC_REGISTRY),
|
102 |
+
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
|
103 |
+
("aggregation", METRIC_AGGREGATION_REGISTRY),
|
104 |
+
]:
|
105 |
+
if key in args:
|
106 |
+
value = args[key]
|
107 |
+
assert (
|
108 |
+
value not in registry
|
109 |
+
), f"{key} named '{value}' conflicts with existing registered {key}!"
|
110 |
+
|
111 |
+
if key == "metric":
|
112 |
+
registry[name] = fn
|
113 |
+
elif key == "aggregation":
|
114 |
+
registry[name] = AGGREGATION_REGISTRY[value]
|
115 |
+
else:
|
116 |
+
registry[name] = value
|
117 |
+
|
118 |
+
return fn
|
119 |
+
|
120 |
+
return decorate
|
121 |
+
|
122 |
+
|
123 |
+
def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
|
124 |
+
if not hf_evaluate_metric:
|
125 |
+
if name in METRIC_REGISTRY:
|
126 |
+
return METRIC_REGISTRY[name]
|
127 |
+
else:
|
128 |
+
eval_logger.warning(
|
129 |
+
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
|
130 |
+
)
|
131 |
+
|
132 |
+
try:
|
133 |
+
metric_object = hf_evaluate.load(name)
|
134 |
+
return metric_object.compute
|
135 |
+
except Exception:
|
136 |
+
eval_logger.error(
|
137 |
+
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
def register_aggregation(name: str):
|
142 |
+
def decorate(fn):
|
143 |
+
assert (
|
144 |
+
name not in AGGREGATION_REGISTRY
|
145 |
+
), f"aggregation named '{name}' conflicts with existing registered aggregation!"
|
146 |
+
|
147 |
+
AGGREGATION_REGISTRY[name] = fn
|
148 |
+
return fn
|
149 |
+
|
150 |
+
return decorate
|
151 |
+
|
152 |
+
|
153 |
+
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
|
154 |
+
try:
|
155 |
+
return AGGREGATION_REGISTRY[name]
|
156 |
+
except KeyError:
|
157 |
+
eval_logger.warning(f"{name} not a registered aggregation metric!")
|
158 |
+
|
159 |
+
|
160 |
+
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
|
161 |
+
try:
|
162 |
+
return METRIC_AGGREGATION_REGISTRY[name]
|
163 |
+
except KeyError:
|
164 |
+
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
|
165 |
+
|
166 |
+
|
167 |
+
def is_higher_better(metric_name) -> bool:
|
168 |
+
try:
|
169 |
+
return HIGHER_IS_BETTER_REGISTRY[metric_name]
|
170 |
+
except KeyError:
|
171 |
+
eval_logger.warning(
|
172 |
+
f"higher_is_better not specified for metric '{metric_name}'!"
|
173 |
+
)
|
174 |
+
|
175 |
+
|
176 |
+
def register_filter(name):
|
177 |
+
def decorate(cls):
|
178 |
+
if name in FILTER_REGISTRY:
|
179 |
+
eval_logger.info(
|
180 |
+
f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}"
|
181 |
+
)
|
182 |
+
FILTER_REGISTRY[name] = cls
|
183 |
+
return cls
|
184 |
+
|
185 |
+
return decorate
|
186 |
+
|
187 |
+
|
188 |
+
def get_filter(filter_name: str) -> type:
|
189 |
+
try:
|
190 |
+
return FILTER_REGISTRY[filter_name]
|
191 |
+
except KeyError:
|
192 |
+
eval_logger.warning(f"filter `{filter_name}` is not registered!")
|
scripts/yans/lm-evaluation-harness/lm_eval/api/samplers.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import datasets
|
4 |
+
|
5 |
+
|
6 |
+
class ContextSampler:
|
7 |
+
def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
|
8 |
+
self.rnd = rnd
|
9 |
+
if not self.rnd:
|
10 |
+
raise ValueError(
|
11 |
+
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
|
12 |
+
)
|
13 |
+
|
14 |
+
self.task = task
|
15 |
+
self.config = task._config
|
16 |
+
|
17 |
+
self.target_delimiter = self.config.target_delimiter
|
18 |
+
self.fewshot_delimiter = self.config.fewshot_delimiter
|
19 |
+
|
20 |
+
if (
|
21 |
+
self.config.fewshot_config is not None
|
22 |
+
and self.config.fewshot_config.get("doc_to_text", None) is not None
|
23 |
+
):
|
24 |
+
self.doc_to_text = partial(
|
25 |
+
self.task.doc_to_text,
|
26 |
+
doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
|
27 |
+
)
|
28 |
+
else:
|
29 |
+
self.doc_to_text = self.task.doc_to_text
|
30 |
+
|
31 |
+
if (
|
32 |
+
self.config.fewshot_config is not None
|
33 |
+
and self.config.fewshot_config.get("doc_to_target", None) is not None
|
34 |
+
):
|
35 |
+
self.doc_to_target = partial(
|
36 |
+
self.task.doc_to_target,
|
37 |
+
doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
|
38 |
+
)
|
39 |
+
else:
|
40 |
+
self.doc_to_target = self.task.doc_to_target
|
41 |
+
|
42 |
+
if (
|
43 |
+
self.config.fewshot_config is not None
|
44 |
+
and self.config.fewshot_config.get("doc_to_choice", None) is not None
|
45 |
+
):
|
46 |
+
self.doc_to_choice = partial(
|
47 |
+
self.task.doc_to_choice,
|
48 |
+
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
self.doc_to_choice = self.task.doc_to_choice
|
52 |
+
|
53 |
+
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
|
54 |
+
if fewshot_indices: # subset few-shot docs from
|
55 |
+
if not isinstance(self.docs, datasets.Dataset):
|
56 |
+
raise ValueError(
|
57 |
+
"Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
|
58 |
+
)
|
59 |
+
self.docs = self.docs.select(fewshot_indices)
|
60 |
+
|
61 |
+
def get_context(self, doc, num_fewshot):
|
62 |
+
# draw an extra fewshot sample if using same split as evaluating on
|
63 |
+
n_samples = (
|
64 |
+
num_fewshot + 1
|
65 |
+
if self.config.fewshot_split == self.config.test_split
|
66 |
+
else num_fewshot
|
67 |
+
)
|
68 |
+
|
69 |
+
# draw `n_samples` docs from fewshot_docs
|
70 |
+
fewshotex = self.sample(n_samples)
|
71 |
+
|
72 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
73 |
+
# TODO: should we just stop people from using fewshot from same split as evaluating?
|
74 |
+
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
|
75 |
+
|
76 |
+
labeled_examples = ""
|
77 |
+
for doc in selected_docs:
|
78 |
+
doc_content = self.doc_to_text(doc)
|
79 |
+
doc_target = self.doc_to_target(doc)
|
80 |
+
labeled_examples += (
|
81 |
+
doc_content
|
82 |
+
if self.config.doc_to_choice is None or isinstance(doc_content, str)
|
83 |
+
else self.doc_to_choice(doc)[doc_content]
|
84 |
+
)
|
85 |
+
labeled_examples += self.target_delimiter
|
86 |
+
if doc_target != "":
|
87 |
+
labeled_examples += (
|
88 |
+
str(doc_target[0])
|
89 |
+
if isinstance(doc_target, list)
|
90 |
+
else doc_target
|
91 |
+
if self.config.doc_to_choice is None or isinstance(doc_target, str)
|
92 |
+
else str(self.doc_to_choice(doc)[doc_target])
|
93 |
+
)
|
94 |
+
labeled_examples += self.fewshot_delimiter
|
95 |
+
|
96 |
+
return labeled_examples
|
97 |
+
|
98 |
+
def get_chat_context(
|
99 |
+
self,
|
100 |
+
doc,
|
101 |
+
num_fewshot,
|
102 |
+
fewshot_as_multiturn: bool = False,
|
103 |
+
):
|
104 |
+
chat_history = []
|
105 |
+
# draw an extra fewshot sample if using same split as evaluating on
|
106 |
+
n_samples = (
|
107 |
+
num_fewshot + 1
|
108 |
+
if self.config.fewshot_split == self.config.test_split
|
109 |
+
else num_fewshot
|
110 |
+
)
|
111 |
+
# draw `n_samples` docs from fewshot_docs
|
112 |
+
fewshotex = self.sample(n_samples)
|
113 |
+
|
114 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
115 |
+
# TODO: should we just stop people from using fewshot from same split as evaluating?
|
116 |
+
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
|
117 |
+
|
118 |
+
if fewshot_as_multiturn:
|
119 |
+
for doc in selected_docs:
|
120 |
+
doc_content = self.doc_to_text(doc)
|
121 |
+
doc_target = self.doc_to_target(doc)
|
122 |
+
chat_history.append(
|
123 |
+
{
|
124 |
+
"role": "user",
|
125 |
+
"content": doc_content
|
126 |
+
if self.config.doc_to_choice is None
|
127 |
+
or isinstance(doc_content, str)
|
128 |
+
else self.doc_to_choice(doc)[doc_content],
|
129 |
+
}
|
130 |
+
)
|
131 |
+
chat_history.append(
|
132 |
+
{
|
133 |
+
"role": "assistant",
|
134 |
+
"content": str(doc_target[0])
|
135 |
+
if isinstance(doc_target, list)
|
136 |
+
else doc_target
|
137 |
+
if self.config.doc_to_choice is None
|
138 |
+
or isinstance(doc_target, str)
|
139 |
+
else str(self.doc_to_choice(doc)[doc_target]),
|
140 |
+
}
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
# get fewshot context as one user turn
|
144 |
+
chat_history.append(
|
145 |
+
{"role": "user", "content": self.get_context(doc, num_fewshot)}
|
146 |
+
)
|
147 |
+
|
148 |
+
return chat_history
|
149 |
+
|
150 |
+
def sample(self, n):
|
151 |
+
"""
|
152 |
+
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
|
153 |
+
"""
|
154 |
+
|
155 |
+
return self.rnd.sample(self.docs, n)
|
156 |
+
|
157 |
+
|
158 |
+
class FirstNSampler(ContextSampler):
|
159 |
+
def sample(self, n) -> None:
|
160 |
+
"""
|
161 |
+
Draw the first `n` samples in order from the specified split.
|
162 |
+
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
|
163 |
+
"""
|
164 |
+
assert (
|
165 |
+
n <= len(self.docs)
|
166 |
+
), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
|
167 |
+
return self.docs[:n]
|
168 |
+
|
169 |
+
|
170 |
+
class BalancedSampler(ContextSampler):
|
171 |
+
def sample(self, n) -> None:
|
172 |
+
"""
|
173 |
+
TODO: this should return approximately class-balanced samples from our fewshot examples.
|
174 |
+
TODO: what order should they be in? maybe random?
|
175 |
+
"""
|
176 |
+
|
177 |
+
pass
|
178 |
+
|
179 |
+
|
180 |
+
class ManualSampler(ContextSampler):
|
181 |
+
def sample(self, n) -> None:
|
182 |
+
""" """
|
183 |
+
pass
|
184 |
+
|
185 |
+
|
186 |
+
SAMPLER_REGISTRY = {
|
187 |
+
"default": ContextSampler,
|
188 |
+
"first_n": FirstNSampler,
|
189 |
+
}
|
190 |
+
|
191 |
+
|
192 |
+
def get_sampler(name):
|
193 |
+
try:
|
194 |
+
return SAMPLER_REGISTRY[name]
|
195 |
+
except KeyError:
|
196 |
+
raise ValueError(
|
197 |
+
f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
|
198 |
+
)
|
scripts/yans/lm-evaluation-harness/lm_eval/api/task.py
ADDED
@@ -0,0 +1,1674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import ast
|
3 |
+
import logging
|
4 |
+
import random
|
5 |
+
import re
|
6 |
+
from collections.abc import Callable
|
7 |
+
from copy import deepcopy
|
8 |
+
from dataclasses import asdict, dataclass
|
9 |
+
from inspect import getsource
|
10 |
+
from typing import (
|
11 |
+
Any,
|
12 |
+
Dict,
|
13 |
+
Iterable,
|
14 |
+
Iterator,
|
15 |
+
List,
|
16 |
+
Literal,
|
17 |
+
Mapping,
|
18 |
+
Optional,
|
19 |
+
Tuple,
|
20 |
+
Union,
|
21 |
+
)
|
22 |
+
|
23 |
+
import datasets
|
24 |
+
import numpy as np
|
25 |
+
from tqdm import tqdm
|
26 |
+
|
27 |
+
from lm_eval import utils
|
28 |
+
from lm_eval.api import samplers
|
29 |
+
from lm_eval.api.instance import Instance, OutputType
|
30 |
+
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
|
31 |
+
from lm_eval.api.registry import (
|
32 |
+
AGGREGATION_REGISTRY,
|
33 |
+
DEFAULT_METRIC_REGISTRY,
|
34 |
+
get_aggregation,
|
35 |
+
get_metric,
|
36 |
+
get_metric_aggregation,
|
37 |
+
is_higher_better,
|
38 |
+
)
|
39 |
+
from lm_eval.caching.cache import load_from_cache, save_to_cache
|
40 |
+
from lm_eval.filters import build_filter_ensemble
|
41 |
+
from lm_eval.prompts import get_prompt
|
42 |
+
|
43 |
+
|
44 |
+
ALL_OUTPUT_TYPES = [
|
45 |
+
"loglikelihood",
|
46 |
+
"multiple_choice",
|
47 |
+
"loglikelihood_rolling",
|
48 |
+
"generate_until",
|
49 |
+
]
|
50 |
+
|
51 |
+
eval_logger = logging.getLogger("lm-eval")
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class TaskConfig(dict):
|
56 |
+
# task naming/registry
|
57 |
+
task: Optional[str] = None
|
58 |
+
task_alias: Optional[str] = None
|
59 |
+
tag: Optional[Union[str, list]] = None
|
60 |
+
group: Optional[Union[str, list]] = None
|
61 |
+
# HF dataset options.
|
62 |
+
# which dataset to use,
|
63 |
+
# and what splits for what purpose
|
64 |
+
dataset_path: Optional[str] = None
|
65 |
+
dataset_name: Optional[str] = None
|
66 |
+
dataset_kwargs: Optional[dict] = None
|
67 |
+
training_split: Optional[str] = None
|
68 |
+
validation_split: Optional[str] = None
|
69 |
+
test_split: Optional[str] = None
|
70 |
+
fewshot_split: Optional[str] = (
|
71 |
+
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
|
72 |
+
)
|
73 |
+
# formatting / prompting options.
|
74 |
+
# see docs/advanced_task_guide.md for more info
|
75 |
+
process_docs: Optional[Callable] = None
|
76 |
+
doc_to_text: Optional[Union[Callable, str]] = None
|
77 |
+
doc_to_target: Optional[Union[Callable, str]] = None
|
78 |
+
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
|
79 |
+
process_results: Optional[Union[Callable, str]] = None
|
80 |
+
use_prompt: Optional[str] = None
|
81 |
+
description: str = ""
|
82 |
+
target_delimiter: str = " "
|
83 |
+
fewshot_delimiter: str = "\n\n"
|
84 |
+
fewshot_config: Optional[dict] = None
|
85 |
+
# runtime configuration options
|
86 |
+
num_fewshot: Optional[int] = None
|
87 |
+
# scoring options
|
88 |
+
metric_list: Optional[list] = None
|
89 |
+
output_type: OutputType = "generate_until"
|
90 |
+
generation_kwargs: Optional[dict] = None
|
91 |
+
repeats: int = 1
|
92 |
+
filter_list: Optional[Union[str, list]] = None
|
93 |
+
should_decontaminate: bool = False
|
94 |
+
doc_to_decontamination_query: Optional[str] = None
|
95 |
+
metadata: Optional[dict] = (
|
96 |
+
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
|
97 |
+
)
|
98 |
+
|
99 |
+
def __post_init__(self) -> None:
|
100 |
+
if self.group is not None:
|
101 |
+
eval_logger.warning(
|
102 |
+
"A task YAML file was found to contain a `group` key. Groups which provide aggregate scores over several subtasks now require a separate config file--if not aggregating, you may want to use the `tag` config option instead within your config. Setting `group` within a TaskConfig will be deprecated in v0.4.4. Please see https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/task_guide.md for more information."
|
103 |
+
)
|
104 |
+
|
105 |
+
if self.tag is None:
|
106 |
+
self.tag = self.group
|
107 |
+
else:
|
108 |
+
raise ValueError(
|
109 |
+
"Got both a `group` and `tag` entry within a TaskConfig. Please use one or the other--`group` values will be deprecated in v0.4.4."
|
110 |
+
)
|
111 |
+
|
112 |
+
if self.generation_kwargs is not None:
|
113 |
+
if self.output_type != "generate_until":
|
114 |
+
eval_logger.warning(
|
115 |
+
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
|
116 |
+
)
|
117 |
+
|
118 |
+
if "temperature" in self.generation_kwargs:
|
119 |
+
self.generation_kwargs["temperature"] = float(
|
120 |
+
self.generation_kwargs["temperature"]
|
121 |
+
)
|
122 |
+
|
123 |
+
if "until" not in self.generation_kwargs:
|
124 |
+
self.generation_kwargs["until"] = [self.fewshot_delimiter]
|
125 |
+
else:
|
126 |
+
if self.output_type == "generate_until":
|
127 |
+
# ensure that we greedily generate in absence of explicit arguments otherwise
|
128 |
+
self.generation_kwargs = {
|
129 |
+
"until": (
|
130 |
+
None
|
131 |
+
if self.fewshot_delimiter is None
|
132 |
+
else [self.fewshot_delimiter]
|
133 |
+
),
|
134 |
+
"do_sample": False,
|
135 |
+
}
|
136 |
+
|
137 |
+
def __getitem__(self, item):
|
138 |
+
return getattr(self, item)
|
139 |
+
|
140 |
+
def __setitem__(self, item, value):
|
141 |
+
return setattr(self, item, value)
|
142 |
+
|
143 |
+
def to_dict(self, keep_callable: bool = False) -> dict:
|
144 |
+
"""dumps the current config as a dictionary object, as a printable format.
|
145 |
+
null fields will not be printed.
|
146 |
+
Used for dumping results alongside full task configuration
|
147 |
+
|
148 |
+
:return: dict
|
149 |
+
A printable dictionary version of the TaskConfig object.
|
150 |
+
|
151 |
+
# TODO: should any default value in the TaskConfig not be printed?
|
152 |
+
"""
|
153 |
+
cfg_dict = asdict(self)
|
154 |
+
# remove values that are `None`
|
155 |
+
for k, v in list(cfg_dict.items()):
|
156 |
+
if v is None:
|
157 |
+
cfg_dict.pop(k)
|
158 |
+
elif k == "metric_list":
|
159 |
+
for metric_dict in v:
|
160 |
+
for metric_key, metric_value in metric_dict.items():
|
161 |
+
if callable(metric_value):
|
162 |
+
metric_dict[metric_key] = self.serialize_function(
|
163 |
+
metric_value, keep_callable=keep_callable
|
164 |
+
)
|
165 |
+
cfg_dict[k] = v
|
166 |
+
elif callable(v):
|
167 |
+
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
|
168 |
+
return cfg_dict
|
169 |
+
|
170 |
+
def serialize_function(
|
171 |
+
self, value: Union[Callable, str], keep_callable=False
|
172 |
+
) -> Union[Callable, str]:
|
173 |
+
"""Serializes a given function or string.
|
174 |
+
|
175 |
+
If 'keep_callable' is True, the original callable is returned.
|
176 |
+
Otherwise, attempts to return the source code of the callable using 'getsource'.
|
177 |
+
"""
|
178 |
+
if keep_callable:
|
179 |
+
return value
|
180 |
+
else:
|
181 |
+
try:
|
182 |
+
return getsource(value)
|
183 |
+
except (TypeError, OSError):
|
184 |
+
return str(value)
|
185 |
+
|
186 |
+
|
187 |
+
class Task(abc.ABC):
|
188 |
+
"""A task represents an entire benchmark including its dataset, problems,
|
189 |
+
answers, and evaluation methods. See BoolQ for a simple example implementation
|
190 |
+
|
191 |
+
A `doc` can be any python object which represents one instance of evaluation.
|
192 |
+
This is usually a dictionary e.g.
|
193 |
+
{"question": ..., "answer": ...} or
|
194 |
+
{"question": ..., question, answer)
|
195 |
+
"""
|
196 |
+
|
197 |
+
VERSION: Optional[Union[int, str]] = None
|
198 |
+
|
199 |
+
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
|
200 |
+
# or a path to a custom `datasets` loading script.
|
201 |
+
DATASET_PATH: Optional[str] = None
|
202 |
+
|
203 |
+
# The name of a subset within `DATASET_PATH`.
|
204 |
+
DATASET_NAME: Optional[str] = None
|
205 |
+
|
206 |
+
OUTPUT_TYPE: Optional[OutputType] = None
|
207 |
+
|
208 |
+
def __init__(
|
209 |
+
self,
|
210 |
+
data_dir: Optional[str] = None,
|
211 |
+
cache_dir: Optional[str] = None,
|
212 |
+
download_mode: Optional[datasets.DownloadMode] = None,
|
213 |
+
config: Optional[Mapping] = None, # Union[dict, TaskConfig]
|
214 |
+
) -> None:
|
215 |
+
"""
|
216 |
+
:param data_dir: str
|
217 |
+
Stores the path to a local folder containing the `Task`'s data files.
|
218 |
+
Use this to specify the path to manually downloaded data (usually when
|
219 |
+
the dataset is not publicly accessible).
|
220 |
+
:param cache_dir: str
|
221 |
+
The directory to read/write the `Task` dataset. This follows the
|
222 |
+
HuggingFace `datasets` API with the default cache directory located at:
|
223 |
+
`~/.cache/huggingface/datasets`
|
224 |
+
NOTE: You can change the cache location globally for a given process
|
225 |
+
to another directory:
|
226 |
+
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
227 |
+
:param download_mode: datasets.DownloadMode
|
228 |
+
How to treat pre-existing `Task` downloads and data.
|
229 |
+
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
230 |
+
Reuse download and reuse dataset.
|
231 |
+
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
232 |
+
Reuse download with fresh dataset.
|
233 |
+
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
234 |
+
Fresh download and fresh dataset.
|
235 |
+
"""
|
236 |
+
self.download(data_dir, cache_dir, download_mode)
|
237 |
+
self._training_docs: Optional[list] = None
|
238 |
+
self._fewshot_docs: Optional[list] = None
|
239 |
+
self._instances: Optional[List[Instance]] = None
|
240 |
+
|
241 |
+
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
|
242 |
+
|
243 |
+
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
|
244 |
+
self.fewshot_rnd: Optional[random.Random] = (
|
245 |
+
None # purposely induce errors in case of improper usage
|
246 |
+
)
|
247 |
+
|
248 |
+
def download(
|
249 |
+
self,
|
250 |
+
data_dir: Optional[str] = None,
|
251 |
+
cache_dir: Optional[str] = None,
|
252 |
+
download_mode=None,
|
253 |
+
) -> None:
|
254 |
+
"""Downloads and returns the task dataset.
|
255 |
+
Override this method to download the dataset from a custom API.
|
256 |
+
|
257 |
+
:param data_dir: str
|
258 |
+
Stores the path to a local folder containing the `Task`'s data files.
|
259 |
+
Use this to specify the path to manually downloaded data (usually when
|
260 |
+
the dataset is not publicly accessible).
|
261 |
+
:param cache_dir: str
|
262 |
+
The directory to read/write the `Task` dataset. This follows the
|
263 |
+
HuggingFace `datasets` API with the default cache directory located at:
|
264 |
+
`~/.cache/huggingface/datasets`
|
265 |
+
NOTE: You can change the cache location globally for a given process
|
266 |
+
by setting the shell environment variable, `HF_DATASETS_CACHE`,
|
267 |
+
to another directory:
|
268 |
+
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
269 |
+
:param download_mode: datasets.DownloadMode
|
270 |
+
How to treat pre-existing `Task` downloads and data.
|
271 |
+
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
272 |
+
Reuse download and reuse dataset.
|
273 |
+
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
274 |
+
Reuse download with fresh dataset.
|
275 |
+
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
276 |
+
Fresh download and fresh dataset.
|
277 |
+
"""
|
278 |
+
self.dataset = datasets.load_dataset(
|
279 |
+
path=self.DATASET_PATH,
|
280 |
+
name=self.DATASET_NAME,
|
281 |
+
data_dir=data_dir,
|
282 |
+
cache_dir=cache_dir,
|
283 |
+
download_mode=download_mode,
|
284 |
+
)
|
285 |
+
|
286 |
+
@property
|
287 |
+
def config(self) -> TaskConfig:
|
288 |
+
"""Returns the TaskConfig associated with this class."""
|
289 |
+
return self._config
|
290 |
+
|
291 |
+
@abc.abstractmethod
|
292 |
+
def has_training_docs(self):
|
293 |
+
"""Whether the task has a training set"""
|
294 |
+
pass
|
295 |
+
|
296 |
+
@abc.abstractmethod
|
297 |
+
def has_validation_docs(self):
|
298 |
+
"""Whether the task has a validation set"""
|
299 |
+
pass
|
300 |
+
|
301 |
+
@abc.abstractmethod
|
302 |
+
def has_test_docs(self):
|
303 |
+
"""Whether the task has a test set"""
|
304 |
+
pass
|
305 |
+
|
306 |
+
def training_docs(self) -> Iterable:
|
307 |
+
"""
|
308 |
+
:return: Iterable[obj]
|
309 |
+
A iterable of any object, that doc_to_text can handle
|
310 |
+
"""
|
311 |
+
return []
|
312 |
+
|
313 |
+
def validation_docs(self) -> Iterable:
|
314 |
+
"""
|
315 |
+
:return: Iterable[obj]
|
316 |
+
A iterable of any object, that doc_to_text can handle
|
317 |
+
"""
|
318 |
+
return []
|
319 |
+
|
320 |
+
def test_docs(self) -> Iterable:
|
321 |
+
"""
|
322 |
+
:return: Iterable[obj]
|
323 |
+
A iterable of any object, that doc_to_text can handle
|
324 |
+
"""
|
325 |
+
return []
|
326 |
+
|
327 |
+
def fewshot_docs(self) -> Iterable:
|
328 |
+
"""
|
329 |
+
:return: Iterable[obj]
|
330 |
+
A iterable of any object, that doc_to_text can handle
|
331 |
+
"""
|
332 |
+
if self.has_training_docs():
|
333 |
+
return self.training_docs()
|
334 |
+
elif self.has_validation_docs():
|
335 |
+
return self.validation_docs()
|
336 |
+
else:
|
337 |
+
eval_logger.warning(
|
338 |
+
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
|
339 |
+
", using test_docs as fewshot_docs but this is not recommended."
|
340 |
+
)
|
341 |
+
return self.test_docs()
|
342 |
+
|
343 |
+
def _process_doc(self, doc: dict) -> dict:
|
344 |
+
"""
|
345 |
+
Override this to process (detokenize, strip, replace, etc.) individual
|
346 |
+
documents. This can be used in a map over documents of a data split.
|
347 |
+
E.g. `map(self._process_doc, self.dataset["validation"])`
|
348 |
+
|
349 |
+
:return: dict
|
350 |
+
The processed version of the specified `doc`.
|
351 |
+
"""
|
352 |
+
return doc
|
353 |
+
|
354 |
+
@property
|
355 |
+
def instances(self) -> List[Instance]:
|
356 |
+
"""After calling `task.build_all_requests()`, tasks
|
357 |
+
maintain a list of the dataset instances which will be evaluated.
|
358 |
+
"""
|
359 |
+
return self._instances
|
360 |
+
|
361 |
+
def fewshot_examples(self, k, rnd):
|
362 |
+
if self._training_docs is None:
|
363 |
+
self._training_docs = list(self.training_docs())
|
364 |
+
|
365 |
+
return rnd.sample(self._training_docs, k)
|
366 |
+
|
367 |
+
def doc_to_decontamination_query(self, doc):
|
368 |
+
raise NotImplementedError(
|
369 |
+
"Override doc_to_decontamination_query with document specific decontamination query."
|
370 |
+
)
|
371 |
+
|
372 |
+
@abc.abstractmethod
|
373 |
+
def doc_to_text(self, doc):
|
374 |
+
pass
|
375 |
+
|
376 |
+
@abc.abstractmethod
|
377 |
+
def doc_to_target(self, doc):
|
378 |
+
pass
|
379 |
+
|
380 |
+
def build_all_requests(
|
381 |
+
self,
|
382 |
+
*,
|
383 |
+
limit: Union[int, None] = None,
|
384 |
+
rank: int = 0,
|
385 |
+
world_size: int = 1,
|
386 |
+
cache_requests: bool = False,
|
387 |
+
rewrite_requests_cache: bool = False,
|
388 |
+
system_instruction: Optional[str] = None,
|
389 |
+
apply_chat_template: bool = False,
|
390 |
+
fewshot_as_multiturn: bool = False,
|
391 |
+
chat_template: Optional[Callable] = None,
|
392 |
+
tokenizer_name: str = "",
|
393 |
+
) -> None:
|
394 |
+
"""Build a set of Instances for a task, and store them in task.instances"""
|
395 |
+
|
396 |
+
# used with caching
|
397 |
+
og_limit = limit
|
398 |
+
|
399 |
+
cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}"
|
400 |
+
cache_key += "-chat_template" if apply_chat_template else ""
|
401 |
+
cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else ""
|
402 |
+
cache_key += (
|
403 |
+
f"-system_prompt_hash{utils.hash_string(system_instruction)}"
|
404 |
+
if system_instruction is not None
|
405 |
+
else ""
|
406 |
+
)
|
407 |
+
cache_key += f"-tokenizer{tokenizer_name}"
|
408 |
+
|
409 |
+
cached_instances = load_from_cache(file_name=cache_key)
|
410 |
+
|
411 |
+
if cache_requests and cached_instances and not rewrite_requests_cache:
|
412 |
+
cached_instances = cached_instances[:limit]
|
413 |
+
|
414 |
+
flattened_instances = [
|
415 |
+
instance
|
416 |
+
for instance_group in cached_instances
|
417 |
+
for instance in instance_group
|
418 |
+
]
|
419 |
+
|
420 |
+
self._instances = flattened_instances
|
421 |
+
return
|
422 |
+
|
423 |
+
eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
|
424 |
+
|
425 |
+
instances = []
|
426 |
+
|
427 |
+
# process all documents when caching is specified for simplicity
|
428 |
+
if (
|
429 |
+
cache_requests
|
430 |
+
and (not cached_instances or rewrite_requests_cache)
|
431 |
+
and limit is not None
|
432 |
+
):
|
433 |
+
limit = None
|
434 |
+
|
435 |
+
doc_id_docs = list(
|
436 |
+
self.doc_iterator(rank=rank, limit=limit, world_size=world_size)
|
437 |
+
)
|
438 |
+
|
439 |
+
num_docs = len(doc_id_docs)
|
440 |
+
|
441 |
+
for doc_id, doc in tqdm(
|
442 |
+
doc_id_docs,
|
443 |
+
total=num_docs,
|
444 |
+
):
|
445 |
+
# sample fewshot context #TODO: need to offset doc_id by rank now!
|
446 |
+
fewshot_ctx = self.fewshot_context(
|
447 |
+
doc,
|
448 |
+
0 if self.config.num_fewshot is None else self.config.num_fewshot,
|
449 |
+
system_instruction,
|
450 |
+
apply_chat_template,
|
451 |
+
fewshot_as_multiturn,
|
452 |
+
chat_template,
|
453 |
+
)
|
454 |
+
|
455 |
+
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
|
456 |
+
inst = self.construct_requests(
|
457 |
+
doc=doc,
|
458 |
+
ctx=fewshot_ctx,
|
459 |
+
metadata=(self.config["task"], doc_id, self.config.repeats),
|
460 |
+
)
|
461 |
+
|
462 |
+
if not isinstance(inst, list):
|
463 |
+
inst = [inst]
|
464 |
+
|
465 |
+
instances.append(inst)
|
466 |
+
|
467 |
+
# now flatten, this is to allow slicing to work with pickles
|
468 |
+
|
469 |
+
sliced_instances = instances[:og_limit]
|
470 |
+
|
471 |
+
flattened_instances = [
|
472 |
+
instance
|
473 |
+
for instance_group in sliced_instances
|
474 |
+
for instance in instance_group
|
475 |
+
]
|
476 |
+
|
477 |
+
self._instances = flattened_instances
|
478 |
+
|
479 |
+
if len(self._instances) == 0:
|
480 |
+
raise ValueError("task.build_requests() did not find any docs!")
|
481 |
+
|
482 |
+
if cache_requests and (not cached_instances or rewrite_requests_cache):
|
483 |
+
save_to_cache(file_name=cache_key, obj=instances)
|
484 |
+
|
485 |
+
@abc.abstractmethod
|
486 |
+
def construct_requests(self, doc, ctx, **kwargs):
|
487 |
+
"""Uses RequestFactory to construct Requests and returns an iterable of
|
488 |
+
Requests which will be sent to the LM.
|
489 |
+
|
490 |
+
:param doc:
|
491 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
492 |
+
:param ctx: str
|
493 |
+
The context string, generated by fewshot_context. This includes the natural
|
494 |
+
language description, as well as the few shot examples, and the question
|
495 |
+
part of the document for `doc`.
|
496 |
+
:param doc_idx: int
|
497 |
+
The index of a document within `self.test_docs()` or `self.validation_docs()`,
|
498 |
+
whichever is the main split used.
|
499 |
+
:param repeats: int
|
500 |
+
TODO: update this docstring
|
501 |
+
The number of times each instance in a dataset is inferred on. Defaults to 1,
|
502 |
+
can be increased for techniques like majority voting.
|
503 |
+
"""
|
504 |
+
pass
|
505 |
+
|
506 |
+
@abc.abstractmethod
|
507 |
+
def process_results(self, doc, results):
|
508 |
+
"""Take a single document and the LM results and evaluates, returning a
|
509 |
+
dict where keys are the names of submetrics and values are the values of
|
510 |
+
the metric for that one document
|
511 |
+
|
512 |
+
:param doc:
|
513 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
514 |
+
:param results:
|
515 |
+
The results of the requests created in construct_requests.
|
516 |
+
"""
|
517 |
+
pass
|
518 |
+
|
519 |
+
@abc.abstractmethod
|
520 |
+
def aggregation(self):
|
521 |
+
"""
|
522 |
+
:returns: {str: [metric_score] -> float}
|
523 |
+
A dictionary where keys are the names of submetrics and values are
|
524 |
+
functions that aggregate a list of metric scores
|
525 |
+
"""
|
526 |
+
pass
|
527 |
+
|
528 |
+
@abc.abstractmethod
|
529 |
+
def higher_is_better(self):
|
530 |
+
"""
|
531 |
+
:returns: {str: bool}
|
532 |
+
A dictionary where keys are the names of submetrics and values are
|
533 |
+
whether a higher value of the submetric is better
|
534 |
+
"""
|
535 |
+
pass
|
536 |
+
|
537 |
+
def get_config(self, key: str) -> Any:
|
538 |
+
return getattr(self._config, key, None)
|
539 |
+
|
540 |
+
@classmethod
|
541 |
+
def count_bytes(cls, doc):
|
542 |
+
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
|
543 |
+
return len(doc.encode("utf-8"))
|
544 |
+
|
545 |
+
@classmethod
|
546 |
+
def count_words(cls, doc):
|
547 |
+
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
|
548 |
+
return len(re.split(r"\s+", doc))
|
549 |
+
|
550 |
+
@utils.positional_deprecated
|
551 |
+
def fewshot_context(
|
552 |
+
self,
|
553 |
+
doc,
|
554 |
+
num_fewshot,
|
555 |
+
rnd=None,
|
556 |
+
description=None,
|
557 |
+
):
|
558 |
+
"""Returns a fewshot context string that is made up of a prepended description
|
559 |
+
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
|
560 |
+
|
561 |
+
:param doc: str
|
562 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
563 |
+
:param num_fewshot: int
|
564 |
+
The number of fewshot examples to provide in the returned context string.
|
565 |
+
:param rnd: random.Random
|
566 |
+
The pseudo-random number generator used to randomly sample examples.
|
567 |
+
WARNING: This is currently a required arg although it's optionalized with a default `None`.
|
568 |
+
:param description: str
|
569 |
+
The task's description that will be prepended to the fewshot examples.
|
570 |
+
:returns: str
|
571 |
+
The fewshot context.
|
572 |
+
"""
|
573 |
+
if rnd is None:
|
574 |
+
if self.fewshot_rnd is not None:
|
575 |
+
rnd = self.fewshot_rnd
|
576 |
+
else:
|
577 |
+
raise ValueError(
|
578 |
+
"A `random.Random` generator argument must be provided to `rnd`"
|
579 |
+
)
|
580 |
+
|
581 |
+
description = description if description else ""
|
582 |
+
|
583 |
+
if num_fewshot == 0:
|
584 |
+
labeled_examples = ""
|
585 |
+
else:
|
586 |
+
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
|
587 |
+
if self.has_training_docs():
|
588 |
+
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
|
589 |
+
else:
|
590 |
+
if self._fewshot_docs is None:
|
591 |
+
self._fewshot_docs = list(
|
592 |
+
self.validation_docs()
|
593 |
+
if self.has_validation_docs()
|
594 |
+
else self.test_docs()
|
595 |
+
)
|
596 |
+
|
597 |
+
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
|
598 |
+
|
599 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
600 |
+
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
|
601 |
+
|
602 |
+
labeled_examples = (
|
603 |
+
"\n\n".join(
|
604 |
+
[
|
605 |
+
self.doc_to_text(doc) + self.doc_to_target(doc)
|
606 |
+
for doc in fewshotex
|
607 |
+
]
|
608 |
+
)
|
609 |
+
+ "\n\n"
|
610 |
+
)
|
611 |
+
|
612 |
+
example = self.doc_to_text(doc)
|
613 |
+
return description + labeled_examples + example
|
614 |
+
|
615 |
+
def apply_filters(self) -> Optional[List[Instance]]:
|
616 |
+
"""Iterates over FilterEnsembles and applies them to instances"""
|
617 |
+
if hasattr(self, "_filters"):
|
618 |
+
for f in self._filters:
|
619 |
+
f.apply(self._instances)
|
620 |
+
else:
|
621 |
+
eval_logger.warning("No filter defined, passing through instances")
|
622 |
+
return self._instances
|
623 |
+
|
624 |
+
def dump_config(self) -> dict:
|
625 |
+
"""Returns the config as a dictionary."""
|
626 |
+
# TODO: this should only return the overrides applied to a non-YAML task's configuration.
|
627 |
+
# (num_fewshot)
|
628 |
+
return self.config.to_dict()
|
629 |
+
|
630 |
+
def set_config(self, key: str, value: Any, update: bool = False) -> None:
|
631 |
+
"""Set or update the configuration for a given key."""
|
632 |
+
if key is None:
|
633 |
+
raise ValueError("Key must be provided.")
|
634 |
+
|
635 |
+
if update:
|
636 |
+
current_value = getattr(self._config, key, {})
|
637 |
+
if not isinstance(current_value, dict):
|
638 |
+
raise TypeError(
|
639 |
+
f"Expected a dict for key '{key}', got {type(current_value).__name__} instead."
|
640 |
+
)
|
641 |
+
current_value.update(value)
|
642 |
+
else:
|
643 |
+
setattr(self._config, key, value)
|
644 |
+
|
645 |
+
def override_metric(self, metric_name: str) -> None:
|
646 |
+
"""
|
647 |
+
Override the default metrics used for evaluation with custom metrics.
|
648 |
+
|
649 |
+
Parameters:
|
650 |
+
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
|
651 |
+
"""
|
652 |
+
(
|
653 |
+
self._metric_fn_list,
|
654 |
+
self._aggregation_list,
|
655 |
+
self._metric_fn_kwargs,
|
656 |
+
self._higher_is_better,
|
657 |
+
) = ({}, {}, {}, {})
|
658 |
+
self._metric_fn_list[metric_name] = get_metric(metric_name)
|
659 |
+
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
|
660 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
661 |
+
self._metric_fn_kwargs[metric_name] = {}
|
662 |
+
if not isinstance(self, ConfigurableTask):
|
663 |
+
self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
|
664 |
+
self.aggregation = lambda: {
|
665 |
+
metric_name: get_metric_aggregation(metric_name)
|
666 |
+
}
|
667 |
+
setattr(self._config, "metric_list", [{"metric": metric_name}])
|
668 |
+
setattr(self._config, "process_results", None)
|
669 |
+
|
670 |
+
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
|
671 |
+
self.fewshot_rnd = random.Random(seed)
|
672 |
+
if hasattr(self, "sampler"):
|
673 |
+
self.sampler.rnd = self.fewshot_rnd
|
674 |
+
|
675 |
+
@property
|
676 |
+
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
|
677 |
+
if self.has_test_docs():
|
678 |
+
return self.test_docs()
|
679 |
+
elif self.has_validation_docs():
|
680 |
+
return self.validation_docs()
|
681 |
+
else:
|
682 |
+
raise ValueError(
|
683 |
+
f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
|
684 |
+
)
|
685 |
+
|
686 |
+
def doc_iterator(
|
687 |
+
self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
|
688 |
+
) -> Iterator[Tuple[int, Any]]:
|
689 |
+
limit = int(limit) if limit else None
|
690 |
+
doc_iterator = utils.create_iterator(
|
691 |
+
enumerate(self.eval_docs),
|
692 |
+
rank=int(rank),
|
693 |
+
limit=limit,
|
694 |
+
world_size=int(world_size),
|
695 |
+
)
|
696 |
+
return doc_iterator
|
697 |
+
|
698 |
+
|
699 |
+
class ConfigurableTask(Task):
|
700 |
+
VERSION = "Yaml"
|
701 |
+
OUTPUT_TYPE = None
|
702 |
+
CONFIG = None
|
703 |
+
|
704 |
+
def __init__(
|
705 |
+
self,
|
706 |
+
data_dir=None,
|
707 |
+
cache_dir=None,
|
708 |
+
download_mode=None,
|
709 |
+
config: Optional[dict] = None,
|
710 |
+
) -> None: # TODO no super() call here
|
711 |
+
# Get pre-configured attributes
|
712 |
+
self._config = self.CONFIG
|
713 |
+
|
714 |
+
# Use new configurations if there was no preconfiguration
|
715 |
+
if self.config is None:
|
716 |
+
self._config = TaskConfig(**config)
|
717 |
+
# Overwrite configs
|
718 |
+
else:
|
719 |
+
if config is not None:
|
720 |
+
self._config.__dict__.update(config)
|
721 |
+
|
722 |
+
if self.config is None:
|
723 |
+
raise ValueError(
|
724 |
+
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
|
725 |
+
)
|
726 |
+
|
727 |
+
if isinstance(self.config.metadata, dict):
|
728 |
+
if "version" in self.config.metadata:
|
729 |
+
self.VERSION = self.config.metadata["version"]
|
730 |
+
|
731 |
+
if self.config.output_type is not None:
|
732 |
+
if self.config.output_type not in ALL_OUTPUT_TYPES:
|
733 |
+
raise ValueError(
|
734 |
+
f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
|
735 |
+
)
|
736 |
+
self.OUTPUT_TYPE = self.config.output_type
|
737 |
+
|
738 |
+
if self.config.dataset_path is not None:
|
739 |
+
self.DATASET_PATH = self.config.dataset_path
|
740 |
+
|
741 |
+
if self.config.dataset_name is not None:
|
742 |
+
self.DATASET_NAME = self.config.dataset_name
|
743 |
+
|
744 |
+
self._metric_fn_list = {}
|
745 |
+
self._metric_fn_kwargs = {}
|
746 |
+
self._aggregation_list = {}
|
747 |
+
self._higher_is_better = {}
|
748 |
+
|
749 |
+
if self.config.metric_list is None:
|
750 |
+
# TODO: handle this in TaskConfig.__post_init__ ?
|
751 |
+
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
|
752 |
+
|
753 |
+
for metric_name in _metric_list:
|
754 |
+
self._metric_fn_list[metric_name] = get_metric(metric_name)
|
755 |
+
self._metric_fn_kwargs[metric_name] = {}
|
756 |
+
self._aggregation_list[metric_name] = get_metric_aggregation(
|
757 |
+
metric_name
|
758 |
+
)
|
759 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
760 |
+
else:
|
761 |
+
for metric_config in self.config.metric_list:
|
762 |
+
if "metric" not in metric_config:
|
763 |
+
raise ValueError(
|
764 |
+
"'metric' key not provided for an entry in 'metric_list', must be specified!"
|
765 |
+
)
|
766 |
+
metric_name = metric_config["metric"]
|
767 |
+
kwargs = {
|
768 |
+
key: metric_config[key]
|
769 |
+
for key in metric_config
|
770 |
+
if key
|
771 |
+
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
|
772 |
+
}
|
773 |
+
hf_evaluate_metric = (
|
774 |
+
"hf_evaluate" in metric_config
|
775 |
+
and metric_config["hf_evaluate"] is True
|
776 |
+
)
|
777 |
+
|
778 |
+
if self.config.process_results is not None:
|
779 |
+
self._metric_fn_list[metric_name] = None
|
780 |
+
self._metric_fn_kwargs[metric_name] = {}
|
781 |
+
elif callable(metric_name):
|
782 |
+
metric_fn = metric_name.__call__
|
783 |
+
metric_name = metric_name.__name__
|
784 |
+
self._metric_fn_list[metric_name] = metric_fn
|
785 |
+
self._metric_fn_kwargs[metric_name] = kwargs
|
786 |
+
else:
|
787 |
+
self._metric_fn_list[metric_name] = get_metric(
|
788 |
+
metric_name, hf_evaluate_metric
|
789 |
+
)
|
790 |
+
self._metric_fn_kwargs[metric_name] = kwargs
|
791 |
+
|
792 |
+
if "aggregation" in metric_config:
|
793 |
+
agg_name = metric_config["aggregation"]
|
794 |
+
if isinstance(agg_name, str):
|
795 |
+
self._aggregation_list[metric_name] = get_aggregation(agg_name)
|
796 |
+
elif callable(agg_name): # noqa: E721
|
797 |
+
self._aggregation_list[metric_name] = metric_config[
|
798 |
+
"aggregation"
|
799 |
+
]
|
800 |
+
else:
|
801 |
+
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
|
802 |
+
metric_agg = get_metric_aggregation(metric_name)
|
803 |
+
eval_logger.warning(
|
804 |
+
f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
|
805 |
+
f"using default "
|
806 |
+
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
|
807 |
+
)
|
808 |
+
self._aggregation_list[metric_name] = metric_agg
|
809 |
+
|
810 |
+
if "higher_is_better" in metric_config:
|
811 |
+
self._higher_is_better[metric_name] = metric_config[
|
812 |
+
"higher_is_better"
|
813 |
+
]
|
814 |
+
else:
|
815 |
+
eval_logger.warning(
|
816 |
+
f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
|
817 |
+
f"using default "
|
818 |
+
f"higher_is_better={is_higher_better(metric_name)}"
|
819 |
+
)
|
820 |
+
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
821 |
+
|
822 |
+
self.download(self.config.dataset_kwargs)
|
823 |
+
self._training_docs = None
|
824 |
+
self._fewshot_docs = None
|
825 |
+
|
826 |
+
if self.config.filter_list is not None:
|
827 |
+
self._filters = []
|
828 |
+
for filter_config in self.config.filter_list:
|
829 |
+
filter_name = filter_config["name"]
|
830 |
+
filter_functions = filter_config["filter"]
|
831 |
+
components = []
|
832 |
+
for function in filter_functions:
|
833 |
+
kwargs = {
|
834 |
+
key: function[key] for key in function if key != "function"
|
835 |
+
}
|
836 |
+
components.append([function["function"], kwargs])
|
837 |
+
filter_pipeline = build_filter_ensemble(filter_name, components)
|
838 |
+
self._filters.append(filter_pipeline)
|
839 |
+
else:
|
840 |
+
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
|
841 |
+
|
842 |
+
if self.config.use_prompt is not None:
|
843 |
+
eval_logger.info(f"loading prompt {self.config.use_prompt}")
|
844 |
+
self.prompt = get_prompt(
|
845 |
+
self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
|
846 |
+
)
|
847 |
+
else:
|
848 |
+
self.prompt = None
|
849 |
+
|
850 |
+
if self.fewshot_docs() is not None:
|
851 |
+
self.fewshot_rnd = (
|
852 |
+
random.Random()
|
853 |
+
) # setting with no seed, to be overridden at a later time
|
854 |
+
config_sampler: Union[str, Callable] = (
|
855 |
+
self.config.fewshot_config.get("sampler", "default")
|
856 |
+
if self.config.fewshot_config
|
857 |
+
else "default"
|
858 |
+
)
|
859 |
+
if isinstance(config_sampler, str):
|
860 |
+
self.sampler = samplers.get_sampler(config_sampler)(
|
861 |
+
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
|
862 |
+
)
|
863 |
+
elif callable(config_sampler) and issubclass(
|
864 |
+
config_sampler, samplers.ContextSampler
|
865 |
+
):
|
866 |
+
self.sampler = config_sampler(
|
867 |
+
docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
|
868 |
+
)
|
869 |
+
else:
|
870 |
+
raise TypeError(
|
871 |
+
f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
|
872 |
+
f"not {type(config_sampler)}"
|
873 |
+
)
|
874 |
+
|
875 |
+
self.task_docs = self.eval_docs
|
876 |
+
|
877 |
+
# Test One Doc
|
878 |
+
self.features = list(self.task_docs.features.keys())
|
879 |
+
self.multiple_input = 0
|
880 |
+
self.multiple_target = 0
|
881 |
+
test_doc = self.task_docs[0]
|
882 |
+
test_text = self.doc_to_text(test_doc)
|
883 |
+
test_target = self.doc_to_target(test_doc)
|
884 |
+
|
885 |
+
if self.config.doc_to_choice is not None:
|
886 |
+
test_choice = self.doc_to_choice(test_doc)
|
887 |
+
if not isinstance(test_choice, list):
|
888 |
+
eval_logger.error("doc_to_choice must return list")
|
889 |
+
else:
|
890 |
+
num_choice = len(test_choice)
|
891 |
+
|
892 |
+
if isinstance(test_text, int):
|
893 |
+
self.multiple_input = num_choice
|
894 |
+
else:
|
895 |
+
test_choice = None
|
896 |
+
|
897 |
+
if isinstance(test_target, list):
|
898 |
+
self.multiple_target = len(test_target)
|
899 |
+
else:
|
900 |
+
if (isinstance(test_target, int)) and (test_choice is not None):
|
901 |
+
test_target = test_choice[test_target]
|
902 |
+
else:
|
903 |
+
test_target = str(test_target)
|
904 |
+
|
905 |
+
if test_choice is not None:
|
906 |
+
check_choices = test_choice
|
907 |
+
else:
|
908 |
+
check_choices = [test_target]
|
909 |
+
if self.config.doc_to_choice is not None:
|
910 |
+
for choice in check_choices:
|
911 |
+
choice_has_whitespace = True if choice[0].isspace() else False
|
912 |
+
delimiter_has_whitespace = (
|
913 |
+
True
|
914 |
+
if self.config.target_delimiter.rstrip()
|
915 |
+
!= self.config.target_delimiter
|
916 |
+
else False
|
917 |
+
)
|
918 |
+
|
919 |
+
if delimiter_has_whitespace and choice_has_whitespace:
|
920 |
+
eval_logger.debug(
|
921 |
+
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
|
922 |
+
)
|
923 |
+
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
|
924 |
+
eval_logger.debug(
|
925 |
+
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
|
926 |
+
)
|
927 |
+
|
928 |
+
def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
|
929 |
+
self.dataset = datasets.load_dataset(
|
930 |
+
path=self.DATASET_PATH,
|
931 |
+
name=self.DATASET_NAME,
|
932 |
+
**dataset_kwargs if dataset_kwargs is not None else {},
|
933 |
+
)
|
934 |
+
|
935 |
+
def has_training_docs(self) -> bool:
|
936 |
+
if self.config.training_split is not None:
|
937 |
+
return True
|
938 |
+
else:
|
939 |
+
return False
|
940 |
+
|
941 |
+
def has_validation_docs(self) -> bool:
|
942 |
+
if self.config.validation_split is not None:
|
943 |
+
return True
|
944 |
+
else:
|
945 |
+
return False
|
946 |
+
|
947 |
+
def has_test_docs(self) -> bool:
|
948 |
+
if self.config.test_split is not None:
|
949 |
+
return True
|
950 |
+
else:
|
951 |
+
return False
|
952 |
+
|
953 |
+
def training_docs(self) -> datasets.Dataset:
|
954 |
+
if self.has_training_docs():
|
955 |
+
if self.config.process_docs is not None:
|
956 |
+
return self.config.process_docs(
|
957 |
+
self.dataset[self.config.training_split]
|
958 |
+
)
|
959 |
+
return self.dataset[self.config.training_split]
|
960 |
+
|
961 |
+
def validation_docs(self) -> datasets.Dataset:
|
962 |
+
if self.has_validation_docs():
|
963 |
+
if self.config.process_docs is not None:
|
964 |
+
return self.config.process_docs(
|
965 |
+
self.dataset[self.config.validation_split]
|
966 |
+
)
|
967 |
+
return self.dataset[self.config.validation_split]
|
968 |
+
|
969 |
+
def test_docs(self) -> datasets.Dataset:
|
970 |
+
if self.has_test_docs():
|
971 |
+
if self.config.process_docs is not None:
|
972 |
+
return self.config.process_docs(self.dataset[self.config.test_split])
|
973 |
+
return self.dataset[self.config.test_split]
|
974 |
+
|
975 |
+
def fewshot_docs(self):
|
976 |
+
if self.config.fewshot_split is not None:
|
977 |
+
if self.config.process_docs is not None:
|
978 |
+
return self.config.process_docs(self.dataset[self.config.fewshot_split])
|
979 |
+
return self.dataset[self.config.fewshot_split]
|
980 |
+
elif (
|
981 |
+
self.config.fewshot_config is not None
|
982 |
+
and self.config.fewshot_config.get("samples", None) is not None
|
983 |
+
):
|
984 |
+
if isinstance(self.config.fewshot_config["samples"], list):
|
985 |
+
return self.config.fewshot_config["samples"]
|
986 |
+
elif callable(self.config.fewshot_config["samples"]):
|
987 |
+
return self.config.fewshot_config["samples"]()
|
988 |
+
else:
|
989 |
+
raise Exception(
|
990 |
+
"`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list."
|
991 |
+
)
|
992 |
+
else:
|
993 |
+
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
|
994 |
+
eval_logger.warning(
|
995 |
+
f"[Task: {self.config.task}] "
|
996 |
+
"num_fewshot > 0 but fewshot_split is None. "
|
997 |
+
"using preconfigured rule."
|
998 |
+
)
|
999 |
+
return super().fewshot_docs()
|
1000 |
+
|
1001 |
+
@staticmethod
|
1002 |
+
def append_target_question(
|
1003 |
+
labeled_examples: List[Dict[str, str]],
|
1004 |
+
question: str,
|
1005 |
+
fewshot_as_multiturn: bool = False,
|
1006 |
+
) -> None:
|
1007 |
+
"""Adds a target question to the labeled examples list.
|
1008 |
+
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
|
1009 |
+
Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant.
|
1010 |
+
"""
|
1011 |
+
if not fewshot_as_multiturn:
|
1012 |
+
# if no messages or last message is system, append as new user entry
|
1013 |
+
if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
|
1014 |
+
labeled_examples.append({"role": "user", "content": question})
|
1015 |
+
# if last message is user, append to it to avoid two user messages in a row
|
1016 |
+
else:
|
1017 |
+
labeled_examples[-1]["content"] += question
|
1018 |
+
else:
|
1019 |
+
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
|
1020 |
+
labeled_examples.append({"role": "user", "content": question})
|
1021 |
+
|
1022 |
+
@utils.positional_deprecated
|
1023 |
+
def fewshot_context(
|
1024 |
+
self,
|
1025 |
+
doc: str,
|
1026 |
+
num_fewshot: int,
|
1027 |
+
system_instruction: Optional[str] = None,
|
1028 |
+
apply_chat_template: bool = False,
|
1029 |
+
fewshot_as_multiturn: bool = False,
|
1030 |
+
chat_template: Optional[Callable] = None,
|
1031 |
+
) -> str:
|
1032 |
+
"""Returns a fewshot context string that is made up of a prepended description
|
1033 |
+
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
|
1034 |
+
|
1035 |
+
:param doc: str
|
1036 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
1037 |
+
:param num_fewshot: int
|
1038 |
+
The number of fewshot examples to provide in the returned context string.
|
1039 |
+
:param system_instruction: str
|
1040 |
+
System instruction to be applied to the prompt.
|
1041 |
+
:param apply_chat_template: bool
|
1042 |
+
Whether to apply the chat template to the fewshot context.
|
1043 |
+
:param fewshot_as_multiturn: bool
|
1044 |
+
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
1045 |
+
:param chat_template: Callable
|
1046 |
+
Chat template to be applied to the fewshot context.
|
1047 |
+
:returns: str
|
1048 |
+
The fewshot context.
|
1049 |
+
"""
|
1050 |
+
|
1051 |
+
if apply_chat_template:
|
1052 |
+
labeled_examples = []
|
1053 |
+
else:
|
1054 |
+
labeled_examples = ""
|
1055 |
+
|
1056 |
+
# get task description
|
1057 |
+
if description := self.config.description:
|
1058 |
+
description = utils.apply_template(self.config.description, doc)
|
1059 |
+
|
1060 |
+
# create system prompt based on the provided system instruction and description
|
1061 |
+
if system_instruction is not None and description:
|
1062 |
+
system_prompt = (
|
1063 |
+
f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
|
1064 |
+
)
|
1065 |
+
elif system_instruction is not None:
|
1066 |
+
system_prompt = system_instruction
|
1067 |
+
elif description:
|
1068 |
+
system_prompt = description
|
1069 |
+
else:
|
1070 |
+
system_prompt = ""
|
1071 |
+
|
1072 |
+
# add system prompt if specified
|
1073 |
+
if system_prompt:
|
1074 |
+
if apply_chat_template:
|
1075 |
+
labeled_examples.append({"role": "system", "content": system_prompt})
|
1076 |
+
else:
|
1077 |
+
labeled_examples = system_prompt
|
1078 |
+
|
1079 |
+
# if few-shot - append examples after the system prompt
|
1080 |
+
if num_fewshot > 0:
|
1081 |
+
if apply_chat_template:
|
1082 |
+
labeled_examples.extend(
|
1083 |
+
self.sampler.get_chat_context(
|
1084 |
+
doc, num_fewshot, fewshot_as_multiturn
|
1085 |
+
)
|
1086 |
+
)
|
1087 |
+
else:
|
1088 |
+
labeled_examples += self.sampler.get_context(doc, num_fewshot)
|
1089 |
+
|
1090 |
+
example = self.doc_to_text(doc)
|
1091 |
+
if apply_chat_template:
|
1092 |
+
if self.multiple_input:
|
1093 |
+
return chat_template(labeled_examples)
|
1094 |
+
if isinstance(example, str):
|
1095 |
+
self.append_target_question(
|
1096 |
+
labeled_examples, example, fewshot_as_multiturn
|
1097 |
+
)
|
1098 |
+
# for loglikelihood create a list of questions with appended choices
|
1099 |
+
elif isinstance(example, list):
|
1100 |
+
labeled_examples_list = []
|
1101 |
+
# copy chat history for each example and append the answer
|
1102 |
+
for ex in example:
|
1103 |
+
chat = deepcopy(labeled_examples)
|
1104 |
+
self.append_target_question(chat, ex, fewshot_as_multiturn)
|
1105 |
+
labeled_examples_list.append(chat_template(chat))
|
1106 |
+
return labeled_examples_list
|
1107 |
+
# if example is an integer, append the choice or convert to string
|
1108 |
+
elif isinstance(example, int):
|
1109 |
+
if self.config.doc_to_choice is not None:
|
1110 |
+
choices = self.doc_to_choice(doc)
|
1111 |
+
self.append_target_question(
|
1112 |
+
labeled_examples, choices[example], fewshot_as_multiturn
|
1113 |
+
)
|
1114 |
+
else:
|
1115 |
+
self.append_target_question(
|
1116 |
+
labeled_examples, str(example), fewshot_as_multiturn
|
1117 |
+
)
|
1118 |
+
# return lm.apply_chat_template(labeled_examples)
|
1119 |
+
return chat_template(labeled_examples)
|
1120 |
+
else:
|
1121 |
+
if self.multiple_input:
|
1122 |
+
return labeled_examples
|
1123 |
+
if isinstance(example, str):
|
1124 |
+
return labeled_examples + example
|
1125 |
+
elif isinstance(example, list):
|
1126 |
+
return [labeled_examples + ex for ex in example]
|
1127 |
+
elif isinstance(example, int):
|
1128 |
+
if self.config.doc_to_choice is not None:
|
1129 |
+
choices = self.doc_to_choice(doc)
|
1130 |
+
return labeled_examples + choices[example]
|
1131 |
+
else:
|
1132 |
+
return labeled_examples + str(example)
|
1133 |
+
|
1134 |
+
def apply_filters(self):
|
1135 |
+
"""Iterates over FilterEnsembles and applies them to instances"""
|
1136 |
+
if hasattr(self, "_filters"):
|
1137 |
+
for f in self._filters:
|
1138 |
+
f.apply(self._instances)
|
1139 |
+
else:
|
1140 |
+
eval_logger.warning("No filter defined, passing through instances")
|
1141 |
+
return self._instances
|
1142 |
+
|
1143 |
+
def should_decontaminate(self):
|
1144 |
+
return self.config.should_decontaminate
|
1145 |
+
|
1146 |
+
def doc_to_decontamination_query(self, doc):
|
1147 |
+
if self.config.should_decontaminate:
|
1148 |
+
if self.config.doc_to_decontamination_query is None:
|
1149 |
+
return self.doc_to_text(doc)
|
1150 |
+
else:
|
1151 |
+
doc_to_decontamination_query = self.config.doc_to_decontamination_query
|
1152 |
+
if doc_to_decontamination_query in self.features:
|
1153 |
+
return doc[doc_to_decontamination_query]
|
1154 |
+
elif callable(doc_to_decontamination_query):
|
1155 |
+
return doc_to_decontamination_query(doc)
|
1156 |
+
else:
|
1157 |
+
return ast.literal_eval(
|
1158 |
+
utils.apply_template(
|
1159 |
+
self.config.doc_to_decontamination_query, doc
|
1160 |
+
)
|
1161 |
+
)
|
1162 |
+
|
1163 |
+
def _process_doc(self, doc: dict) -> dict:
|
1164 |
+
"""
|
1165 |
+
Override this to process (detokenize, strip, replace, etc.) individual
|
1166 |
+
documents. This can be used in a map over documents of a data split.
|
1167 |
+
E.g. `map(self._process_doc, self.dataset["validation"])`
|
1168 |
+
|
1169 |
+
:return: dict
|
1170 |
+
The processed version of the specified `doc`.
|
1171 |
+
"""
|
1172 |
+
return doc
|
1173 |
+
|
1174 |
+
def doc_to_text(self, doc, doc_to_text=None):
|
1175 |
+
if self.prompt is not None:
|
1176 |
+
doc_to_text = self.prompt
|
1177 |
+
elif doc_to_text is not None:
|
1178 |
+
doc_to_text = doc_to_text
|
1179 |
+
else:
|
1180 |
+
doc_to_text = self.config.doc_to_text
|
1181 |
+
|
1182 |
+
if isinstance(doc_to_text, int):
|
1183 |
+
return doc_to_text
|
1184 |
+
elif isinstance(doc_to_text, str):
|
1185 |
+
if doc_to_text in self.features:
|
1186 |
+
# if self.config.doc_to_choice is not None:
|
1187 |
+
# return self.doc_to_choice(doc)[doc[doc_to_text]]
|
1188 |
+
# else:
|
1189 |
+
return doc[doc_to_text]
|
1190 |
+
else:
|
1191 |
+
text_string = utils.apply_template(doc_to_text, doc)
|
1192 |
+
if text_string.isdigit() and self._config.doc_to_choice is not None:
|
1193 |
+
return ast.literal_eval(text_string)
|
1194 |
+
else:
|
1195 |
+
return text_string
|
1196 |
+
elif callable(doc_to_text):
|
1197 |
+
return doc_to_text(doc)
|
1198 |
+
# Used when applying a Promptsource template
|
1199 |
+
elif hasattr(doc_to_text, "apply"):
|
1200 |
+
applied_prompt = doc_to_text.apply(doc)
|
1201 |
+
if len(applied_prompt) == 2:
|
1202 |
+
return applied_prompt[0]
|
1203 |
+
else:
|
1204 |
+
eval_logger.warning("Applied prompt returns empty string")
|
1205 |
+
return self.config.fewshot_delimiter
|
1206 |
+
else:
|
1207 |
+
print(type(doc_to_text))
|
1208 |
+
raise TypeError
|
1209 |
+
|
1210 |
+
def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
|
1211 |
+
if self.prompt is not None:
|
1212 |
+
doc_to_target = self.prompt
|
1213 |
+
elif doc_to_target is not None:
|
1214 |
+
doc_to_target = doc_to_target
|
1215 |
+
else:
|
1216 |
+
doc_to_target = self.config.doc_to_target
|
1217 |
+
|
1218 |
+
if isinstance(doc_to_target, int):
|
1219 |
+
return doc_to_target
|
1220 |
+
elif isinstance(doc_to_target, str):
|
1221 |
+
if doc_to_target in self.features:
|
1222 |
+
# if self.config.doc_to_choice is not None:
|
1223 |
+
# return self.doc_to_choice(doc)[doc[doc_to_target]]
|
1224 |
+
# else:
|
1225 |
+
return doc[doc_to_target]
|
1226 |
+
else:
|
1227 |
+
target_string = utils.apply_template(doc_to_target, doc)
|
1228 |
+
if target_string.isdigit() and self._config.doc_to_choice is not None:
|
1229 |
+
return ast.literal_eval(target_string)
|
1230 |
+
elif (
|
1231 |
+
len(target_string) >= 2
|
1232 |
+
and (target_string[0] == "[")
|
1233 |
+
and (target_string[-1] == "]")
|
1234 |
+
):
|
1235 |
+
try:
|
1236 |
+
return ast.literal_eval(target_string)
|
1237 |
+
except (SyntaxError, ValueError):
|
1238 |
+
return target_string
|
1239 |
+
else:
|
1240 |
+
return target_string
|
1241 |
+
elif isinstance(doc_to_target, list):
|
1242 |
+
return doc_to_target
|
1243 |
+
elif callable(doc_to_target):
|
1244 |
+
return doc_to_target(doc)
|
1245 |
+
# Used when applying a Promptsource template
|
1246 |
+
elif hasattr(doc_to_target, "apply"):
|
1247 |
+
applied_prompt = doc_to_target.apply(doc)
|
1248 |
+
if len(applied_prompt) == 2:
|
1249 |
+
return applied_prompt[1]
|
1250 |
+
else:
|
1251 |
+
eval_logger.warning("Applied prompt returns empty string")
|
1252 |
+
return self.config.fewshot_delimiter
|
1253 |
+
else:
|
1254 |
+
raise TypeError
|
1255 |
+
|
1256 |
+
def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
|
1257 |
+
if self.prompt is not None:
|
1258 |
+
doc_to_choice = self.prompt
|
1259 |
+
elif doc_to_choice is not None:
|
1260 |
+
doc_to_choice = doc_to_choice
|
1261 |
+
elif self.config.doc_to_choice is None:
|
1262 |
+
eval_logger.error("doc_to_choice was called but not set in config")
|
1263 |
+
else:
|
1264 |
+
doc_to_choice = self.config.doc_to_choice
|
1265 |
+
|
1266 |
+
if isinstance(doc_to_choice, str):
|
1267 |
+
if doc_to_choice in self.features:
|
1268 |
+
return doc[doc_to_choice]
|
1269 |
+
else:
|
1270 |
+
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
|
1271 |
+
elif isinstance(doc_to_choice, list):
|
1272 |
+
return doc_to_choice
|
1273 |
+
elif isinstance(doc_to_choice, dict):
|
1274 |
+
return list(doc_to_choice.values())
|
1275 |
+
elif callable(doc_to_choice):
|
1276 |
+
return doc_to_choice(doc)
|
1277 |
+
elif hasattr(doc_to_choice, "get_answer_choices_list"):
|
1278 |
+
return doc_to_choice.get_answer_choices_list(doc)
|
1279 |
+
else:
|
1280 |
+
raise TypeError
|
1281 |
+
|
1282 |
+
def construct_requests(
|
1283 |
+
self, doc: dict, ctx: str, **kwargs
|
1284 |
+
) -> Union[List[Instance], Instance]:
|
1285 |
+
if self.OUTPUT_TYPE == "loglikelihood":
|
1286 |
+
arguments = (ctx, self.doc_to_target(doc))
|
1287 |
+
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
|
1288 |
+
arguments = (self.doc_to_target(doc),)
|
1289 |
+
elif self.OUTPUT_TYPE == "multiple_choice":
|
1290 |
+
choices = self.doc_to_choice(doc)
|
1291 |
+
target_delimiter = self.config.target_delimiter
|
1292 |
+
if self.multiple_input:
|
1293 |
+
# If there are multiple inputs, choices are placed in the ctx
|
1294 |
+
cont = self.doc_to_target(doc)
|
1295 |
+
arguments = [
|
1296 |
+
(ctx + choice, f"{target_delimiter}{cont}") for choice in choices
|
1297 |
+
]
|
1298 |
+
else:
|
1299 |
+
# Otherwise they are placed in the continuation
|
1300 |
+
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
|
1301 |
+
|
1302 |
+
request_list = [
|
1303 |
+
Instance(
|
1304 |
+
request_type="loglikelihood",
|
1305 |
+
doc=doc,
|
1306 |
+
arguments=arg,
|
1307 |
+
idx=i,
|
1308 |
+
**kwargs,
|
1309 |
+
)
|
1310 |
+
for i, arg in enumerate(arguments)
|
1311 |
+
]
|
1312 |
+
# TODO: we should raise a warning telling users this will at most ~2x runtime.
|
1313 |
+
if "acc_mutual_info" in self._metric_fn_list.keys():
|
1314 |
+
# if we are calculating multiple choice accuracy
|
1315 |
+
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
|
1316 |
+
|
1317 |
+
# here mutual info refers to calculating
|
1318 |
+
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
|
1319 |
+
# in other words normalizing by subtracting the unconditional logprob of each choice.
|
1320 |
+
request_list.extend(
|
1321 |
+
[
|
1322 |
+
Instance(
|
1323 |
+
request_type="loglikelihood",
|
1324 |
+
doc=doc,
|
1325 |
+
arguments=("", "{}".format(choice)),
|
1326 |
+
idx=i,
|
1327 |
+
**kwargs,
|
1328 |
+
)
|
1329 |
+
for i, choice in enumerate(choices)
|
1330 |
+
]
|
1331 |
+
)
|
1332 |
+
return request_list
|
1333 |
+
|
1334 |
+
elif self.OUTPUT_TYPE == "generate_until":
|
1335 |
+
arguments = (ctx, deepcopy(self.config.generation_kwargs))
|
1336 |
+
|
1337 |
+
return Instance(
|
1338 |
+
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
|
1339 |
+
)
|
1340 |
+
|
1341 |
+
def process_results(self, doc, results):
|
1342 |
+
if callable(self.config.process_results):
|
1343 |
+
return self.config.process_results(doc, results)
|
1344 |
+
|
1345 |
+
result_dict = {}
|
1346 |
+
use_metric = list(self._metric_fn_list.keys())
|
1347 |
+
if self.OUTPUT_TYPE == "loglikelihood":
|
1348 |
+
results = results[0]
|
1349 |
+
ll, is_greedy = results
|
1350 |
+
return {
|
1351 |
+
**({"perplexity": ll} if "perplexity" in use_metric else {}),
|
1352 |
+
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
|
1353 |
+
}
|
1354 |
+
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
|
1355 |
+
(loglikelihood,) = results
|
1356 |
+
_words = self.count_words(self.doc_to_target(doc))
|
1357 |
+
_bytes = self.count_bytes(self.doc_to_target(doc))
|
1358 |
+
return {
|
1359 |
+
**(
|
1360 |
+
{"word_perplexity": (loglikelihood, _words)}
|
1361 |
+
if "word_perplexity" in use_metric
|
1362 |
+
else {}
|
1363 |
+
),
|
1364 |
+
**(
|
1365 |
+
{"byte_perplexity": (loglikelihood, _bytes)}
|
1366 |
+
if "byte_perplexity" in use_metric
|
1367 |
+
else {}
|
1368 |
+
),
|
1369 |
+
**(
|
1370 |
+
{"bits_per_byte": (loglikelihood, _bytes)}
|
1371 |
+
if "bits_per_byte" in use_metric
|
1372 |
+
else {}
|
1373 |
+
),
|
1374 |
+
}
|
1375 |
+
elif self.OUTPUT_TYPE == "multiple_choice":
|
1376 |
+
lls, is_greedy = zip(*results)
|
1377 |
+
|
1378 |
+
# retrieve choices in List[str] form, to compute choice lengths, etc.
|
1379 |
+
choices = self.doc_to_choice(doc)
|
1380 |
+
completion_len = np.array([float(len(i)) for i in choices])
|
1381 |
+
|
1382 |
+
if (
|
1383 |
+
2 * len(choices) == len(lls)
|
1384 |
+
and "acc_mutual_info" in self._metric_fn_list.keys()
|
1385 |
+
):
|
1386 |
+
# then we are doing mutual info.
|
1387 |
+
# this stores the "dryrun" / unconditional answer loglikelihoods
|
1388 |
+
lls_unconditional = lls[1::2]
|
1389 |
+
if len(lls_unconditional) != len(choices):
|
1390 |
+
raise ValueError
|
1391 |
+
# and this stores our "regular" conditional loglikelihoods
|
1392 |
+
lls = lls[::2]
|
1393 |
+
|
1394 |
+
pred = np.argmax(lls)
|
1395 |
+
pred_norm = np.argmax(lls / completion_len)
|
1396 |
+
|
1397 |
+
if self.multiple_input:
|
1398 |
+
gold = self.doc_to_text(doc)
|
1399 |
+
else:
|
1400 |
+
gold = self.doc_to_target(doc)
|
1401 |
+
|
1402 |
+
gold_index_error = False
|
1403 |
+
if isinstance(gold, list):
|
1404 |
+
gold = [i if i < len(choices) else -100 for i in gold]
|
1405 |
+
if -100 in gold:
|
1406 |
+
gold_index_error = True
|
1407 |
+
else:
|
1408 |
+
if isinstance(gold, int):
|
1409 |
+
gold = gold if gold < len(choices) else -100
|
1410 |
+
elif isinstance(gold, str):
|
1411 |
+
gold = choices.index(gold) if gold in choices else -100
|
1412 |
+
|
1413 |
+
if gold == -100:
|
1414 |
+
gold_index_error = True
|
1415 |
+
|
1416 |
+
if gold_index_error:
|
1417 |
+
eval_logger.warning(
|
1418 |
+
f"Label index was not in within range of available choices,"
|
1419 |
+
f"Sample:\n\n{doc}\n\n"
|
1420 |
+
)
|
1421 |
+
|
1422 |
+
if self.multiple_target:
|
1423 |
+
acc = 1.0 if pred in gold else 0.0
|
1424 |
+
acc_norm = 1.0 if pred_norm in gold else 0.0
|
1425 |
+
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
|
1426 |
+
else:
|
1427 |
+
acc = 1.0 if pred == gold else 0.0
|
1428 |
+
acc_norm = 1.0 if pred_norm == gold else 0.0
|
1429 |
+
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
|
1430 |
+
exact_match = int(is_greedy[gold]) if gold != -100 else 0
|
1431 |
+
|
1432 |
+
prob_norm = utils.softmax(lls)
|
1433 |
+
|
1434 |
+
# TODO use keyword arguments to the metric?
|
1435 |
+
# gold, pred, norm stuff, the original lls,
|
1436 |
+
result_dict = {
|
1437 |
+
**({"acc": acc} if "acc" in use_metric else {}),
|
1438 |
+
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
|
1439 |
+
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
|
1440 |
+
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
|
1441 |
+
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
|
1442 |
+
**(
|
1443 |
+
{"brier_score": (gold, prob_norm)}
|
1444 |
+
if "brier_score" in use_metric
|
1445 |
+
else {}
|
1446 |
+
),
|
1447 |
+
}
|
1448 |
+
|
1449 |
+
if "acc_mutual_info" in use_metric:
|
1450 |
+
lls_mutual_info = [
|
1451 |
+
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
|
1452 |
+
]
|
1453 |
+
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
|
1454 |
+
result_dict["acc_mutual_info"] = acc_mutual_info
|
1455 |
+
|
1456 |
+
elif self.OUTPUT_TYPE == "generate_until":
|
1457 |
+
gold = self.doc_to_target(doc)
|
1458 |
+
result = results[0]
|
1459 |
+
if self.config.doc_to_choice is not None:
|
1460 |
+
# If you set doc_to_choice,
|
1461 |
+
# it assumes that doc_to_target returns a number.
|
1462 |
+
choices = self.doc_to_choice(doc)
|
1463 |
+
gold = choices[gold]
|
1464 |
+
# we expect multiple_targets to be a list.
|
1465 |
+
elif self.multiple_target:
|
1466 |
+
gold = list(gold)
|
1467 |
+
elif type(gold) != type(result):
|
1468 |
+
# cast gold to the same type as result
|
1469 |
+
gold = type(result)(gold)
|
1470 |
+
|
1471 |
+
for metric in self._metric_fn_list.keys():
|
1472 |
+
if self.multiple_target:
|
1473 |
+
# in the case where we have multiple targets,
|
1474 |
+
# return true if any are true
|
1475 |
+
# TODO: this may break for multipLe_target, non zero-or-1 metrics
|
1476 |
+
scores = []
|
1477 |
+
if not isinstance(gold, list):
|
1478 |
+
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
|
1479 |
+
# print(gold)
|
1480 |
+
gold = [gold]
|
1481 |
+
if metric == "exact_match":
|
1482 |
+
result = [result for _ in range(len(gold))]
|
1483 |
+
scores = self._metric_fn_list[metric](
|
1484 |
+
references=gold,
|
1485 |
+
predictions=result,
|
1486 |
+
**self._metric_fn_kwargs[metric],
|
1487 |
+
)[metric]
|
1488 |
+
result_score = 1.0 if scores > 0.0 else 0.0
|
1489 |
+
else:
|
1490 |
+
for gold_option in gold:
|
1491 |
+
try:
|
1492 |
+
result_score = self._metric_fn_list[metric](
|
1493 |
+
references=[gold_option],
|
1494 |
+
predictions=[result],
|
1495 |
+
**self._metric_fn_kwargs[metric],
|
1496 |
+
)
|
1497 |
+
except (
|
1498 |
+
TypeError
|
1499 |
+
): # TODO: this is hacky and I don't want to do it
|
1500 |
+
result_score = self._metric_fn_list[metric](
|
1501 |
+
[gold_option, result]
|
1502 |
+
)
|
1503 |
+
if isinstance(result_score, dict):
|
1504 |
+
# TODO: this handles the case where HF evaluate returns a dict.
|
1505 |
+
result_score = result_score[metric]
|
1506 |
+
scores.append(result_score)
|
1507 |
+
if any(scores):
|
1508 |
+
result_score = 1.0
|
1509 |
+
else:
|
1510 |
+
result_score = 0.0
|
1511 |
+
else:
|
1512 |
+
try:
|
1513 |
+
result_score = self._metric_fn_list[metric](
|
1514 |
+
references=[gold],
|
1515 |
+
predictions=[result],
|
1516 |
+
**self._metric_fn_kwargs[metric],
|
1517 |
+
)
|
1518 |
+
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
|
1519 |
+
result_score = self._metric_fn_list[metric]([gold, result])
|
1520 |
+
if isinstance(result_score, dict):
|
1521 |
+
# TODO: this handles the case where HF evaluate returns a dict.
|
1522 |
+
result_score = result_score[metric]
|
1523 |
+
result_dict[metric] = result_score
|
1524 |
+
else:
|
1525 |
+
raise ValueError(
|
1526 |
+
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
|
1527 |
+
"'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
|
1528 |
+
)
|
1529 |
+
|
1530 |
+
return result_dict
|
1531 |
+
|
1532 |
+
def aggregation(self) -> dict:
|
1533 |
+
return self._aggregation_list
|
1534 |
+
|
1535 |
+
def higher_is_better(self) -> dict:
|
1536 |
+
return self._higher_is_better
|
1537 |
+
|
1538 |
+
def get_config(self, key: str) -> Any:
|
1539 |
+
return getattr(self._config, key, None)
|
1540 |
+
|
1541 |
+
@property
|
1542 |
+
def task_name(self) -> Any:
|
1543 |
+
return getattr(self.config, "task", None)
|
1544 |
+
|
1545 |
+
def __repr__(self):
|
1546 |
+
return (
|
1547 |
+
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
|
1548 |
+
f"output_type={self.OUTPUT_TYPE},"
|
1549 |
+
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
|
1550 |
+
f"num_samples={len(self.eval_docs)})"
|
1551 |
+
)
|
1552 |
+
|
1553 |
+
|
1554 |
+
class MultipleChoiceTask(Task):
|
1555 |
+
OUTPUT_TYPE = "loglikelihood"
|
1556 |
+
|
1557 |
+
def doc_to_target(self, doc: dict) -> str:
|
1558 |
+
return " " + doc["choices"][doc["gold"]]
|
1559 |
+
|
1560 |
+
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
|
1561 |
+
# TODO: add mutual info here?
|
1562 |
+
return [
|
1563 |
+
Instance(
|
1564 |
+
request_type="loglikelihood",
|
1565 |
+
doc=doc,
|
1566 |
+
arguments=(ctx, " {}".format(choice)),
|
1567 |
+
idx=i,
|
1568 |
+
**kwargs,
|
1569 |
+
)
|
1570 |
+
for i, choice in enumerate(doc["choices"])
|
1571 |
+
]
|
1572 |
+
|
1573 |
+
def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
|
1574 |
+
results = [
|
1575 |
+
res[0] for res in results
|
1576 |
+
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
|
1577 |
+
gold = doc["gold"]
|
1578 |
+
|
1579 |
+
acc = 1.0 if np.argmax(results) == gold else 0.0
|
1580 |
+
completion_len = np.array([float(len(i)) for i in doc["choices"]])
|
1581 |
+
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
|
1582 |
+
|
1583 |
+
return {
|
1584 |
+
"acc": acc,
|
1585 |
+
"acc_norm": acc_norm,
|
1586 |
+
}
|
1587 |
+
|
1588 |
+
def higher_is_better(self) -> dict:
|
1589 |
+
return {
|
1590 |
+
"acc": True,
|
1591 |
+
"acc_norm": True,
|
1592 |
+
}
|
1593 |
+
|
1594 |
+
def aggregation(self) -> dict:
|
1595 |
+
return {
|
1596 |
+
"acc": mean,
|
1597 |
+
"acc_norm": mean,
|
1598 |
+
}
|
1599 |
+
|
1600 |
+
|
1601 |
+
class PerplexityTask(Task):
|
1602 |
+
OUTPUT_TYPE = "loglikelihood_rolling"
|
1603 |
+
|
1604 |
+
def has_training_docs(self) -> bool:
|
1605 |
+
return False
|
1606 |
+
|
1607 |
+
def fewshot_examples(self, k: int, rnd) -> List:
|
1608 |
+
if k != 0:
|
1609 |
+
raise ValueError(
|
1610 |
+
"The number of fewshot examples must be 0 for perplexity tasks."
|
1611 |
+
)
|
1612 |
+
return []
|
1613 |
+
|
1614 |
+
def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
|
1615 |
+
if num_fewshot != 0:
|
1616 |
+
raise ValueError(
|
1617 |
+
"The number of fewshot examples must be 0 for perplexity tasks."
|
1618 |
+
)
|
1619 |
+
|
1620 |
+
return ""
|
1621 |
+
|
1622 |
+
def higher_is_better(self) -> dict:
|
1623 |
+
return {
|
1624 |
+
"word_perplexity": False,
|
1625 |
+
"byte_perplexity": False,
|
1626 |
+
"bits_per_byte": False,
|
1627 |
+
}
|
1628 |
+
|
1629 |
+
def doc_to_decontamination_query(self, doc):
|
1630 |
+
return doc
|
1631 |
+
|
1632 |
+
def doc_to_text(self, doc) -> str:
|
1633 |
+
return ""
|
1634 |
+
|
1635 |
+
def doc_to_target(self, doc):
|
1636 |
+
return doc
|
1637 |
+
|
1638 |
+
def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
|
1639 |
+
if bool(ctx):
|
1640 |
+
raise ValueError
|
1641 |
+
|
1642 |
+
return Instance(
|
1643 |
+
request_type=self.OUTPUT_TYPE,
|
1644 |
+
doc=doc,
|
1645 |
+
arguments=(self.doc_to_target(doc),),
|
1646 |
+
idx=0,
|
1647 |
+
**kwargs,
|
1648 |
+
)
|
1649 |
+
|
1650 |
+
def process_results(self, doc: dict, results: Tuple[float]) -> dict:
|
1651 |
+
(loglikelihood,) = results
|
1652 |
+
words = self.count_words(self.doc_to_target(doc))
|
1653 |
+
bytes_ = self.count_bytes(self.doc_to_target(doc))
|
1654 |
+
return {
|
1655 |
+
"word_perplexity": (loglikelihood, words),
|
1656 |
+
"byte_perplexity": (loglikelihood, bytes_),
|
1657 |
+
"bits_per_byte": (loglikelihood, bytes_),
|
1658 |
+
}
|
1659 |
+
|
1660 |
+
def aggregation(self) -> dict:
|
1661 |
+
return {
|
1662 |
+
"word_perplexity": weighted_perplexity,
|
1663 |
+
"byte_perplexity": weighted_perplexity,
|
1664 |
+
"bits_per_byte": bits_per_byte,
|
1665 |
+
}
|
1666 |
+
|
1667 |
+
@classmethod
|
1668 |
+
def count_bytes(cls, doc) -> int:
|
1669 |
+
return len(doc.encode("utf-8"))
|
1670 |
+
|
1671 |
+
@classmethod
|
1672 |
+
def count_words(cls, doc) -> int:
|
1673 |
+
"""Downstream tasks with custom word boundaries should override this!"""
|
1674 |
+
return len(re.split(r"\s+", doc))
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import (
|
2 |
+
anthropic_llms,
|
3 |
+
api_models,
|
4 |
+
dummy,
|
5 |
+
gguf,
|
6 |
+
huggingface,
|
7 |
+
mamba_lm,
|
8 |
+
nemo_lm,
|
9 |
+
neuralmagic,
|
10 |
+
neuron_optimum,
|
11 |
+
openai_completions,
|
12 |
+
optimum_lm,
|
13 |
+
textsynth,
|
14 |
+
vllm_causallms,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
# TODO: implement __all__
|
19 |
+
|
20 |
+
|
21 |
+
try:
|
22 |
+
# enable hf hub transfer if available
|
23 |
+
import hf_transfer # type: ignore # noqa
|
24 |
+
import huggingface_hub.constants # type: ignore
|
25 |
+
|
26 |
+
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
|
27 |
+
except ImportError:
|
28 |
+
pass
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (631 Bytes). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/anthropic_llms.cpython-310.pyc
ADDED
Binary file (11 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/api_models.cpython-310.pyc
ADDED
Binary file (16.6 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/dummy.cpython-310.pyc
ADDED
Binary file (1.58 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/gguf.cpython-310.pyc
ADDED
Binary file (4.11 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/huggingface.cpython-310.pyc
ADDED
Binary file (29.8 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/mamba_lm.cpython-310.pyc
ADDED
Binary file (3.69 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/nemo_lm.cpython-310.pyc
ADDED
Binary file (13.7 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/neuralmagic.cpython-310.pyc
ADDED
Binary file (11 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/neuron_optimum.cpython-310.pyc
ADDED
Binary file (18.3 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/openai_completions.cpython-310.pyc
ADDED
Binary file (6.39 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/optimum_lm.cpython-310.pyc
ADDED
Binary file (2.65 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/textsynth.cpython-310.pyc
ADDED
Binary file (5.23 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (21.3 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/vllm_causallms.cpython-310.pyc
ADDED
Binary file (14.3 kB). View file
|
|
scripts/yans/lm-evaluation-harness/lm_eval/models/anthropic_llms.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import cached_property
|
3 |
+
from typing import Any, Dict, List, Tuple, Union
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
from lm_eval import utils
|
8 |
+
from lm_eval.api.model import LM
|
9 |
+
from lm_eval.api.registry import register_model
|
10 |
+
from lm_eval.models.openai_completions import LocalCompletionsAPI
|
11 |
+
from lm_eval.models.utils import retry_on_specific_exceptions
|
12 |
+
|
13 |
+
|
14 |
+
eval_logger = utils.eval_logger
|
15 |
+
|
16 |
+
|
17 |
+
def anthropic_completion(
|
18 |
+
client, #: anthropic.Anthropic,
|
19 |
+
model: str,
|
20 |
+
prompt: str,
|
21 |
+
max_tokens_to_sample: int,
|
22 |
+
temperature: float,
|
23 |
+
stop: List[str],
|
24 |
+
**kwargs: Any,
|
25 |
+
) -> str:
|
26 |
+
"""Wrapper function around the Anthropic completion API client with exponential back-off
|
27 |
+
in case of RateLimitError.
|
28 |
+
|
29 |
+
params:
|
30 |
+
client: anthropic.Anthropic
|
31 |
+
Anthropic API client
|
32 |
+
model: str
|
33 |
+
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
|
34 |
+
prompt: str
|
35 |
+
Prompt to feed to the model
|
36 |
+
max_tokens_to_sample: int
|
37 |
+
Maximum number of tokens to sample from the model
|
38 |
+
temperature: float
|
39 |
+
Sampling temperature
|
40 |
+
stop: List[str]
|
41 |
+
List of stop sequences
|
42 |
+
kwargs: Any
|
43 |
+
Additional model_args to pass to the API client
|
44 |
+
"""
|
45 |
+
|
46 |
+
try:
|
47 |
+
import anthropic
|
48 |
+
except ModuleNotFoundError:
|
49 |
+
raise Exception(
|
50 |
+
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
|
51 |
+
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
|
52 |
+
)
|
53 |
+
|
54 |
+
def _exception_callback(e: Exception, sleep_time: float) -> None:
|
55 |
+
eval_logger.warning(
|
56 |
+
f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
|
57 |
+
)
|
58 |
+
|
59 |
+
@retry_on_specific_exceptions(
|
60 |
+
on_exceptions=[anthropic.RateLimitError],
|
61 |
+
max_retries=None, # retry forever, consider changing
|
62 |
+
on_exception_callback=_exception_callback,
|
63 |
+
)
|
64 |
+
def completion():
|
65 |
+
response = client.completions.create(
|
66 |
+
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
|
67 |
+
model=model,
|
68 |
+
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
|
69 |
+
# (e.g. gsm8k's ":") may truncate a lot of the input.
|
70 |
+
stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
|
71 |
+
max_tokens_to_sample=max_tokens_to_sample,
|
72 |
+
temperature=temperature,
|
73 |
+
**kwargs,
|
74 |
+
)
|
75 |
+
return response.completion
|
76 |
+
|
77 |
+
return completion()
|
78 |
+
|
79 |
+
|
80 |
+
def anthropic_chat(
|
81 |
+
client, #: anthropic.Anthropic,
|
82 |
+
model: str,
|
83 |
+
prompt: str,
|
84 |
+
max_tokens: int,
|
85 |
+
temperature: float,
|
86 |
+
stop: List[str],
|
87 |
+
**kwargs: Any,
|
88 |
+
) -> str:
|
89 |
+
"""Wrapper function around the Anthropic completion API client with exponential back-off
|
90 |
+
in case of RateLimitError.
|
91 |
+
|
92 |
+
params:
|
93 |
+
client: anthropic.Anthropic
|
94 |
+
Anthropic API client
|
95 |
+
model: str
|
96 |
+
Anthropic model e.g. 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'
|
97 |
+
prompt: str
|
98 |
+
Prompt to feed to the model
|
99 |
+
max_tokens: int
|
100 |
+
Maximum number of tokens to sample from the model
|
101 |
+
temperature: float
|
102 |
+
Sampling temperature
|
103 |
+
stop: List[str]
|
104 |
+
List of stop sequences
|
105 |
+
kwargs: Any
|
106 |
+
Additional model_args to pass to the API client
|
107 |
+
"""
|
108 |
+
|
109 |
+
try:
|
110 |
+
import anthropic
|
111 |
+
except ModuleNotFoundError:
|
112 |
+
raise Exception(
|
113 |
+
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
|
114 |
+
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
|
115 |
+
)
|
116 |
+
|
117 |
+
def _exception_callback(e: Exception, sleep_time: float) -> None:
|
118 |
+
eval_logger.warning(
|
119 |
+
f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
|
120 |
+
)
|
121 |
+
|
122 |
+
@retry_on_specific_exceptions(
|
123 |
+
on_exceptions=[
|
124 |
+
anthropic.RateLimitError,
|
125 |
+
anthropic.APIConnectionError,
|
126 |
+
anthropic.APIStatusError,
|
127 |
+
],
|
128 |
+
max_retries=None, # retry forever, consider changing
|
129 |
+
on_exception_callback=_exception_callback,
|
130 |
+
)
|
131 |
+
def messages():
|
132 |
+
response = client.messages.create(
|
133 |
+
model=model,
|
134 |
+
max_tokens=max_tokens,
|
135 |
+
temperature=temperature,
|
136 |
+
messages=[{"role": "user", "content": f"{prompt}"}],
|
137 |
+
**kwargs,
|
138 |
+
)
|
139 |
+
return response.content[0].text
|
140 |
+
|
141 |
+
return messages()
|
142 |
+
|
143 |
+
|
144 |
+
@register_model("anthropic-completions")
|
145 |
+
class AnthropicLM(LM):
|
146 |
+
REQ_CHUNK_SIZE = 20 # TODO: not used
|
147 |
+
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
batch_size: int = 1,
|
151 |
+
model: str = "claude-2.0",
|
152 |
+
max_tokens_to_sample: int = 256,
|
153 |
+
temperature: float = 0, # defaults to 1
|
154 |
+
**kwargs, # top_p, top_k, etc.
|
155 |
+
) -> None:
|
156 |
+
"""Anthropic API wrapper.
|
157 |
+
|
158 |
+
:param model: str
|
159 |
+
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
|
160 |
+
:param max_tokens_to_sample: int
|
161 |
+
Maximum number of tokens to sample from the model
|
162 |
+
:param temperature: float
|
163 |
+
Sampling temperature
|
164 |
+
:param kwargs: Any
|
165 |
+
Additional model_args to pass to the API client
|
166 |
+
"""
|
167 |
+
super().__init__()
|
168 |
+
|
169 |
+
try:
|
170 |
+
import anthropic
|
171 |
+
except ModuleNotFoundError:
|
172 |
+
raise Exception(
|
173 |
+
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
|
174 |
+
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
|
175 |
+
)
|
176 |
+
|
177 |
+
self.model = model
|
178 |
+
# defaults to os.environ.get("ANTHROPIC_API_KEY")
|
179 |
+
self.client = anthropic.Anthropic()
|
180 |
+
self.temperature = temperature
|
181 |
+
self.max_tokens_to_sample = max_tokens_to_sample
|
182 |
+
self.tokenizer = self.client.get_tokenizer()
|
183 |
+
self.kwargs = kwargs
|
184 |
+
|
185 |
+
@property
|
186 |
+
def eot_token_id(self):
|
187 |
+
# Not sure but anthropic.HUMAN_PROMPT ?
|
188 |
+
raise NotImplementedError("No idea about anthropic tokenization.")
|
189 |
+
|
190 |
+
@property
|
191 |
+
def max_length(self) -> int:
|
192 |
+
return 2048
|
193 |
+
|
194 |
+
@property
|
195 |
+
def max_gen_toks(self) -> int:
|
196 |
+
return self.max_tokens_to_sample
|
197 |
+
|
198 |
+
@property
|
199 |
+
def batch_size(self):
|
200 |
+
# Isn't used because we override _loglikelihood_tokens
|
201 |
+
raise NotImplementedError("No support for logits.")
|
202 |
+
|
203 |
+
@property
|
204 |
+
def device(self):
|
205 |
+
# Isn't used because we override _loglikelihood_tokens
|
206 |
+
raise NotImplementedError("No support for logits.")
|
207 |
+
|
208 |
+
def tok_encode(self, string: str) -> List[int]:
|
209 |
+
return self.tokenizer.encode(string).ids
|
210 |
+
|
211 |
+
def tok_decode(self, tokens: List[int]) -> str:
|
212 |
+
return self.tokenizer.decode(tokens)
|
213 |
+
|
214 |
+
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
|
215 |
+
raise NotImplementedError("No support for logits.")
|
216 |
+
|
217 |
+
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
|
218 |
+
try:
|
219 |
+
import anthropic
|
220 |
+
except ModuleNotFoundError:
|
221 |
+
raise Exception(
|
222 |
+
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
|
223 |
+
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
|
224 |
+
)
|
225 |
+
|
226 |
+
if not requests:
|
227 |
+
return []
|
228 |
+
|
229 |
+
_requests: List[Tuple[str, dict]] = [req.args for req in requests]
|
230 |
+
|
231 |
+
res = []
|
232 |
+
for request in tqdm(_requests, disable=disable_tqdm):
|
233 |
+
try:
|
234 |
+
inp = request[0]
|
235 |
+
request_args = request[1]
|
236 |
+
# generation_kwargs
|
237 |
+
until = request_args.get("until")
|
238 |
+
max_gen_toks = request_args.get("max_gen_toks", self.max_length)
|
239 |
+
temperature = request_args.get("temperature", self.temperature)
|
240 |
+
response = anthropic_completion(
|
241 |
+
client=self.client,
|
242 |
+
model=self.model,
|
243 |
+
prompt=inp,
|
244 |
+
max_tokens_to_sample=max_gen_toks,
|
245 |
+
temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
|
246 |
+
stop=until, # type: ignore
|
247 |
+
**self.kwargs,
|
248 |
+
)
|
249 |
+
res.append(response)
|
250 |
+
|
251 |
+
self.cache_hook.add_partial("generate_until", request, response)
|
252 |
+
except anthropic.APIConnectionError as e: # type: ignore # noqa: F821
|
253 |
+
eval_logger.critical(f"Server unreachable: {e.__cause__}")
|
254 |
+
break
|
255 |
+
except anthropic.APIStatusError as e: # type: ignore # noqa: F821
|
256 |
+
eval_logger.critical(f"API error {e.status_code}: {e.message}")
|
257 |
+
break
|
258 |
+
|
259 |
+
return res
|
260 |
+
|
261 |
+
def _model_call(self, inps):
|
262 |
+
# Isn't used because we override _loglikelihood_tokens
|
263 |
+
raise NotImplementedError()
|
264 |
+
|
265 |
+
def _model_generate(self, context, max_length, eos_token_id):
|
266 |
+
# Isn't used because we override generate_until
|
267 |
+
raise NotImplementedError()
|
268 |
+
|
269 |
+
def loglikelihood(self, requests, disable_tqdm: bool = False):
|
270 |
+
raise NotImplementedError("No support for logits.")
|
271 |
+
|
272 |
+
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
|
273 |
+
raise NotImplementedError("No support for logits.")
|
274 |
+
|
275 |
+
|
276 |
+
@register_model("anthropic-chat", "anthropic-chat-completions")
|
277 |
+
class AnthropicChat(LocalCompletionsAPI):
|
278 |
+
def __init__(
|
279 |
+
self,
|
280 |
+
base_url="https://api.anthropic.com/v1/messages",
|
281 |
+
tokenizer_backend=None,
|
282 |
+
**kwargs,
|
283 |
+
):
|
284 |
+
super().__init__(
|
285 |
+
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
|
286 |
+
)
|
287 |
+
eval_logger.warning(
|
288 |
+
"Chat completions does not support batching. Defaulting to batch size 1."
|
289 |
+
)
|
290 |
+
self._batch_size = 1
|
291 |
+
self.anthropic_version = "2023-06-01"
|
292 |
+
eval_logger.warning(
|
293 |
+
f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning"
|
294 |
+
)
|
295 |
+
|
296 |
+
@cached_property
|
297 |
+
def api_key(self):
|
298 |
+
"""Override this property to return the API key for the API request."""
|
299 |
+
key = os.environ.get("ANTHROPIC_API_KEY", None)
|
300 |
+
if key is None:
|
301 |
+
raise ValueError(
|
302 |
+
"API key not found. Please set the ANTHROPIC_API_KEY environment variable."
|
303 |
+
)
|
304 |
+
return key
|
305 |
+
|
306 |
+
@cached_property
|
307 |
+
def header(self):
|
308 |
+
return {
|
309 |
+
"x-api-key": f"{self.api_key}",
|
310 |
+
"anthropic-version": self.anthropic_version,
|
311 |
+
}
|
312 |
+
|
313 |
+
def _create_payload(
|
314 |
+
self, messages: List[Dict], generate=True, gen_kwargs: dict = None, **kwargs
|
315 |
+
) -> dict:
|
316 |
+
system = (
|
317 |
+
messages[0].get("content") if messages[0].get("role") == "system" else None
|
318 |
+
)
|
319 |
+
if system:
|
320 |
+
messages = messages[1:]
|
321 |
+
gen_kwargs.pop("do_sample", False)
|
322 |
+
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
|
323 |
+
temperature = gen_kwargs.pop("temperature", 0)
|
324 |
+
stop = gen_kwargs.pop("until", ["\n\nHuman:"])
|
325 |
+
if not isinstance(stop, list):
|
326 |
+
stop = [stop]
|
327 |
+
out = {
|
328 |
+
"messages": messages,
|
329 |
+
"model": self.model,
|
330 |
+
"max_tokens": max_tokens,
|
331 |
+
"temperature": temperature,
|
332 |
+
"stop_sequences": stop,
|
333 |
+
**gen_kwargs,
|
334 |
+
}
|
335 |
+
if system:
|
336 |
+
out["system"] = system
|
337 |
+
return out
|
338 |
+
|
339 |
+
def parse_generations(
|
340 |
+
self, outputs: Union[Dict, List[Dict]], **kwargs
|
341 |
+
) -> List[str]:
|
342 |
+
res = []
|
343 |
+
if not isinstance(outputs, list):
|
344 |
+
outputs = [outputs]
|
345 |
+
for out in outputs:
|
346 |
+
for choices in out["content"]:
|
347 |
+
res.append(choices["text"])
|
348 |
+
return res
|
349 |
+
|
350 |
+
def tok_encode(
|
351 |
+
self,
|
352 |
+
string: str,
|
353 |
+
left_truncate_len=None,
|
354 |
+
add_special_tokens=None,
|
355 |
+
**kwargs,
|
356 |
+
) -> List[str]:
|
357 |
+
return [string]
|
358 |
+
|
359 |
+
def loglikelihood(self, requests, **kwargs):
|
360 |
+
raise NotImplementedError(
|
361 |
+
"Anthropic Chat Completions API does not support the return of loglikelihood"
|
362 |
+
)
|
scripts/yans/lm-evaluation-harness/lm_eval/models/api_models.py
ADDED
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import asyncio
|
3 |
+
import copy
|
4 |
+
import itertools
|
5 |
+
import json
|
6 |
+
from functools import cached_property
|
7 |
+
from typing import (
|
8 |
+
Any,
|
9 |
+
Awaitable,
|
10 |
+
Callable,
|
11 |
+
Dict,
|
12 |
+
Iterable,
|
13 |
+
List,
|
14 |
+
Literal,
|
15 |
+
NamedTuple,
|
16 |
+
Optional,
|
17 |
+
Tuple,
|
18 |
+
Union,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
try:
|
23 |
+
import requests
|
24 |
+
from aiohttp import ClientSession, TCPConnector
|
25 |
+
from tenacity import RetryError, retry, stop_after_attempt, wait_exponential
|
26 |
+
from tqdm import tqdm
|
27 |
+
from tqdm.asyncio import tqdm_asyncio
|
28 |
+
except ModuleNotFoundError:
|
29 |
+
pass
|
30 |
+
|
31 |
+
|
32 |
+
from importlib.util import find_spec
|
33 |
+
|
34 |
+
from lm_eval import utils
|
35 |
+
from lm_eval.api.instance import Instance
|
36 |
+
from lm_eval.api.model import TemplateLM
|
37 |
+
from lm_eval.models.utils import Collator, chunks, configure_pad_token
|
38 |
+
|
39 |
+
|
40 |
+
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]]
|
41 |
+
|
42 |
+
|
43 |
+
# utility class to keep track of json encoded chats
|
44 |
+
class JsonChatStr(NamedTuple):
|
45 |
+
prompt: str
|
46 |
+
|
47 |
+
def encode(self, encoding):
|
48 |
+
return self.prompt.encode(encoding)
|
49 |
+
|
50 |
+
|
51 |
+
eval_logger = utils.eval_logger
|
52 |
+
|
53 |
+
|
54 |
+
class TemplateAPI(TemplateLM):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
model: str = None,
|
58 |
+
pretrained: str = None, # `model` takes precedence over `pretrained` when passed.
|
59 |
+
base_url: str = None,
|
60 |
+
tokenizer: Optional[str] = None,
|
61 |
+
# Logliklehood tasks require a tokenizer to calculate context lengths,
|
62 |
+
# however the requests can be sent as a string if the API doesn't support token inputs.
|
63 |
+
# use tokenized_requests=False
|
64 |
+
tokenizer_backend: Optional[
|
65 |
+
Literal["tiktoken", "huggingface", None]
|
66 |
+
] = "huggingface",
|
67 |
+
truncate: bool = False,
|
68 |
+
# number of concurrent requests. More useful if not batching
|
69 |
+
num_concurrent: int = 1,
|
70 |
+
max_retries: int = 3,
|
71 |
+
max_gen_toks: int = 256,
|
72 |
+
batch_size: Union[str, int] = 1,
|
73 |
+
seed: int = 1234,
|
74 |
+
max_length: Optional[int] = 2048,
|
75 |
+
add_bos_token: bool = False,
|
76 |
+
custom_prefix_token_id=None,
|
77 |
+
# send the requests as tokens or strings
|
78 |
+
tokenized_requests=True,
|
79 |
+
**kwargs,
|
80 |
+
) -> None:
|
81 |
+
super().__init__()
|
82 |
+
missing_packages = [
|
83 |
+
pkg
|
84 |
+
for pkg in ["aiohttp", "tqdm", "tenacity", "requests"]
|
85 |
+
if find_spec(pkg) is None
|
86 |
+
]
|
87 |
+
if missing_packages:
|
88 |
+
raise ModuleNotFoundError(
|
89 |
+
f"Attempted to use an API model, but the required packages {missing_packages} are not installed. "
|
90 |
+
'Please install these via `pip install lm-eval[api]` or `pip install -e ."[api]"`'
|
91 |
+
)
|
92 |
+
self.model = model or pretrained
|
93 |
+
self.base_url = base_url
|
94 |
+
self.tokenizer = tokenizer
|
95 |
+
if not isinstance(batch_size, int) and "auto" in batch_size:
|
96 |
+
eval_logger.warning(
|
97 |
+
"Automatic batch size is not supported for API models. Defaulting to batch size 1."
|
98 |
+
)
|
99 |
+
elif int(batch_size) > 1:
|
100 |
+
eval_logger.warning(
|
101 |
+
"Batch size > 1 detected. Ensure your API supports batched requests with varying total sequence lengths."
|
102 |
+
)
|
103 |
+
self._batch_size = int(batch_size) if batch_size != "auto" else 1
|
104 |
+
self._truncate = truncate
|
105 |
+
self._max_gen_toks = int(max_gen_toks)
|
106 |
+
self._seed = int(seed)
|
107 |
+
self.max_length = max_length
|
108 |
+
if int(num_concurrent) <= 1:
|
109 |
+
eval_logger.info(
|
110 |
+
"Concurrent requests are disabled. To enable concurrent requests, set `num_concurrent` > 1."
|
111 |
+
)
|
112 |
+
self._concurrent = int(num_concurrent)
|
113 |
+
self.tokenizer_backend = tokenizer_backend
|
114 |
+
self.add_bos_token = add_bos_token
|
115 |
+
self.custom_prefix_token_id = custom_prefix_token_id
|
116 |
+
self.tokenized_requests = tokenized_requests
|
117 |
+
self.max_retries = int(max_retries)
|
118 |
+
|
119 |
+
eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
|
120 |
+
if self.tokenizer_backend is None:
|
121 |
+
self.tokenizer = None
|
122 |
+
self.tokenized_requests = False
|
123 |
+
else:
|
124 |
+
if self.tokenizer is None:
|
125 |
+
if self.tokenizer_backend == "huggingface":
|
126 |
+
import transformers
|
127 |
+
|
128 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
129 |
+
self.tokenizer if self.tokenizer else self.model
|
130 |
+
)
|
131 |
+
# Not used as the API will handle padding but to mirror the behavior of the HFLM
|
132 |
+
self.tokenizer = configure_pad_token(self.tokenizer)
|
133 |
+
elif self.tokenizer_backend == "tiktoken":
|
134 |
+
try:
|
135 |
+
import tiktoken
|
136 |
+
|
137 |
+
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
138 |
+
except ModuleNotFoundError as e:
|
139 |
+
raise Exception(
|
140 |
+
"Attempted to use 'openai' LM type, but the package `tiktoken` is not installed. "
|
141 |
+
"Please install it via `pip install lm-eval[api]` or `pip install -e .[api]`."
|
142 |
+
) from e
|
143 |
+
if "openai" not in self.base_url:
|
144 |
+
eval_logger.warning(
|
145 |
+
f"Passed `base_url={self.base_url}` but using (OpenAI) Tiktoken tokenizer backend. "
|
146 |
+
"Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken."
|
147 |
+
)
|
148 |
+
else:
|
149 |
+
import transformers
|
150 |
+
|
151 |
+
assert isinstance(tokenizer, str), "tokenizer must be a string"
|
152 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
153 |
+
tokenizer,
|
154 |
+
)
|
155 |
+
|
156 |
+
@abc.abstractmethod
|
157 |
+
def _create_payload(
|
158 |
+
self,
|
159 |
+
messages: Union[List[List[int]], List[dict], List[str], str],
|
160 |
+
*,
|
161 |
+
generate: bool = True,
|
162 |
+
gen_kwargs: Optional[dict] = None,
|
163 |
+
seed: int = 1234,
|
164 |
+
**kwargs,
|
165 |
+
) -> dict:
|
166 |
+
"""This method is responsible for creating the json payload that will be sent to the API."""
|
167 |
+
raise NotImplementedError
|
168 |
+
|
169 |
+
def create_message(
|
170 |
+
self,
|
171 |
+
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
|
172 |
+
generate=False,
|
173 |
+
) -> Union[List[List[int]], List[dict], List[str], str]:
|
174 |
+
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
|
175 |
+
if isinstance(messages[0], JsonChatStr):
|
176 |
+
# for chat completions we need to decode the json string to list[dict,...]
|
177 |
+
assert (
|
178 |
+
self._batch_size == 1
|
179 |
+
), "non-tokenized chat requests are only supported with batch_size=1"
|
180 |
+
# list[dict["role":..., "content":...],...]
|
181 |
+
return json.loads(messages[0].prompt)
|
182 |
+
|
183 |
+
if not self.tokenized_requests:
|
184 |
+
# if messages are tokenized:
|
185 |
+
if isinstance(messages[0][0], int):
|
186 |
+
# assuming decoding is lossless. However, this is only for logliklehood requests
|
187 |
+
# as we need to compute the context length. For generations, we don't need to tokenize.
|
188 |
+
messages = self.decode_batch(messages)
|
189 |
+
if self._batch_size <= 1:
|
190 |
+
# if batch is 1 return str
|
191 |
+
return messages[0]
|
192 |
+
else:
|
193 |
+
# list[str,...]
|
194 |
+
return messages
|
195 |
+
|
196 |
+
# list[list[int], ...]
|
197 |
+
return messages
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
@abc.abstractmethod
|
201 |
+
def parse_logprobs(
|
202 |
+
outputs: Union[Any, List[Any]],
|
203 |
+
tokens: List[List[int]] = None,
|
204 |
+
ctxlen: List[int] = None,
|
205 |
+
**kwargs,
|
206 |
+
) -> List[Tuple[float, bool]]:
|
207 |
+
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
|
208 |
+
raise NotImplementedError
|
209 |
+
|
210 |
+
@staticmethod
|
211 |
+
@abc.abstractmethod
|
212 |
+
def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]:
|
213 |
+
"""Method used to parse the generations from the (batched) API response. This method should return a list of str"""
|
214 |
+
raise NotImplementedError
|
215 |
+
|
216 |
+
@cached_property
|
217 |
+
def api_key(self) -> str:
|
218 |
+
"""Override this property to return the API key for the API request."""
|
219 |
+
return ""
|
220 |
+
|
221 |
+
@cached_property
|
222 |
+
def header(self) -> dict:
|
223 |
+
"""Override this property to return the headers for the API request."""
|
224 |
+
return {"Authorization": f"Bearer {self.api_key}"}
|
225 |
+
|
226 |
+
@property
|
227 |
+
def chat_template(self) -> str:
|
228 |
+
"""Must be defined for LM subclasses that implement Chat Templating.
|
229 |
+
Should return the structure of the chat template applied to user/assistant messages.
|
230 |
+
Only used for logging and reproducibility.
|
231 |
+
"""
|
232 |
+
return ""
|
233 |
+
|
234 |
+
@property
|
235 |
+
def tokenizer_name(self) -> str:
|
236 |
+
"""Must be defined for LM subclasses which implement Chat Templating.
|
237 |
+
Should return the name of the tokenizer or chat template used.
|
238 |
+
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
|
239 |
+
"""
|
240 |
+
return ""
|
241 |
+
|
242 |
+
def apply_chat_template(
|
243 |
+
self, chat_history: List[Dict[str, str]]
|
244 |
+
) -> Union[str, JsonChatStr]:
|
245 |
+
"""Applies a chat template to a list of chat history between user and model."""
|
246 |
+
if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
|
247 |
+
return self.tokenizer.apply_chat_template(
|
248 |
+
chat_history, tokenize=False, add_generation_prompt=True
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
# bit of a hack. We'll load back before sending to the API
|
252 |
+
return JsonChatStr(json.dumps(chat_history))
|
253 |
+
|
254 |
+
@cached_property
|
255 |
+
def eot_token_id(self) -> Optional[int]:
|
256 |
+
if self.tokenizer is None:
|
257 |
+
return None
|
258 |
+
else:
|
259 |
+
if self.tokenizer_backend == "huggingface":
|
260 |
+
return self.tokenizer.eos_token_id
|
261 |
+
elif self.tokenizer_backend == "tiktoken":
|
262 |
+
return self.tokenizer.eot_token
|
263 |
+
|
264 |
+
@cached_property
|
265 |
+
def prefix_token_id(self) -> Optional[int]:
|
266 |
+
if self.tokenizer is None:
|
267 |
+
return None
|
268 |
+
else:
|
269 |
+
if self.custom_prefix_token_id is not None:
|
270 |
+
return self.custom_prefix_token_id
|
271 |
+
if self.tokenizer_backend == "huggingface":
|
272 |
+
if self.tokenizer.bos_token_id is not None:
|
273 |
+
return self.tokenizer.bos_token_id
|
274 |
+
return self.tokenizer.eos_token_id
|
275 |
+
else:
|
276 |
+
return self.tokenizer.eot_token
|
277 |
+
|
278 |
+
def tok_encode(
|
279 |
+
self,
|
280 |
+
string: str,
|
281 |
+
left_truncate_len: int = None,
|
282 |
+
add_special_tokens: bool = False,
|
283 |
+
truncation: bool = False,
|
284 |
+
**kwargs,
|
285 |
+
) -> Union[List[List[int]], List[int], List[str]]:
|
286 |
+
if self.tokenizer_backend is None:
|
287 |
+
return [string]
|
288 |
+
elif self.tokenizer_backend == "huggingface":
|
289 |
+
# by default for CausalLM - false or self.add_bos_token is set
|
290 |
+
if not add_special_tokens:
|
291 |
+
add_special_tokens = False or self.add_bos_token
|
292 |
+
encoding: Union[List[List[int]], List[int]] = self.tokenizer(
|
293 |
+
string,
|
294 |
+
add_special_tokens=add_special_tokens,
|
295 |
+
truncation=truncation,
|
296 |
+
return_attention_mask=False,
|
297 |
+
).input_ids
|
298 |
+
|
299 |
+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
|
300 |
+
if left_truncate_len:
|
301 |
+
if not isinstance(string, str):
|
302 |
+
encoding = [enc[-left_truncate_len:] for enc in encoding]
|
303 |
+
else:
|
304 |
+
encoding = encoding[-left_truncate_len:]
|
305 |
+
|
306 |
+
return encoding
|
307 |
+
|
308 |
+
else:
|
309 |
+
try:
|
310 |
+
encoding = self.tokenizer.encode(string)
|
311 |
+
except Exception:
|
312 |
+
encoding = self.tokenizer.encode_batch(string)
|
313 |
+
return encoding
|
314 |
+
|
315 |
+
def decode_batch(self, tokens: List[List[int]]) -> List[str]:
|
316 |
+
if self.tokenizer_backend == "huggingface":
|
317 |
+
return self.tokenizer.batch_decode(tokens)
|
318 |
+
elif self.tokenizer_backend == "tiktoken":
|
319 |
+
return self.tokenizer.decode_batch(tokens)
|
320 |
+
|
321 |
+
def model_call(
|
322 |
+
self,
|
323 |
+
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
|
324 |
+
*,
|
325 |
+
generate: bool = True,
|
326 |
+
gen_kwargs: Optional[Dict] = None,
|
327 |
+
**kwargs,
|
328 |
+
) -> Optional[dict]:
|
329 |
+
# !!! Copy: shared dict for each request, need new object !!!
|
330 |
+
gen_kwargs = copy.deepcopy(gen_kwargs)
|
331 |
+
try:
|
332 |
+
response = requests.post(
|
333 |
+
self.base_url,
|
334 |
+
json=self._create_payload(
|
335 |
+
self.create_message(messages),
|
336 |
+
generate=generate,
|
337 |
+
gen_kwargs=gen_kwargs,
|
338 |
+
seed=self._seed,
|
339 |
+
**kwargs,
|
340 |
+
),
|
341 |
+
headers=self.header,
|
342 |
+
)
|
343 |
+
if not response.ok:
|
344 |
+
eval_logger.warning(
|
345 |
+
f"API request failed with error message: {response.text}. Retrying..."
|
346 |
+
)
|
347 |
+
response.raise_for_status()
|
348 |
+
return response.json()
|
349 |
+
except RetryError:
|
350 |
+
eval_logger.error(
|
351 |
+
"API request failed after multiple retries. Please check the API status."
|
352 |
+
)
|
353 |
+
return None
|
354 |
+
|
355 |
+
async def amodel_call(
|
356 |
+
self,
|
357 |
+
session: ClientSession,
|
358 |
+
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
|
359 |
+
*,
|
360 |
+
generate: bool = True,
|
361 |
+
cache_keys: list = None,
|
362 |
+
ctxlens: Optional[List[int]] = None,
|
363 |
+
gen_kwargs: Optional[Dict] = None,
|
364 |
+
**kwargs,
|
365 |
+
) -> Union[List[str], List[Tuple[float, bool]], None]:
|
366 |
+
# !!! Copy: shared dict for each request, need new object !!!
|
367 |
+
gen_kwargs = copy.deepcopy(gen_kwargs)
|
368 |
+
payload = self._create_payload(
|
369 |
+
self.create_message(messages),
|
370 |
+
generate=generate,
|
371 |
+
gen_kwargs=gen_kwargs,
|
372 |
+
seed=self._seed,
|
373 |
+
**kwargs,
|
374 |
+
)
|
375 |
+
cache_method = "generate_until" if generate else "loglikelihood"
|
376 |
+
try:
|
377 |
+
async with session.post(
|
378 |
+
self.base_url,
|
379 |
+
json=payload,
|
380 |
+
headers=self.header,
|
381 |
+
) as response:
|
382 |
+
if not response.ok:
|
383 |
+
error_text = await response.text()
|
384 |
+
eval_logger.warning(
|
385 |
+
f"API request failed with error message: {error_text}. Retrying..."
|
386 |
+
)
|
387 |
+
# raising exception will retry the request
|
388 |
+
response.raise_for_status()
|
389 |
+
outputs = await response.json()
|
390 |
+
answers = (
|
391 |
+
self.parse_generations(
|
392 |
+
outputs=outputs,
|
393 |
+
)
|
394 |
+
if generate
|
395 |
+
else self.parse_logprobs(
|
396 |
+
outputs=outputs,
|
397 |
+
tokens=messages,
|
398 |
+
ctxlens=ctxlens,
|
399 |
+
)
|
400 |
+
)
|
401 |
+
if cache_keys:
|
402 |
+
for res, cache in zip(answers, cache_keys):
|
403 |
+
self.cache_hook.add_partial(cache_method, cache, res)
|
404 |
+
return answers
|
405 |
+
# If the retries also fail
|
406 |
+
except RetryError:
|
407 |
+
eval_logger.error(
|
408 |
+
"API request failed after multiple retries. Please check the API status."
|
409 |
+
)
|
410 |
+
return None
|
411 |
+
|
412 |
+
def batch_logliklehood_requests(
|
413 |
+
self, chunks: Iterable[List[LogLikelihoodInputs]]
|
414 |
+
) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]:
|
415 |
+
inputs = []
|
416 |
+
ctxlens = []
|
417 |
+
cache_keys = []
|
418 |
+
for chunk in chunks:
|
419 |
+
for cache_key, context_enc, continuation_enc in chunk:
|
420 |
+
inp = (context_enc + continuation_enc)[-(self.max_length) :]
|
421 |
+
ctxlen = len(context_enc) - max(
|
422 |
+
0, len(context_enc) + len(continuation_enc) - (self.max_length)
|
423 |
+
)
|
424 |
+
|
425 |
+
inputs.append(inp)
|
426 |
+
ctxlens.append(ctxlen)
|
427 |
+
cache_keys.append(cache_key)
|
428 |
+
return inputs, ctxlens, cache_keys
|
429 |
+
|
430 |
+
async def get_batched_requests(
|
431 |
+
self,
|
432 |
+
requests: list,
|
433 |
+
cache_keys: list,
|
434 |
+
*,
|
435 |
+
generate: bool = True,
|
436 |
+
ctxlens: List[int] = None,
|
437 |
+
**kwargs,
|
438 |
+
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
|
439 |
+
ctxlens = ctxlens if ctxlens else [None] * len(requests)
|
440 |
+
conn = TCPConnector(limit=self._concurrent)
|
441 |
+
async with ClientSession(connector=conn) as session:
|
442 |
+
retry_: Callable[..., Awaitable[Any]] = retry(
|
443 |
+
stop=stop_after_attempt(self.max_retries),
|
444 |
+
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
445 |
+
reraise=True,
|
446 |
+
)(self.amodel_call)
|
447 |
+
# Create tasks for each batch of request
|
448 |
+
tasks = [
|
449 |
+
asyncio.create_task(
|
450 |
+
retry_(
|
451 |
+
session=session,
|
452 |
+
messages=message,
|
453 |
+
cache_keys=cache_key,
|
454 |
+
generate=generate,
|
455 |
+
ctxlens=ctxlen,
|
456 |
+
**kwargs,
|
457 |
+
)
|
458 |
+
)
|
459 |
+
for message, cache_key, ctxlen in zip(
|
460 |
+
chunks(requests, n=self._batch_size),
|
461 |
+
chunks(cache_keys, n=self._batch_size),
|
462 |
+
chunks(ctxlens, n=self._batch_size),
|
463 |
+
)
|
464 |
+
]
|
465 |
+
|
466 |
+
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
|
467 |
+
|
468 |
+
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
|
469 |
+
assert (
|
470 |
+
self.tokenizer is not None
|
471 |
+
), "Tokenizer is required for loglikelihood tasks to compute context lengths."
|
472 |
+
res = []
|
473 |
+
|
474 |
+
def _collate(req: LogLikelihoodInputs):
|
475 |
+
"""Defines the key for the sorted method"""
|
476 |
+
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
477 |
+
# - time estimates will always be over not underestimates, which is more useful for planning
|
478 |
+
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
479 |
+
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
480 |
+
# automatic adaptive batches much much easier to implement
|
481 |
+
# - any OOMs will happen right away rather than near the end
|
482 |
+
|
483 |
+
toks = req[1] + req[2]
|
484 |
+
return -len(toks), tuple(toks)
|
485 |
+
|
486 |
+
re_ord = Collator(
|
487 |
+
requests,
|
488 |
+
sort_fn=_collate,
|
489 |
+
group_by=None,
|
490 |
+
)
|
491 |
+
# if concurrent then we'll batch in the async context
|
492 |
+
chunked = re_ord.get_batched(n=self._batch_size if self._concurrent <= 1 else 0)
|
493 |
+
if self._concurrent <= 1:
|
494 |
+
pbar = tqdm(desc="Requesting API", total=len(requests))
|
495 |
+
for chunk in chunked:
|
496 |
+
inputs, ctxlens, cache_keys = self.batch_logliklehood_requests([chunk])
|
497 |
+
|
498 |
+
outputs = retry(
|
499 |
+
stop=stop_after_attempt(self.max_retries),
|
500 |
+
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
501 |
+
reraise=True,
|
502 |
+
)(self.model_call)(messages=inputs, generate=False)
|
503 |
+
if isinstance(outputs, dict):
|
504 |
+
outputs = [outputs]
|
505 |
+
for answer_, cache_key in zip(
|
506 |
+
self.parse_logprobs(
|
507 |
+
outputs=outputs, tokens=inputs, ctxlens=ctxlens
|
508 |
+
),
|
509 |
+
cache_keys,
|
510 |
+
):
|
511 |
+
if answer_ is not None:
|
512 |
+
res.append(answer_)
|
513 |
+
# partial caching
|
514 |
+
if cache_key is not None:
|
515 |
+
self.cache_hook.add_partial(
|
516 |
+
"loglikelihood", cache_key, answer_
|
517 |
+
)
|
518 |
+
pbar.update(1)
|
519 |
+
else:
|
520 |
+
inputs, ctxlens, cache_keys = self.batch_logliklehood_requests(chunked)
|
521 |
+
res = itertools.chain.from_iterable(
|
522 |
+
asyncio.run(
|
523 |
+
self.get_batched_requests(
|
524 |
+
inputs, cache_keys, generate=False, ctxlens=ctxlens
|
525 |
+
)
|
526 |
+
)
|
527 |
+
)
|
528 |
+
|
529 |
+
return re_ord.get_original(res)
|
530 |
+
|
531 |
+
def generate_until(
|
532 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
533 |
+
) -> List[str]:
|
534 |
+
res = []
|
535 |
+
|
536 |
+
def _collate_gen(_requests):
|
537 |
+
# sort by the length of the non-tokenized contexts
|
538 |
+
return -len(_requests[0])
|
539 |
+
|
540 |
+
# Let the API deal with tokenization
|
541 |
+
requests, all_gen_kwargs = zip(*(req.args for req in requests))
|
542 |
+
if self.tokenized_requests:
|
543 |
+
encodings_list = self.tok_encode(
|
544 |
+
requests, add_special_tokens=self.add_bos_token
|
545 |
+
)
|
546 |
+
else:
|
547 |
+
encodings_list = [None] * len(requests)
|
548 |
+
requests = [
|
549 |
+
(a, b, c) for a, b, c in zip(requests, all_gen_kwargs, encodings_list)
|
550 |
+
]
|
551 |
+
|
552 |
+
re_ord = Collator(
|
553 |
+
requests,
|
554 |
+
sort_fn=_collate_gen,
|
555 |
+
group_by="gen_kwargs",
|
556 |
+
)
|
557 |
+
chunked = re_ord.get_batched(
|
558 |
+
n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None
|
559 |
+
)
|
560 |
+
if self._concurrent <= 1:
|
561 |
+
pbar = tqdm(desc="Requesting API", total=len(requests))
|
562 |
+
for chunk in chunked:
|
563 |
+
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
|
564 |
+
req = encodings_list if self.tokenized_requests else contexts
|
565 |
+
outputs = retry(
|
566 |
+
stop=stop_after_attempt(self.max_retries),
|
567 |
+
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
568 |
+
reraise=True,
|
569 |
+
)(self.model_call)(
|
570 |
+
messages=req,
|
571 |
+
generate=True,
|
572 |
+
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
|
573 |
+
)
|
574 |
+
for generated_text, context in zip(
|
575 |
+
self.parse_generations(
|
576 |
+
outputs=outputs,
|
577 |
+
contexts=contexts,
|
578 |
+
),
|
579 |
+
contexts,
|
580 |
+
):
|
581 |
+
if generated_text is not None:
|
582 |
+
res.append(generated_text)
|
583 |
+
|
584 |
+
# partial caching
|
585 |
+
if context is not None:
|
586 |
+
self.cache_hook.add_partial(
|
587 |
+
"generate_until",
|
588 |
+
(context, all_gen_kwargs[0]),
|
589 |
+
generated_text,
|
590 |
+
)
|
591 |
+
pbar.update(1)
|
592 |
+
else:
|
593 |
+
for chunk in chunked:
|
594 |
+
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
|
595 |
+
req = encodings_list if self.tokenized_requests else contexts
|
596 |
+
results = itertools.chain.from_iterable(
|
597 |
+
asyncio.run(
|
598 |
+
self.get_batched_requests(
|
599 |
+
req,
|
600 |
+
cache_keys=[(ctx, all_gen_kwargs[0]) for ctx in contexts],
|
601 |
+
generate=True,
|
602 |
+
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
|
603 |
+
)
|
604 |
+
)
|
605 |
+
)
|
606 |
+
res.extend(results)
|
607 |
+
|
608 |
+
return re_ord.get_original(res)
|
609 |
+
|
610 |
+
def loglikelihood_rolling(
|
611 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
612 |
+
) -> List[float]:
|
613 |
+
loglikelihoods = []
|
614 |
+
|
615 |
+
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
|
616 |
+
rolling_token_windows = list(
|
617 |
+
map(
|
618 |
+
utils.make_disjoint_window,
|
619 |
+
utils.get_rolling_token_windows(
|
620 |
+
token_list=self.tok_encode(string),
|
621 |
+
prefix_token=self.prefix_token_id,
|
622 |
+
max_seq_len=self.max_length,
|
623 |
+
context_len=1,
|
624 |
+
),
|
625 |
+
)
|
626 |
+
)
|
627 |
+
|
628 |
+
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
|
629 |
+
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
|
630 |
+
|
631 |
+
string_nll = self._loglikelihood_tokens(
|
632 |
+
rolling_token_windows,
|
633 |
+
disable_tqdm=True,
|
634 |
+
)
|
635 |
+
|
636 |
+
# discard is_greedy
|
637 |
+
string_nll = [x[0] for x in string_nll]
|
638 |
+
|
639 |
+
string_nll = sum(string_nll)
|
640 |
+
loglikelihoods.append(string_nll)
|
641 |
+
return loglikelihoods
|
scripts/yans/lm-evaluation-harness/lm_eval/models/huggingface.py
ADDED
@@ -0,0 +1,1356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
from datetime import timedelta
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Dict, List, Literal, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import transformers
|
10 |
+
from accelerate import (
|
11 |
+
Accelerator,
|
12 |
+
InitProcessGroupKwargs,
|
13 |
+
find_executable_batch_size,
|
14 |
+
)
|
15 |
+
from accelerate.utils import get_max_memory
|
16 |
+
from huggingface_hub import HfApi
|
17 |
+
from packaging import version
|
18 |
+
from peft import PeftModel
|
19 |
+
from peft import __version__ as PEFT_VERSION
|
20 |
+
from tqdm import tqdm
|
21 |
+
from transformers.models.auto.modeling_auto import (
|
22 |
+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
23 |
+
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
24 |
+
)
|
25 |
+
|
26 |
+
from lm_eval import utils
|
27 |
+
from lm_eval.api.instance import Instance
|
28 |
+
from lm_eval.api.model import TemplateLM
|
29 |
+
from lm_eval.api.registry import register_model
|
30 |
+
from lm_eval.models.utils import (
|
31 |
+
Collator,
|
32 |
+
clear_torch_cache,
|
33 |
+
configure_pad_token,
|
34 |
+
get_dtype,
|
35 |
+
pad_and_concat,
|
36 |
+
stop_sequences_criteria,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
eval_logger = utils.eval_logger
|
41 |
+
|
42 |
+
|
43 |
+
@register_model("hf-auto", "hf", "huggingface")
|
44 |
+
class HFLM(TemplateLM):
|
45 |
+
"""
|
46 |
+
An abstracted Huggingface model class. Enables usage with both models of
|
47 |
+
`transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
|
48 |
+
|
49 |
+
Supports data-parallel multi-GPU with HF Accelerate.
|
50 |
+
"""
|
51 |
+
|
52 |
+
AUTO_MODEL_CLASS = None
|
53 |
+
_DEFAULT_MAX_LENGTH = 2048
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
pretrained: Union[str, transformers.PreTrainedModel],
|
58 |
+
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
|
59 |
+
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
|
60 |
+
revision: Optional[str] = "main",
|
61 |
+
subfolder: Optional[str] = None,
|
62 |
+
tokenizer: Optional[
|
63 |
+
Union[
|
64 |
+
str,
|
65 |
+
transformers.PreTrainedTokenizer,
|
66 |
+
transformers.PreTrainedTokenizerFast,
|
67 |
+
]
|
68 |
+
] = None,
|
69 |
+
truncation: Optional[bool] = False,
|
70 |
+
logits_cache: bool = True,
|
71 |
+
max_length: Optional[int] = None,
|
72 |
+
device: Optional[str] = "cuda",
|
73 |
+
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
74 |
+
batch_size: Optional[Union[int, str]] = 1,
|
75 |
+
max_batch_size: Optional[int] = 64,
|
76 |
+
trust_remote_code: Optional[bool] = False,
|
77 |
+
use_fast_tokenizer: Optional[bool] = True,
|
78 |
+
add_bos_token: Optional[bool] = False,
|
79 |
+
prefix_token_id: Optional[int] = None,
|
80 |
+
# arguments used for splitting a model across GPUs naively.
|
81 |
+
# only used if `parallelize=True`.
|
82 |
+
parallelize: Optional[bool] = False,
|
83 |
+
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
84 |
+
max_cpu_memory: Optional[Union[int, str]] = None,
|
85 |
+
offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
|
86 |
+
# PEFT, delta weights and quantization options
|
87 |
+
peft: Optional[str] = None,
|
88 |
+
delta: Optional[str] = None,
|
89 |
+
autogptq: Optional[Union[bool, str]] = False,
|
90 |
+
**kwargs,
|
91 |
+
) -> None:
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
# optionally: take in an already-initialized transformers.PreTrainedModel
|
95 |
+
if not isinstance(pretrained, str):
|
96 |
+
eval_logger.warning(
|
97 |
+
"`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
|
98 |
+
)
|
99 |
+
assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
|
100 |
+
self._model = pretrained
|
101 |
+
self._device = self._model.device
|
102 |
+
self._config = self._model.config
|
103 |
+
gpus = 0
|
104 |
+
|
105 |
+
else:
|
106 |
+
assert isinstance(device, str)
|
107 |
+
assert isinstance(pretrained, str)
|
108 |
+
assert isinstance(batch_size, (int, str))
|
109 |
+
|
110 |
+
gpus = torch.cuda.device_count()
|
111 |
+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
|
112 |
+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
|
113 |
+
if accelerator.num_processes > 1:
|
114 |
+
self.accelerator = accelerator
|
115 |
+
|
116 |
+
if "npu" in accelerator.device.type:
|
117 |
+
gpus = torch.npu.device_count()
|
118 |
+
|
119 |
+
# using one process with no model parallelism
|
120 |
+
if not (parallelize or accelerator.num_processes > 1):
|
121 |
+
# use user-passed device
|
122 |
+
device_list = set(
|
123 |
+
["cuda", "cpu"]
|
124 |
+
+ [f"cuda:{i}" for i in range(gpus)]
|
125 |
+
+ ["mps", "mps:0"]
|
126 |
+
+ [f"npu:{i}" for i in range(gpus)]
|
127 |
+
)
|
128 |
+
if device and device in device_list:
|
129 |
+
self._device = torch.device(device)
|
130 |
+
eval_logger.info(f"Using device '{device}'")
|
131 |
+
if device in ("mps", "mps:0") and version.parse(
|
132 |
+
torch.__version__
|
133 |
+
) < version.parse("2.1"):
|
134 |
+
raise RuntimeError(
|
135 |
+
f"mps requires torch >= 2.1. You have {torch.__version__}"
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
eval_logger.info("Device not specified")
|
139 |
+
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
|
140 |
+
self._device = (
|
141 |
+
torch.device("cuda")
|
142 |
+
if torch.cuda.is_available()
|
143 |
+
else torch.device("cpu")
|
144 |
+
)
|
145 |
+
else: # Parallelism managed by accelerate
|
146 |
+
if device != "cuda":
|
147 |
+
eval_logger.info(
|
148 |
+
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
|
149 |
+
)
|
150 |
+
# TODO: include in warning that `load_in_8bit` etc. affect this too
|
151 |
+
self._device = (
|
152 |
+
self.accelerator.device
|
153 |
+
if hasattr(self, "accelerator")
|
154 |
+
else torch.device(device)
|
155 |
+
)
|
156 |
+
|
157 |
+
revision = str(revision) # cast to string if not already one
|
158 |
+
# TODO: update this to be less of a hack once subfolder is fixed in HF
|
159 |
+
revision = revision + ("/" + subfolder if subfolder is not None else "")
|
160 |
+
|
161 |
+
self._get_config(
|
162 |
+
pretrained,
|
163 |
+
revision=revision,
|
164 |
+
trust_remote_code=trust_remote_code,
|
165 |
+
)
|
166 |
+
|
167 |
+
# determine which of 'causal' and 'seq2seq' backends to use
|
168 |
+
self._get_backend(
|
169 |
+
config=self.config, backend=backend, trust_remote_code=trust_remote_code
|
170 |
+
)
|
171 |
+
|
172 |
+
# load tokenizer so we know tokenizer vocabulary size before loading model and PEFT
|
173 |
+
self._create_tokenizer(
|
174 |
+
pretrained,
|
175 |
+
tokenizer,
|
176 |
+
revision=revision,
|
177 |
+
trust_remote_code=trust_remote_code,
|
178 |
+
use_fast_tokenizer=use_fast_tokenizer,
|
179 |
+
)
|
180 |
+
|
181 |
+
# if we passed `pretrained` as a string, initialize our model now
|
182 |
+
if isinstance(pretrained, str):
|
183 |
+
self._create_model(
|
184 |
+
pretrained=pretrained,
|
185 |
+
revision=revision,
|
186 |
+
dtype=dtype,
|
187 |
+
trust_remote_code=trust_remote_code,
|
188 |
+
parallelize=parallelize,
|
189 |
+
gpus=gpus,
|
190 |
+
max_memory_per_gpu=max_memory_per_gpu,
|
191 |
+
max_cpu_memory=max_cpu_memory,
|
192 |
+
offload_folder=offload_folder,
|
193 |
+
peft=peft,
|
194 |
+
delta=delta,
|
195 |
+
autogptq=autogptq,
|
196 |
+
**kwargs,
|
197 |
+
)
|
198 |
+
|
199 |
+
# access self._model through self.model property outside this method
|
200 |
+
if isinstance(self.model, torch.nn.Module):
|
201 |
+
self.model.eval()
|
202 |
+
self.model.tie_weights()
|
203 |
+
|
204 |
+
self.truncation = truncation
|
205 |
+
self.logits_cache = logits_cache
|
206 |
+
self.vocab_size = self.tokenizer.vocab_size
|
207 |
+
# select (or create) a pad token to use
|
208 |
+
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
|
209 |
+
|
210 |
+
self.add_bos_token = add_bos_token
|
211 |
+
if "gemma" in getattr(self.config, "model_type", ""):
|
212 |
+
self.add_bos_token = True
|
213 |
+
eval_logger.info(
|
214 |
+
f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
|
215 |
+
)
|
216 |
+
|
217 |
+
self._max_length = max_length
|
218 |
+
self.pretrained = pretrained
|
219 |
+
self.delta = delta
|
220 |
+
self.peft = peft
|
221 |
+
self.revision = revision
|
222 |
+
self.batch_schedule = 1
|
223 |
+
self.batch_sizes = {}
|
224 |
+
self.max_batch_size = max_batch_size
|
225 |
+
|
226 |
+
if str(batch_size).startswith("auto"):
|
227 |
+
batch_size = batch_size.split(":")
|
228 |
+
self.batch_size_per_gpu = batch_size[0]
|
229 |
+
self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
|
230 |
+
else:
|
231 |
+
self.batch_size_per_gpu = int(batch_size)
|
232 |
+
|
233 |
+
if isinstance(pretrained, str):
|
234 |
+
if gpus >= 1 or str(self.device) == "mps":
|
235 |
+
# TODO: can remove this whole snippet except in the mps case, perhaps?
|
236 |
+
if not (parallelize or autogptq or hasattr(self, "accelerator")):
|
237 |
+
# place model onto device requested manually,
|
238 |
+
# if not using HF Accelerate or device_map
|
239 |
+
# or any other option that preloads model onto device
|
240 |
+
try:
|
241 |
+
self.model.to(self.device)
|
242 |
+
except ValueError:
|
243 |
+
eval_logger.debug(
|
244 |
+
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
|
245 |
+
)
|
246 |
+
# multigpu data-parallel support when launched with accelerate
|
247 |
+
if gpus > 1:
|
248 |
+
if accelerator.num_processes > 1:
|
249 |
+
if parallelize:
|
250 |
+
eval_logger.warning(
|
251 |
+
"You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
|
252 |
+
)
|
253 |
+
elif gpus > accelerator.num_processes:
|
254 |
+
eval_logger.warning(
|
255 |
+
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
|
256 |
+
"If you would like to use data parallelism, please launch the script "
|
257 |
+
"with 'accelerate launch *script*'. "
|
258 |
+
f"Current run will proceed with {accelerator.num_processes} devices."
|
259 |
+
)
|
260 |
+
if self.accelerator.is_local_main_process:
|
261 |
+
eval_logger.info(
|
262 |
+
f"Using {gpus} devices with data parallelism"
|
263 |
+
)
|
264 |
+
|
265 |
+
self._device = torch.device(f"{accelerator.device}")
|
266 |
+
self.accelerator = accelerator
|
267 |
+
|
268 |
+
self._rank = self.accelerator.local_process_index
|
269 |
+
self._world_size = self.accelerator.num_processes
|
270 |
+
else:
|
271 |
+
# if we aren't launching via accelerate, ditch
|
272 |
+
self._rank = 0
|
273 |
+
self._world_size = 1
|
274 |
+
else:
|
275 |
+
# if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
|
276 |
+
eval_logger.warning(
|
277 |
+
"Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
|
278 |
+
)
|
279 |
+
self._rank = 0
|
280 |
+
self._world_size = 1
|
281 |
+
|
282 |
+
self.custom_prefix_token_id = prefix_token_id
|
283 |
+
if prefix_token_id is not None:
|
284 |
+
eval_logger.info(
|
285 |
+
f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
|
286 |
+
)
|
287 |
+
|
288 |
+
def _get_accelerate_args(
|
289 |
+
self,
|
290 |
+
parallelize: bool = None,
|
291 |
+
device_map: Optional[str] = "auto",
|
292 |
+
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
293 |
+
max_cpu_memory: Optional[Union[int, str]] = None,
|
294 |
+
offload_folder: Optional[str] = "./offload",
|
295 |
+
gpus: Optional[int] = None,
|
296 |
+
) -> dict:
|
297 |
+
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
|
298 |
+
num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
|
299 |
+
num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes
|
300 |
+
if (
|
301 |
+
num_machines == 0
|
302 |
+
and hasattr(self, "accelerator")
|
303 |
+
and self.accelerator is not None
|
304 |
+
):
|
305 |
+
eval_logger.info(
|
306 |
+
"We are not in a distributed setting for accelerate. Setting model_parallel to False."
|
307 |
+
)
|
308 |
+
parallelize = False
|
309 |
+
|
310 |
+
if parallelize is None:
|
311 |
+
# If parallelism is unset by the user, we automatically assign model parallelism
|
312 |
+
# if enough extra GPUs are available
|
313 |
+
max_memory_all_gpus = get_max_memory()
|
314 |
+
# We just want gpu, not cpu, max memory
|
315 |
+
if "cpu" in max_memory_all_gpus:
|
316 |
+
del max_memory_all_gpus["cpu"]
|
317 |
+
parallelize = bool(num_local_processes < len(max_memory_all_gpus))
|
318 |
+
eval_logger.info(
|
319 |
+
f"Setting model parallel to {parallelize} since "
|
320 |
+
f"the number of local processes is {num_local_processes} "
|
321 |
+
f"and the number of GPUs is {len(max_memory_all_gpus)}"
|
322 |
+
)
|
323 |
+
|
324 |
+
args = {}
|
325 |
+
if parallelize: # Model parallelism will be used
|
326 |
+
max_memory = {}
|
327 |
+
if max_memory_per_gpu is not None: # Using the provided memory requirements
|
328 |
+
max_memory_per_gpu_map = {
|
329 |
+
device_idx: max_memory_per_gpu for device_idx in range(gpus)
|
330 |
+
}
|
331 |
+
else: # Estimating the possible memory requirements
|
332 |
+
max_memory_all_gpus = get_max_memory()
|
333 |
+
if "cpu" in max_memory_all_gpus:
|
334 |
+
del max_memory_all_gpus["cpu"]
|
335 |
+
if not hasattr(self, "accelerator"):
|
336 |
+
max_memory_per_gpu_map = {
|
337 |
+
k: v for k, v in max_memory_all_gpus.items()
|
338 |
+
}
|
339 |
+
else:
|
340 |
+
# use only 1 / num_processes of the GPUs if we are running under accelerate launch
|
341 |
+
max_memory_per_gpu_map = {
|
342 |
+
k: v
|
343 |
+
for k, v in max_memory_all_gpus.items()
|
344 |
+
if k % num_local_processes
|
345 |
+
== (self.accelerator.process_index % num_local_processes)
|
346 |
+
}
|
347 |
+
args["max_memory"] = max_memory_per_gpu_map
|
348 |
+
args["device_map"] = "auto"
|
349 |
+
eval_logger.info(
|
350 |
+
f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to 'auto'"
|
351 |
+
)
|
352 |
+
|
353 |
+
if max_cpu_memory is not None:
|
354 |
+
max_memory["cpu"] = max_cpu_memory
|
355 |
+
|
356 |
+
args["offload_folder"] = offload_folder
|
357 |
+
elif (
|
358 |
+
device_map is None
|
359 |
+
): # No model parallelism, we use the default provided device for our model
|
360 |
+
if hasattr(self, "accelerator"):
|
361 |
+
device_map = {"": f"{self.accelerator.device}"}
|
362 |
+
else:
|
363 |
+
device_map = {"": str(self.device)}
|
364 |
+
args["max_memory"] = None
|
365 |
+
args["device_map"] = device_map
|
366 |
+
eval_logger.info(
|
367 |
+
f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}"
|
368 |
+
)
|
369 |
+
else:
|
370 |
+
args["max_memory"] = None
|
371 |
+
args["device_map"] = None
|
372 |
+
eval_logger.info("Model parallel was set to False.")
|
373 |
+
|
374 |
+
return args
|
375 |
+
|
376 |
+
@property
|
377 |
+
def config(self):
|
378 |
+
# return the associated transformers.AutoConfig for the given pretrained model.
|
379 |
+
return self._config
|
380 |
+
|
381 |
+
@property
|
382 |
+
def model(self):
|
383 |
+
# returns the model, unwrapping it if using Accelerate
|
384 |
+
if hasattr(self, "accelerator"):
|
385 |
+
return self.accelerator.unwrap_model(self._model)
|
386 |
+
else:
|
387 |
+
return self._model
|
388 |
+
|
389 |
+
@property
|
390 |
+
def eot_token_id(self):
|
391 |
+
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
|
392 |
+
return self.tokenizer.eos_token_id
|
393 |
+
|
394 |
+
@property
|
395 |
+
def prefix_token_id(self):
|
396 |
+
# it is used as prefix for loglikelihood
|
397 |
+
if self.custom_prefix_token_id is not None:
|
398 |
+
return self.custom_prefix_token_id
|
399 |
+
if self.tokenizer.bos_token_id is not None:
|
400 |
+
return self.tokenizer.bos_token_id
|
401 |
+
return self.tokenizer.eos_token_id
|
402 |
+
|
403 |
+
@property
|
404 |
+
def max_length(self):
|
405 |
+
if self._max_length: # if max length manually set, return it
|
406 |
+
return self._max_length
|
407 |
+
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
|
408 |
+
for attr in seqlen_config_attrs:
|
409 |
+
if hasattr(self.model.config, attr):
|
410 |
+
return getattr(self.model.config, attr)
|
411 |
+
if hasattr(self.tokenizer, "model_max_length"):
|
412 |
+
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
|
413 |
+
return self._DEFAULT_MAX_LENGTH
|
414 |
+
return self.tokenizer.model_max_length
|
415 |
+
return self._DEFAULT_MAX_LENGTH
|
416 |
+
|
417 |
+
@property
|
418 |
+
def max_gen_toks(self) -> int:
|
419 |
+
return 256
|
420 |
+
|
421 |
+
@property
|
422 |
+
def batch_size(self):
|
423 |
+
return self.batch_size_per_gpu
|
424 |
+
|
425 |
+
@property
|
426 |
+
def device(self):
|
427 |
+
return self._device
|
428 |
+
|
429 |
+
@property
|
430 |
+
def rank(self):
|
431 |
+
return self._rank
|
432 |
+
|
433 |
+
@property
|
434 |
+
def world_size(self):
|
435 |
+
return self._world_size
|
436 |
+
|
437 |
+
@property
|
438 |
+
def tokenizer_name(self) -> str:
|
439 |
+
return self.tokenizer.name_or_path.replace("/", "__")
|
440 |
+
|
441 |
+
@property
|
442 |
+
def chat_template(self) -> str:
|
443 |
+
if self.tokenizer.chat_template is not None:
|
444 |
+
return self.tokenizer.chat_template
|
445 |
+
return self.tokenizer.default_chat_template
|
446 |
+
|
447 |
+
def _get_backend(
|
448 |
+
self,
|
449 |
+
config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
|
450 |
+
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
|
451 |
+
trust_remote_code: Optional[bool] = False,
|
452 |
+
) -> None:
|
453 |
+
"""
|
454 |
+
Helper method during initialization.
|
455 |
+
Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder))
|
456 |
+
model type to be used.
|
457 |
+
"""
|
458 |
+
assert backend in ["default", "causal", "seq2seq"]
|
459 |
+
|
460 |
+
if backend != "default":
|
461 |
+
# if we've settled on non-default backend, use that manually
|
462 |
+
if backend == "causal":
|
463 |
+
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
|
464 |
+
elif backend == "seq2seq":
|
465 |
+
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
|
466 |
+
eval_logger.info(
|
467 |
+
f"Overrode HF model backend type, and using type '{backend}'"
|
468 |
+
)
|
469 |
+
else:
|
470 |
+
# determine and use the default HF backend for this model, based on its config + metadata.
|
471 |
+
if (
|
472 |
+
getattr(config, "model_type")
|
473 |
+
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
474 |
+
):
|
475 |
+
# first check if model type is listed under seq2seq models, since some
|
476 |
+
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
|
477 |
+
# these special cases should be treated as seq2seq models.
|
478 |
+
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
|
479 |
+
elif (
|
480 |
+
getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
481 |
+
):
|
482 |
+
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
|
483 |
+
else:
|
484 |
+
if not trust_remote_code:
|
485 |
+
eval_logger.warning(
|
486 |
+
"HF model type is neither marked as CausalLM or Seq2SeqLM. \
|
487 |
+
This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
|
488 |
+
)
|
489 |
+
# if model type is neither in HF transformers causal or seq2seq model registries
|
490 |
+
# then we default to AutoModelForCausalLM
|
491 |
+
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
|
492 |
+
|
493 |
+
assert self.AUTO_MODEL_CLASS in [
|
494 |
+
transformers.AutoModelForCausalLM,
|
495 |
+
transformers.AutoModelForSeq2SeqLM,
|
496 |
+
]
|
497 |
+
return None
|
498 |
+
|
499 |
+
def _get_config(
|
500 |
+
self,
|
501 |
+
pretrained: str,
|
502 |
+
revision: str = "main",
|
503 |
+
trust_remote_code: bool = False,
|
504 |
+
) -> None:
|
505 |
+
self._config = transformers.AutoConfig.from_pretrained(
|
506 |
+
pretrained,
|
507 |
+
revision=revision,
|
508 |
+
trust_remote_code=trust_remote_code,
|
509 |
+
)
|
510 |
+
|
511 |
+
def _create_model(
|
512 |
+
self,
|
513 |
+
pretrained: str,
|
514 |
+
revision: Optional[str] = "main",
|
515 |
+
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
516 |
+
trust_remote_code: Optional[bool] = False,
|
517 |
+
# arguments used for splitting a model across GPUs naively.
|
518 |
+
# only used if `parallelize=True`.
|
519 |
+
# (accelerate naive PP (device_map) options)
|
520 |
+
parallelize: Optional[bool] = False,
|
521 |
+
gpus: Optional[int] = None,
|
522 |
+
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
523 |
+
max_cpu_memory: Optional[Union[int, str]] = None,
|
524 |
+
offload_folder: Optional[str] = "./offload",
|
525 |
+
# PEFT, delta weights and quantization options
|
526 |
+
peft: Optional[str] = None,
|
527 |
+
delta: Optional[str] = None,
|
528 |
+
autogptq: Optional[Union[bool, str]] = False,
|
529 |
+
**kwargs,
|
530 |
+
) -> None:
|
531 |
+
"""
|
532 |
+
Initializes an HF or HF-compatible PreTrainedModel from scratch
|
533 |
+
inside HFLM, using the kwargs passed into self.__init__().
|
534 |
+
|
535 |
+
Also handles functionality such as AutoGPTQ usage and PEFT wrapping.
|
536 |
+
|
537 |
+
For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
|
538 |
+
(such as PyTorch models that are nearly, but not quite, fully mirroring
|
539 |
+
HF's public interface relied on in this HFLM class)
|
540 |
+
please consider subclassing HFLM and overriding this and other methods as needed.
|
541 |
+
"""
|
542 |
+
|
543 |
+
model_kwargs = kwargs if kwargs else {}
|
544 |
+
|
545 |
+
model_kwargs.update(
|
546 |
+
self._get_accelerate_args(
|
547 |
+
parallelize=parallelize,
|
548 |
+
device_map=kwargs.get("device_map", None),
|
549 |
+
max_memory_per_gpu=max_memory_per_gpu,
|
550 |
+
max_cpu_memory=max_cpu_memory,
|
551 |
+
offload_folder=offload_folder,
|
552 |
+
gpus=gpus,
|
553 |
+
)
|
554 |
+
)
|
555 |
+
|
556 |
+
if not autogptq:
|
557 |
+
if model_kwargs.get("load_in_4bit", None):
|
558 |
+
assert (
|
559 |
+
transformers.__version__ >= "4.30.0"
|
560 |
+
), "load_in_4bit requires transformers >= 4.30.0"
|
561 |
+
if transformers.__version__ >= "4.30.0":
|
562 |
+
if model_kwargs.get("load_in_4bit", None):
|
563 |
+
if model_kwargs.get("bnb_4bit_compute_dtype", None):
|
564 |
+
model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
|
565 |
+
model_kwargs["bnb_4bit_compute_dtype"]
|
566 |
+
)
|
567 |
+
|
568 |
+
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
|
569 |
+
pretrained,
|
570 |
+
revision=revision,
|
571 |
+
torch_dtype=get_dtype(dtype),
|
572 |
+
trust_remote_code=trust_remote_code,
|
573 |
+
**model_kwargs,
|
574 |
+
)
|
575 |
+
else:
|
576 |
+
try:
|
577 |
+
from auto_gptq import AutoGPTQForCausalLM
|
578 |
+
except ModuleNotFoundError:
|
579 |
+
raise Exception(
|
580 |
+
"Tried to load auto_gptq, but auto-gptq is not installed ",
|
581 |
+
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
|
582 |
+
)
|
583 |
+
|
584 |
+
self._model = AutoGPTQForCausalLM.from_quantized(
|
585 |
+
pretrained,
|
586 |
+
trust_remote_code=trust_remote_code,
|
587 |
+
model_basename=None if autogptq is True else Path(autogptq).stem,
|
588 |
+
use_safetensors=True
|
589 |
+
if autogptq is True
|
590 |
+
else autogptq.endswith(".safetensors"),
|
591 |
+
**model_kwargs,
|
592 |
+
)
|
593 |
+
|
594 |
+
if peft and delta:
|
595 |
+
raise ValueError(
|
596 |
+
"Cannot use both 'peft' and 'delta' options at the same time."
|
597 |
+
)
|
598 |
+
|
599 |
+
if peft:
|
600 |
+
if model_kwargs.get("load_in_4bit", None):
|
601 |
+
if version.parse(PEFT_VERSION) < version.parse("0.4.0"):
|
602 |
+
raise AssertionError("load_in_4bit requires peft >= 0.4.0")
|
603 |
+
if self._model.config.vocab_size != len(self.tokenizer):
|
604 |
+
# resize model for LoRAs with added tokens
|
605 |
+
self._model.resize_token_embeddings(len(self.tokenizer))
|
606 |
+
eval_logger.info(
|
607 |
+
f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
|
608 |
+
)
|
609 |
+
self._model = PeftModel.from_pretrained(
|
610 |
+
self._model, peft, revision=revision
|
611 |
+
)
|
612 |
+
elif delta:
|
613 |
+
if autogptq:
|
614 |
+
eval_logger.warning(
|
615 |
+
"Delta weights might trigger unexpected behavior when used with AutoGPTQ."
|
616 |
+
)
|
617 |
+
_model_delta = self.AUTO_MODEL_CLASS.from_pretrained(
|
618 |
+
delta,
|
619 |
+
revision=revision,
|
620 |
+
torch_dtype=get_dtype(dtype),
|
621 |
+
trust_remote_code=trust_remote_code,
|
622 |
+
**model_kwargs,
|
623 |
+
)
|
624 |
+
for name, param in self._model.state_dict().items():
|
625 |
+
try:
|
626 |
+
param.data += _model_delta.state_dict()[name]
|
627 |
+
except KeyError:
|
628 |
+
raise KeyError(f"Delta model is missing weights for layer: {name}")
|
629 |
+
except Exception as e:
|
630 |
+
raise RuntimeError(
|
631 |
+
f"Failed to add delta weights to layer {name}. Error: {e}"
|
632 |
+
)
|
633 |
+
|
634 |
+
del _model_delta
|
635 |
+
|
636 |
+
return None
|
637 |
+
|
638 |
+
def _create_tokenizer(
|
639 |
+
self,
|
640 |
+
pretrained: Union[str, transformers.PreTrainedModel],
|
641 |
+
tokenizer: Optional[
|
642 |
+
Union[
|
643 |
+
str,
|
644 |
+
transformers.PreTrainedTokenizer,
|
645 |
+
transformers.PreTrainedTokenizerFast,
|
646 |
+
]
|
647 |
+
],
|
648 |
+
revision: Optional[str] = "main",
|
649 |
+
trust_remote_code: Optional[bool] = False,
|
650 |
+
use_fast_tokenizer: Optional[bool] = True,
|
651 |
+
) -> None:
|
652 |
+
"""
|
653 |
+
Helper method during initialization.
|
654 |
+
|
655 |
+
Create a tokenizer object corresponding to the correct
|
656 |
+
tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
|
657 |
+
"""
|
658 |
+
|
659 |
+
if tokenizer:
|
660 |
+
if isinstance(tokenizer, str):
|
661 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
662 |
+
tokenizer,
|
663 |
+
revision=revision,
|
664 |
+
trust_remote_code=trust_remote_code,
|
665 |
+
use_fast=use_fast_tokenizer,
|
666 |
+
)
|
667 |
+
else:
|
668 |
+
assert isinstance(
|
669 |
+
tokenizer, transformers.PreTrainedTokenizer
|
670 |
+
) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
|
671 |
+
self.tokenizer = tokenizer
|
672 |
+
else:
|
673 |
+
# Get tokenizer based on 'pretrained'
|
674 |
+
if isinstance(pretrained, str):
|
675 |
+
model_name = pretrained
|
676 |
+
else:
|
677 |
+
# get the HF hub name via accessor on model
|
678 |
+
model_name = self.model.name_or_path
|
679 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
680 |
+
model_name,
|
681 |
+
revision=revision,
|
682 |
+
trust_remote_code=trust_remote_code,
|
683 |
+
use_fast=use_fast_tokenizer,
|
684 |
+
)
|
685 |
+
return None
|
686 |
+
|
687 |
+
def _detect_batch_size(self, requests=None, pos: int = 0):
|
688 |
+
if requests:
|
689 |
+
_, context_enc, continuation_enc = requests[pos]
|
690 |
+
max_length = len(
|
691 |
+
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
|
692 |
+
)
|
693 |
+
max_context_enc = len(context_enc[-(self.max_length + 1) :])
|
694 |
+
max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
|
695 |
+
else:
|
696 |
+
max_length = self.max_length
|
697 |
+
max_context_enc = max_length
|
698 |
+
max_cont_enc = max_length
|
699 |
+
|
700 |
+
# if OOM, then halves batch_size and tries again
|
701 |
+
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
|
702 |
+
def forward_batch(batch_size):
|
703 |
+
if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
|
704 |
+
length = max(max_context_enc, max_cont_enc)
|
705 |
+
batched_conts = torch.ones(
|
706 |
+
(batch_size, length), device=self.device
|
707 |
+
).long()
|
708 |
+
test_batch = torch.ones((batch_size, length), device=self.device).long()
|
709 |
+
call_kwargs = {
|
710 |
+
"attn_mask": test_batch,
|
711 |
+
"labels": batched_conts,
|
712 |
+
}
|
713 |
+
else:
|
714 |
+
call_kwargs = {}
|
715 |
+
test_batch = torch.ones(
|
716 |
+
(batch_size, max_length), device=self.device
|
717 |
+
).long()
|
718 |
+
for _ in range(5):
|
719 |
+
out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) # noqa: F841
|
720 |
+
|
721 |
+
return batch_size
|
722 |
+
|
723 |
+
try:
|
724 |
+
batch_size = forward_batch()
|
725 |
+
except RuntimeError as e:
|
726 |
+
if "No executable batch size found" in str(e):
|
727 |
+
batch_size = 1
|
728 |
+
else:
|
729 |
+
raise
|
730 |
+
|
731 |
+
if self.world_size > 1:
|
732 |
+
# if multi-GPU, always take minimum over all selected batch sizes
|
733 |
+
max_rnk_bs = torch.tensor([batch_size], device=self.device)
|
734 |
+
gathered = (
|
735 |
+
self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
|
736 |
+
)
|
737 |
+
batch_size = min(gathered)
|
738 |
+
clear_torch_cache()
|
739 |
+
return batch_size
|
740 |
+
|
741 |
+
clear_torch_cache()
|
742 |
+
return batch_size
|
743 |
+
|
744 |
+
def tok_encode(
|
745 |
+
self, string: str, left_truncate_len=None, add_special_tokens=None
|
746 |
+
) -> List[int]:
|
747 |
+
""" """
|
748 |
+
# default for None - empty dict, use predefined tokenizer param
|
749 |
+
# used for all models except for CausalLM or predefined value
|
750 |
+
special_tokens_kwargs = {}
|
751 |
+
|
752 |
+
# by default for CausalLM - false or self.add_bos_token is set
|
753 |
+
if add_special_tokens is None:
|
754 |
+
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
|
755 |
+
special_tokens_kwargs = {
|
756 |
+
"add_special_tokens": False or self.add_bos_token
|
757 |
+
}
|
758 |
+
# otherwise the method explicitly defines the value
|
759 |
+
else:
|
760 |
+
special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
|
761 |
+
|
762 |
+
encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
|
763 |
+
|
764 |
+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
|
765 |
+
if left_truncate_len:
|
766 |
+
encoding = encoding[-left_truncate_len:]
|
767 |
+
|
768 |
+
return encoding
|
769 |
+
|
770 |
+
def tok_batch_encode(
|
771 |
+
self,
|
772 |
+
strings: List[str],
|
773 |
+
padding_side: str = "left",
|
774 |
+
left_truncate_len: int = None,
|
775 |
+
truncation: bool = False,
|
776 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
777 |
+
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
|
778 |
+
old_padding_side = self.tokenizer.padding_side
|
779 |
+
self.tokenizer.padding_side = padding_side
|
780 |
+
|
781 |
+
add_special_tokens = {}
|
782 |
+
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
|
783 |
+
add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
|
784 |
+
|
785 |
+
encoding = self.tokenizer(
|
786 |
+
strings,
|
787 |
+
truncation=truncation,
|
788 |
+
padding="longest",
|
789 |
+
return_tensors="pt",
|
790 |
+
**add_special_tokens,
|
791 |
+
)
|
792 |
+
if left_truncate_len:
|
793 |
+
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
|
794 |
+
encoding["attention_mask"] = encoding["attention_mask"][
|
795 |
+
:, -left_truncate_len:
|
796 |
+
]
|
797 |
+
self.tokenizer.padding_side = old_padding_side
|
798 |
+
|
799 |
+
return encoding["input_ids"], encoding["attention_mask"]
|
800 |
+
|
801 |
+
def tok_decode(self, tokens, skip_special_tokens=True):
|
802 |
+
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
803 |
+
|
804 |
+
def _model_call(self, inps, attn_mask=None, labels=None):
|
805 |
+
"""
|
806 |
+
:param inps: torch.Tensor
|
807 |
+
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
|
808 |
+
[batch, sequence_ctx]. the size of sequence may vary from call to call
|
809 |
+
:param attn_mask: torch.Tensor, optional
|
810 |
+
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
|
811 |
+
(and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
|
812 |
+
:param labels: torch.Tensor, optional
|
813 |
+
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
|
814 |
+
(and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
|
815 |
+
:return
|
816 |
+
A torch tensor of shape [batch, sequence, vocab] with the
|
817 |
+
logits returned from the model's decoder
|
818 |
+
"""
|
819 |
+
with torch.no_grad():
|
820 |
+
if attn_mask is not None or labels is not None:
|
821 |
+
assert attn_mask is not None and labels is not None
|
822 |
+
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
|
823 |
+
return self.model(
|
824 |
+
input_ids=inps, attention_mask=attn_mask, labels=labels
|
825 |
+
).logits
|
826 |
+
else:
|
827 |
+
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
|
828 |
+
return self.model(inps).logits
|
829 |
+
|
830 |
+
def _model_generate(self, context, max_length, stop, **generation_kwargs):
|
831 |
+
# temperature = 0.0 if not set
|
832 |
+
# if do_sample is false and temp==0.0:
|
833 |
+
# remove temperature, as do_sample=False takes care of this
|
834 |
+
# and we don't want a warning from HF
|
835 |
+
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
|
836 |
+
do_sample = generation_kwargs.get("do_sample", None)
|
837 |
+
|
838 |
+
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
|
839 |
+
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
|
840 |
+
generation_kwargs["do_sample"] = do_sample = False
|
841 |
+
|
842 |
+
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
|
843 |
+
generation_kwargs.pop("temperature")
|
844 |
+
# build stopping criteria
|
845 |
+
stopping_criteria = stop_sequences_criteria(
|
846 |
+
self.tokenizer, stop, context.shape[1], context.shape[0]
|
847 |
+
)
|
848 |
+
return self.model.generate(
|
849 |
+
input_ids=context,
|
850 |
+
max_length=max_length,
|
851 |
+
stopping_criteria=stopping_criteria,
|
852 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
853 |
+
use_cache=True,
|
854 |
+
**generation_kwargs,
|
855 |
+
)
|
856 |
+
|
857 |
+
def _select_cont_toks(
|
858 |
+
self, logits: torch.Tensor, contlen: int = None, inplen: int = None
|
859 |
+
) -> torch.Tensor:
|
860 |
+
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
|
861 |
+
assert (
|
862 |
+
contlen and inplen
|
863 |
+
), "Must pass input len and cont. len to select scored logits for causal LM"
|
864 |
+
# discard right-padding.
|
865 |
+
# also discard the input/context tokens. we'll only score continuations.
|
866 |
+
logits = logits[inplen - contlen : inplen]
|
867 |
+
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
|
868 |
+
assert (
|
869 |
+
contlen and not inplen
|
870 |
+
), "Selecting scored logits for Seq2SeqLM requires only cont. len"
|
871 |
+
# only discard right-padding.
|
872 |
+
# the logits input to this fn only contain decoder-side tokens.
|
873 |
+
logits = logits[:contlen]
|
874 |
+
|
875 |
+
return logits
|
876 |
+
|
877 |
+
def loglikelihood_rolling(
|
878 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
879 |
+
) -> List[float]:
|
880 |
+
loglikelihoods = []
|
881 |
+
|
882 |
+
adaptive_batch_size = None
|
883 |
+
if self.batch_size == "auto":
|
884 |
+
# using rolling window with maximum context
|
885 |
+
print("Passed argument batch_size = auto. Detecting largest batch size")
|
886 |
+
batch_size = self._detect_batch_size()
|
887 |
+
print(f"Determined Largest batch size: {batch_size}")
|
888 |
+
adaptive_batch_size = batch_size
|
889 |
+
|
890 |
+
for (string,) in tqdm(
|
891 |
+
[req.args for req in requests], disable=(disable_tqdm or (self.rank != 0))
|
892 |
+
):
|
893 |
+
rolling_token_windows = list(
|
894 |
+
map(
|
895 |
+
utils.make_disjoint_window,
|
896 |
+
utils.get_rolling_token_windows(
|
897 |
+
token_list=self.tok_encode(string),
|
898 |
+
prefix_token=self.prefix_token_id,
|
899 |
+
max_seq_len=self.max_length,
|
900 |
+
context_len=1,
|
901 |
+
),
|
902 |
+
)
|
903 |
+
)
|
904 |
+
|
905 |
+
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
|
906 |
+
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
|
907 |
+
|
908 |
+
pad_amnt = 0
|
909 |
+
if self.world_size > 1:
|
910 |
+
# We pad out the external document-level iterator so the inner iterator doesn't hang
|
911 |
+
mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
|
912 |
+
gathered = (
|
913 |
+
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
|
914 |
+
)
|
915 |
+
|
916 |
+
pad_amnt = max(gathered) - gathered[self.rank]
|
917 |
+
if pad_amnt > 0:
|
918 |
+
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
|
919 |
+
|
920 |
+
string_nll = self._loglikelihood_tokens(
|
921 |
+
requests=rolling_token_windows,
|
922 |
+
disable_tqdm=True,
|
923 |
+
override_bs=adaptive_batch_size,
|
924 |
+
)
|
925 |
+
|
926 |
+
if (self.world_size > 1) and (pad_amnt > 0):
|
927 |
+
string_nll = [x[0] for x in string_nll[:-pad_amnt]]
|
928 |
+
else:
|
929 |
+
# discard is_greedy
|
930 |
+
string_nll = [x[0] for x in string_nll]
|
931 |
+
|
932 |
+
string_nll = sum(string_nll)
|
933 |
+
loglikelihoods.append(string_nll)
|
934 |
+
|
935 |
+
return loglikelihoods
|
936 |
+
|
937 |
+
def _batch_scheduler(self, pos, n_reordered_requests):
|
938 |
+
sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
|
939 |
+
if sched in self.batch_sizes:
|
940 |
+
return self.batch_sizes[sched]
|
941 |
+
if (len(self.batch_sizes) > 1) and (
|
942 |
+
self.batch_sizes[sched - 1] == self.max_batch_size
|
943 |
+
):
|
944 |
+
# if previous batch size is already maximal, skip recomputation
|
945 |
+
self.batch_sizes[sched] = self.max_batch_size
|
946 |
+
return self.batch_sizes[sched]
|
947 |
+
print(
|
948 |
+
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
|
949 |
+
)
|
950 |
+
self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
|
951 |
+
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
|
952 |
+
return self.batch_sizes[sched]
|
953 |
+
|
954 |
+
def _loglikelihood_tokens(
|
955 |
+
self,
|
956 |
+
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
|
957 |
+
disable_tqdm: bool = False,
|
958 |
+
override_bs: int = None,
|
959 |
+
) -> List[Tuple[float, bool]]:
|
960 |
+
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
|
961 |
+
res = []
|
962 |
+
|
963 |
+
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
|
964 |
+
"""Defines the key for the sorted method"""
|
965 |
+
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
966 |
+
# - time estimates will always be over not underestimates, which is more useful for planning
|
967 |
+
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
968 |
+
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
969 |
+
# automatic adaptive batches much much easier to implement
|
970 |
+
# - any OOMs will happen right away rather than near the end
|
971 |
+
|
972 |
+
toks = req[1] + req[2]
|
973 |
+
return -len(toks), tuple(toks)
|
974 |
+
|
975 |
+
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
|
976 |
+
"""Defines the key to group and lookup one-token continuations"""
|
977 |
+
# Use with group_by="contexts" (optional)"
|
978 |
+
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
|
979 |
+
# speeds up some multiple-choice tasks proportionally to the number of choices.
|
980 |
+
# groups requests by context+continuation[:-1] and infer on one request/group.
|
981 |
+
return req[-2] + req[-1][:-1]
|
982 |
+
|
983 |
+
re_ord = Collator(
|
984 |
+
requests,
|
985 |
+
sort_fn=_collate,
|
986 |
+
group_by="contexts"
|
987 |
+
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
|
988 |
+
and self.logits_cache
|
989 |
+
else None,
|
990 |
+
group_fn=_lookup_one_token_cont,
|
991 |
+
)
|
992 |
+
|
993 |
+
# automatic (variable) batch size detection for vectorization
|
994 |
+
# pull longest context sample from request
|
995 |
+
n_reordered_requests = len(re_ord)
|
996 |
+
batch_size = (
|
997 |
+
self.batch_size
|
998 |
+
if self.batch_size != "auto"
|
999 |
+
else override_bs
|
1000 |
+
if override_bs is not None
|
1001 |
+
else 0
|
1002 |
+
)
|
1003 |
+
batch_fn = (
|
1004 |
+
self._batch_scheduler
|
1005 |
+
if self.batch_size == "auto"
|
1006 |
+
and n_reordered_requests > 0
|
1007 |
+
and not override_bs
|
1008 |
+
else None
|
1009 |
+
)
|
1010 |
+
|
1011 |
+
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
|
1012 |
+
pbar = tqdm(
|
1013 |
+
total=len(requests),
|
1014 |
+
disable=(disable_tqdm or (self.rank != 0)),
|
1015 |
+
desc="Running loglikelihood requests",
|
1016 |
+
)
|
1017 |
+
for chunk in chunks:
|
1018 |
+
inps = []
|
1019 |
+
cont_toks_list = []
|
1020 |
+
inplens = []
|
1021 |
+
|
1022 |
+
conts = []
|
1023 |
+
encoder_attns = []
|
1024 |
+
|
1025 |
+
padding_len_inp = None
|
1026 |
+
padding_len_cont = None
|
1027 |
+
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
|
1028 |
+
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
|
1029 |
+
# again because vectorizing is annoying
|
1030 |
+
|
1031 |
+
for _, context_enc, continuation_enc in chunk:
|
1032 |
+
# sanity check
|
1033 |
+
assert len(context_enc) > 0
|
1034 |
+
assert len(continuation_enc) > 0
|
1035 |
+
assert len(continuation_enc) <= self.max_length
|
1036 |
+
|
1037 |
+
# how this all works (illustrated on a causal decoder-only setup):
|
1038 |
+
# CTX CONT
|
1039 |
+
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
|
1040 |
+
# model \ \
|
1041 |
+
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
|
1042 |
+
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
|
1043 |
+
|
1044 |
+
# when too long to fit in context, truncate from the left
|
1045 |
+
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
|
1046 |
+
inp = torch.tensor(
|
1047 |
+
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
|
1048 |
+
dtype=torch.long,
|
1049 |
+
device=self.device,
|
1050 |
+
)
|
1051 |
+
(inplen,) = inp.shape
|
1052 |
+
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
|
1053 |
+
inp = torch.tensor(
|
1054 |
+
(context_enc)[-self.max_length :],
|
1055 |
+
dtype=torch.long,
|
1056 |
+
device=self.device,
|
1057 |
+
)
|
1058 |
+
(inplen,) = inp.shape
|
1059 |
+
|
1060 |
+
# build encoder attn masks
|
1061 |
+
encoder_attns.append(torch.ones_like(inp))
|
1062 |
+
|
1063 |
+
cont = torch.tensor(
|
1064 |
+
(continuation_enc)[-self.max_length :],
|
1065 |
+
# TODO: left-shift these?
|
1066 |
+
# TODO: our code assumes we never end up truncating conts for either model type
|
1067 |
+
dtype=torch.long,
|
1068 |
+
device=self.device,
|
1069 |
+
)
|
1070 |
+
(contlen,) = cont.shape
|
1071 |
+
|
1072 |
+
conts.append(cont)
|
1073 |
+
|
1074 |
+
padding_len_cont = (
|
1075 |
+
max(padding_len_cont, contlen)
|
1076 |
+
if padding_len_cont is not None
|
1077 |
+
else contlen
|
1078 |
+
)
|
1079 |
+
|
1080 |
+
padding_len_inp = (
|
1081 |
+
max(padding_len_inp, inplen)
|
1082 |
+
if padding_len_inp is not None
|
1083 |
+
else inplen
|
1084 |
+
)
|
1085 |
+
|
1086 |
+
inps.append(inp) # [1, inp_length]
|
1087 |
+
cont_toks_list.append(continuation_enc)
|
1088 |
+
inplens.append(inplen)
|
1089 |
+
|
1090 |
+
# create encoder attn mask and batched conts, if seq2seq
|
1091 |
+
call_kwargs = {}
|
1092 |
+
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
|
1093 |
+
batched_inps = pad_and_concat(
|
1094 |
+
padding_len_inp, inps, padding_side="right"
|
1095 |
+
) # [batch, padding_len_inp]
|
1096 |
+
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
|
1097 |
+
# TODO: left-pad encoder inps and mask?
|
1098 |
+
batched_inps = pad_and_concat(
|
1099 |
+
padding_len_inp, inps
|
1100 |
+
) # [batch, padding_len_inp]
|
1101 |
+
batched_conts = pad_and_concat(
|
1102 |
+
padding_len_cont, conts
|
1103 |
+
) # [batch, padding_len_cont]
|
1104 |
+
batched_encoder_mask = pad_and_concat(
|
1105 |
+
padding_len_inp, encoder_attns
|
1106 |
+
) # [batch, padding_len_inp]
|
1107 |
+
call_kwargs = {
|
1108 |
+
"attn_mask": batched_encoder_mask,
|
1109 |
+
"labels": batched_conts,
|
1110 |
+
}
|
1111 |
+
|
1112 |
+
multi_logits = F.log_softmax(
|
1113 |
+
self._model_call(batched_inps, **call_kwargs), dim=-1
|
1114 |
+
) # [batch, padding_length (inp or cont), vocab]
|
1115 |
+
|
1116 |
+
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
|
1117 |
+
chunk, multi_logits, inplens, cont_toks_list
|
1118 |
+
):
|
1119 |
+
# Slice to original seq length
|
1120 |
+
contlen = len(cont_toks)
|
1121 |
+
# take only logits in the continuation
|
1122 |
+
# (discard context toks if decoder-only ; discard right-padding)
|
1123 |
+
# also discards + checks for "virtual tokens" in the causal LM's input window
|
1124 |
+
# from prompt/prefix tuning tokens, if applicable
|
1125 |
+
ctx_len = (
|
1126 |
+
inplen + (logits.shape[0] - padding_len_inp)
|
1127 |
+
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
|
1128 |
+
else None
|
1129 |
+
)
|
1130 |
+
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
|
1131 |
+
logits = logits.unsqueeze(0) # [1, seq, vocab]
|
1132 |
+
|
1133 |
+
# Check if per-token argmax is exactly equal to continuation
|
1134 |
+
greedy_tokens = logits.argmax(dim=-1)
|
1135 |
+
|
1136 |
+
# check for one-token continuation cache hits.
|
1137 |
+
# noop in case group_by != "contexts" or no cache hit and returns the
|
1138 |
+
# original args. Otherwise, expands the logits batch dimension and yields each
|
1139 |
+
# batch along with matching continuation tokens and prompt strings.
|
1140 |
+
# logits -> [1, seq, vocab]
|
1141 |
+
for request_str, cont_toks, logits in re_ord.get_cache(
|
1142 |
+
req_str=request_str,
|
1143 |
+
cxt_toks=ctx_tokens,
|
1144 |
+
cont_toks=cont_toks,
|
1145 |
+
logits=logits,
|
1146 |
+
):
|
1147 |
+
cont_toks = torch.tensor(
|
1148 |
+
cont_toks, dtype=torch.long, device=self.device
|
1149 |
+
).unsqueeze(0) # [1, seq]
|
1150 |
+
max_equal = (greedy_tokens == cont_toks).all()
|
1151 |
+
|
1152 |
+
# Obtain log-probs at the corresponding continuation token indices
|
1153 |
+
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
|
1154 |
+
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
|
1155 |
+
-1
|
1156 |
+
) # [1, seq]
|
1157 |
+
|
1158 |
+
# Answer: (log prob, is-exact-match)
|
1159 |
+
answer = (float(logits.sum()), bool(max_equal))
|
1160 |
+
|
1161 |
+
res.append(answer)
|
1162 |
+
|
1163 |
+
self.cache_hook.add_partial("loglikelihood", request_str, answer)
|
1164 |
+
pbar.update(1)
|
1165 |
+
|
1166 |
+
pbar.close()
|
1167 |
+
|
1168 |
+
return re_ord.get_original(res)
|
1169 |
+
|
1170 |
+
def generate_until(
|
1171 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
1172 |
+
) -> List[str]:
|
1173 |
+
res = []
|
1174 |
+
|
1175 |
+
def _collate(req: Tuple[str, dict]):
|
1176 |
+
"""Defines the key for the sorted method"""
|
1177 |
+
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
1178 |
+
# - time estimates will always be over not underestimates, which is more useful for planning
|
1179 |
+
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
1180 |
+
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
1181 |
+
# automatic adaptive batches much much easier to implement
|
1182 |
+
# - any OOMs will happen right away rather than near the end
|
1183 |
+
toks = self.tok_encode(req[0])
|
1184 |
+
return -len(toks), req[0]
|
1185 |
+
|
1186 |
+
pbar = tqdm(
|
1187 |
+
total=len(requests),
|
1188 |
+
disable=(disable_tqdm or (self.rank != 0)),
|
1189 |
+
desc="Running generate_until requests",
|
1190 |
+
)
|
1191 |
+
adaptive_batch_size = None
|
1192 |
+
if self.batch_size == "auto":
|
1193 |
+
# using rolling window with maximum context
|
1194 |
+
print("Passed argument batch_size = auto. Detecting largest batch size")
|
1195 |
+
batch_size = self._detect_batch_size()
|
1196 |
+
print(f"Determined Largest batch size: {batch_size}")
|
1197 |
+
adaptive_batch_size = batch_size
|
1198 |
+
# for each different set of kwargs, we execute all requests, by batch.
|
1199 |
+
batch_size = (
|
1200 |
+
self.batch_size
|
1201 |
+
if self.batch_size != "auto"
|
1202 |
+
else adaptive_batch_size
|
1203 |
+
if adaptive_batch_size is not None
|
1204 |
+
else 0
|
1205 |
+
)
|
1206 |
+
batch_fn = (
|
1207 |
+
self._batch_scheduler
|
1208 |
+
if self.batch_size == "auto" and not adaptive_batch_size
|
1209 |
+
else None
|
1210 |
+
)
|
1211 |
+
|
1212 |
+
# we group requests by their generation_kwargs,
|
1213 |
+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
|
1214 |
+
# in the same batch.
|
1215 |
+
# group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
|
1216 |
+
re_ords = Collator(
|
1217 |
+
[reg.args for reg in requests],
|
1218 |
+
sort_fn=_collate,
|
1219 |
+
group_by="gen_kwargs",
|
1220 |
+
group_fn=lambda x: x[1],
|
1221 |
+
)
|
1222 |
+
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
|
1223 |
+
for chunk in chunks:
|
1224 |
+
contexts, all_gen_kwargs = zip(*chunk)
|
1225 |
+
# we assume all gen kwargs in the batch are the same
|
1226 |
+
# this is safe to assume because the `grouper` object ensures it.
|
1227 |
+
gen_kwargs = all_gen_kwargs[0]
|
1228 |
+
# unpack our keyword arguments.
|
1229 |
+
until = None
|
1230 |
+
if isinstance(gen_kwargs, dict):
|
1231 |
+
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
|
1232 |
+
if "until" in kwargs.keys():
|
1233 |
+
until = kwargs.pop("until")
|
1234 |
+
if isinstance(until, str):
|
1235 |
+
until = [until]
|
1236 |
+
elif not isinstance(until, list):
|
1237 |
+
raise ValueError(
|
1238 |
+
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
|
1239 |
+
)
|
1240 |
+
else:
|
1241 |
+
raise ValueError(
|
1242 |
+
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
|
1243 |
+
)
|
1244 |
+
# add EOS token to stop sequences
|
1245 |
+
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
|
1246 |
+
if not until:
|
1247 |
+
until = [eos]
|
1248 |
+
else:
|
1249 |
+
until.append(eos)
|
1250 |
+
if "max_gen_toks" in kwargs.keys():
|
1251 |
+
max_gen_toks = kwargs.pop("max_gen_toks")
|
1252 |
+
else:
|
1253 |
+
max_gen_toks = self.max_gen_toks
|
1254 |
+
|
1255 |
+
# set the max length in tokens of inputs ("context_enc")
|
1256 |
+
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
|
1257 |
+
# max len for inputs = max length, minus room to generate the max new tokens
|
1258 |
+
max_ctx_len = self.max_length - max_gen_toks
|
1259 |
+
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
|
1260 |
+
# max len for inputs = encoder's whole max_length
|
1261 |
+
max_ctx_len = self.max_length
|
1262 |
+
|
1263 |
+
# encode, pad, and truncate contexts for this batch
|
1264 |
+
context_enc, attn_masks = self.tok_batch_encode(
|
1265 |
+
contexts,
|
1266 |
+
left_truncate_len=max_ctx_len,
|
1267 |
+
truncation=self.truncation,
|
1268 |
+
)
|
1269 |
+
context_enc = context_enc.to(self.device)
|
1270 |
+
attn_masks = attn_masks.to(self.device)
|
1271 |
+
|
1272 |
+
if "max_length" not in kwargs:
|
1273 |
+
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
|
1274 |
+
|
1275 |
+
# perform batched generation
|
1276 |
+
cont = self._model_generate(
|
1277 |
+
context=context_enc,
|
1278 |
+
attention_mask=attn_masks,
|
1279 |
+
stop=until,
|
1280 |
+
**kwargs,
|
1281 |
+
)
|
1282 |
+
|
1283 |
+
cont_toks_list = cont.tolist()
|
1284 |
+
for cont_toks, context in zip(cont_toks_list, contexts):
|
1285 |
+
# discard context + left-padding toks if using causal decoder-only LM
|
1286 |
+
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
|
1287 |
+
cont_toks = cont_toks[context_enc.shape[1] :]
|
1288 |
+
|
1289 |
+
s = self.tok_decode(cont_toks)
|
1290 |
+
|
1291 |
+
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
|
1292 |
+
for term in until:
|
1293 |
+
if len(term) > 0:
|
1294 |
+
# ignore '' separator,
|
1295 |
+
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
|
1296 |
+
s = s.split(term)[0]
|
1297 |
+
|
1298 |
+
res.append(s)
|
1299 |
+
|
1300 |
+
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
|
1301 |
+
pbar.update(1)
|
1302 |
+
# reorder this group of results back to original unsorted form
|
1303 |
+
res = re_ords.get_original(res)
|
1304 |
+
|
1305 |
+
pbar.close()
|
1306 |
+
|
1307 |
+
return res
|
1308 |
+
|
1309 |
+
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
|
1310 |
+
"""
|
1311 |
+
Method to apply a chat template to a list of chat history between user and model.
|
1312 |
+
"""
|
1313 |
+
return self.tokenizer.apply_chat_template(
|
1314 |
+
chat_history, tokenize=False, add_generation_prompt=True
|
1315 |
+
)
|
1316 |
+
|
1317 |
+
def get_model_info(self) -> dict:
|
1318 |
+
"""
|
1319 |
+
Method to get Hugging Face model information for experiment reproducibility.
|
1320 |
+
"""
|
1321 |
+
|
1322 |
+
def get_model_num_params(model) -> int:
|
1323 |
+
if hasattr(model, "num_parameters"):
|
1324 |
+
return model.num_parameters()
|
1325 |
+
if hasattr(model, "parameters"):
|
1326 |
+
return sum(p.numel() for p in model.parameters())
|
1327 |
+
else:
|
1328 |
+
return -1
|
1329 |
+
|
1330 |
+
def get_model_dtype(model) -> str:
|
1331 |
+
if hasattr(model, "dtype"):
|
1332 |
+
return model.dtype
|
1333 |
+
else:
|
1334 |
+
return ""
|
1335 |
+
|
1336 |
+
def get_model_sha(pretrained: str, revision: str) -> str:
|
1337 |
+
try:
|
1338 |
+
model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
|
1339 |
+
return model_info.sha
|
1340 |
+
except Exception as e:
|
1341 |
+
eval_logger.warn(
|
1342 |
+
f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
|
1343 |
+
)
|
1344 |
+
return ""
|
1345 |
+
|
1346 |
+
model_info = {
|
1347 |
+
"model_num_parameters": get_model_num_params(self._model),
|
1348 |
+
"model_dtype": get_model_dtype(self._model),
|
1349 |
+
"model_revision": self.revision,
|
1350 |
+
"model_sha": get_model_sha(self.pretrained, self.revision),
|
1351 |
+
}
|
1352 |
+
if self.peft:
|
1353 |
+
model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
|
1354 |
+
if self.delta:
|
1355 |
+
model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
|
1356 |
+
return model_info
|
scripts/yans/lm-evaluation-harness/lm_eval/models/nemo_lm.py
ADDED
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import importlib
|
16 |
+
import pathlib
|
17 |
+
from copy import deepcopy
|
18 |
+
from typing import List, Literal
|
19 |
+
|
20 |
+
import filelock
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
from lm_eval.api.instance import Instance
|
26 |
+
from lm_eval.api.model import LM
|
27 |
+
from lm_eval.api.registry import register_model
|
28 |
+
from lm_eval.models.utils import Collator
|
29 |
+
from lm_eval.utils import (
|
30 |
+
eval_logger,
|
31 |
+
get_rolling_token_windows,
|
32 |
+
make_disjoint_window,
|
33 |
+
simple_parse_args_string,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def _patch_pretrained_cfg(
|
38 |
+
pretrained_cfg, trainer, tensor_model_parallel_size, pipeline_model_parallel_size
|
39 |
+
):
|
40 |
+
try:
|
41 |
+
import omegaconf
|
42 |
+
except ModuleNotFoundError:
|
43 |
+
raise Exception(
|
44 |
+
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
|
45 |
+
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
|
46 |
+
"or installing nemo following https://github.com/NVIDIA/NeMo.",
|
47 |
+
)
|
48 |
+
|
49 |
+
omegaconf.OmegaConf.set_struct(pretrained_cfg, True)
|
50 |
+
with omegaconf.open_dict(pretrained_cfg):
|
51 |
+
attributes_to_update = {
|
52 |
+
"sequence_parallel": False,
|
53 |
+
"activations_checkpoint_granularity": None,
|
54 |
+
"activations_checkpoint_method": None,
|
55 |
+
"precision": trainer.precision,
|
56 |
+
"global_batch_size": None,
|
57 |
+
"tensor_model_parallel_size": tensor_model_parallel_size,
|
58 |
+
"pipeline_model_parallel_size": pipeline_model_parallel_size,
|
59 |
+
"apply_rope_fusion": False,
|
60 |
+
}
|
61 |
+
for name, value in attributes_to_update.items():
|
62 |
+
if hasattr(pretrained_cfg, name):
|
63 |
+
pretrained_cfg[name] = value
|
64 |
+
return pretrained_cfg
|
65 |
+
|
66 |
+
|
67 |
+
def _get_target_from_class(target_class) -> str:
|
68 |
+
return f"{target_class.__module__}.{target_class.__name__}"
|
69 |
+
|
70 |
+
|
71 |
+
def load_model(
|
72 |
+
model_path: str,
|
73 |
+
trainer,
|
74 |
+
tensor_model_parallel_size: int,
|
75 |
+
pipeline_model_parallel_size: int,
|
76 |
+
) -> torch.nn.Module:
|
77 |
+
try:
|
78 |
+
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import (
|
79 |
+
MegatronGPTModel,
|
80 |
+
)
|
81 |
+
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
|
82 |
+
except ModuleNotFoundError:
|
83 |
+
raise Exception(
|
84 |
+
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
|
85 |
+
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
|
86 |
+
"or installing nemo following https://github.com/NVIDIA/NeMo.",
|
87 |
+
)
|
88 |
+
model_path = pathlib.Path(model_path)
|
89 |
+
|
90 |
+
save_restore_connector = NLPSaveRestoreConnector()
|
91 |
+
if model_path.is_dir():
|
92 |
+
save_restore_connector.model_extracted_dir = model_path.as_posix()
|
93 |
+
pretrained_cfg = save_restore_connector.restore_from(
|
94 |
+
None, model_path.as_posix(), return_config=True, trainer=trainer
|
95 |
+
)
|
96 |
+
if not hasattr(pretrained_cfg, "target"):
|
97 |
+
pretrained_cfg["target"] = _get_target_from_class(MegatronGPTModel)
|
98 |
+
|
99 |
+
pretrained_cfg = _patch_pretrained_cfg(
|
100 |
+
pretrained_cfg,
|
101 |
+
trainer,
|
102 |
+
tensor_model_parallel_size=tensor_model_parallel_size,
|
103 |
+
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
104 |
+
)
|
105 |
+
|
106 |
+
model_to_load_path = model_path
|
107 |
+
override_config = pretrained_cfg
|
108 |
+
|
109 |
+
module_name, class_name = override_config.target.rsplit(".", 1)
|
110 |
+
model_class = getattr(importlib.import_module(module_name), class_name)
|
111 |
+
|
112 |
+
# monkeypatch _build_tokenizer method to be process-safe
|
113 |
+
tokenizer_lock = filelock.FileLock(f"/tmp/{model_path.name}.tokenizer.lock")
|
114 |
+
|
115 |
+
def _synced_build_tokenizer(self):
|
116 |
+
with tokenizer_lock:
|
117 |
+
self._original_build_tokenizer()
|
118 |
+
|
119 |
+
model_class._original_build_tokenizer = model_class._build_tokenizer
|
120 |
+
model_class._build_tokenizer = _synced_build_tokenizer
|
121 |
+
|
122 |
+
model = model_class.restore_from(
|
123 |
+
restore_path=model_to_load_path.as_posix(),
|
124 |
+
trainer=trainer,
|
125 |
+
override_config_path=override_config,
|
126 |
+
save_restore_connector=save_restore_connector,
|
127 |
+
map_location=f"cuda:{trainer.local_rank}",
|
128 |
+
)
|
129 |
+
|
130 |
+
model.freeze()
|
131 |
+
model.training = False
|
132 |
+
try:
|
133 |
+
# Have to turn off activations_checkpoint_method for inference
|
134 |
+
model.model.language_model.encoder.activations_checkpoint_method = None
|
135 |
+
except AttributeError:
|
136 |
+
pass
|
137 |
+
return model
|
138 |
+
|
139 |
+
|
140 |
+
def setup_distributed_environment(trainer):
|
141 |
+
try:
|
142 |
+
from nemo.utils.app_state import AppState
|
143 |
+
except ModuleNotFoundError:
|
144 |
+
raise Exception(
|
145 |
+
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
|
146 |
+
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
|
147 |
+
"or installing nemo following https://github.com/NVIDIA/NeMo.",
|
148 |
+
)
|
149 |
+
|
150 |
+
def dummy():
|
151 |
+
return
|
152 |
+
|
153 |
+
if trainer.strategy.launcher is not None:
|
154 |
+
trainer.strategy.launcher.launch(dummy, trainer=trainer)
|
155 |
+
trainer.strategy.setup_environment()
|
156 |
+
|
157 |
+
app_state = AppState()
|
158 |
+
|
159 |
+
return app_state
|
160 |
+
|
161 |
+
|
162 |
+
@register_model("nemo_lm")
|
163 |
+
class NeMoLM(LM):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
path: str,
|
167 |
+
max_length: int = 4096,
|
168 |
+
batch_size: int = 1,
|
169 |
+
max_gen_toks: int = 256,
|
170 |
+
devices: int = 1,
|
171 |
+
num_nodes: int = 1,
|
172 |
+
tensor_model_parallel_size: int = 1,
|
173 |
+
pipeline_model_parallel_size: int = 1,
|
174 |
+
precision: Literal[
|
175 |
+
"16-mixed",
|
176 |
+
"bf16-mixed",
|
177 |
+
"32-true",
|
178 |
+
"64-true",
|
179 |
+
64,
|
180 |
+
32,
|
181 |
+
16,
|
182 |
+
"64",
|
183 |
+
"32",
|
184 |
+
"16",
|
185 |
+
"bf16",
|
186 |
+
] = "bf16",
|
187 |
+
**kwargs,
|
188 |
+
):
|
189 |
+
try:
|
190 |
+
from nemo.collections.nlp.modules.common.text_generation_utils import (
|
191 |
+
generate,
|
192 |
+
)
|
193 |
+
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
|
194 |
+
from pytorch_lightning.trainer.trainer import Trainer
|
195 |
+
|
196 |
+
self.generate = generate
|
197 |
+
except ModuleNotFoundError:
|
198 |
+
raise Exception(
|
199 |
+
"Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
|
200 |
+
"Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
|
201 |
+
"or installing nemo following https://github.com/NVIDIA/NeMo.",
|
202 |
+
)
|
203 |
+
|
204 |
+
super().__init__()
|
205 |
+
|
206 |
+
if (
|
207 |
+
tensor_model_parallel_size == 1
|
208 |
+
and pipeline_model_parallel_size == 1
|
209 |
+
and devices > 1
|
210 |
+
):
|
211 |
+
eval_logger.info(
|
212 |
+
f"The number of data replicas for evaluation is {devices}."
|
213 |
+
)
|
214 |
+
eval_logger.info(f"The total number of devices is {devices}.")
|
215 |
+
eval_logger.info(
|
216 |
+
"No tensor parallelism or pipeline parallelism is applied."
|
217 |
+
)
|
218 |
+
|
219 |
+
elif tensor_model_parallel_size * pipeline_model_parallel_size == devices:
|
220 |
+
eval_logger.info(
|
221 |
+
f"Setting tensor parallelism to {tensor_model_parallel_size} and pipeline parallelism to {pipeline_model_parallel_size}."
|
222 |
+
)
|
223 |
+
eval_logger.info(f"The total number of devices is {devices}.")
|
224 |
+
eval_logger.info("No data parallelism is applied.")
|
225 |
+
|
226 |
+
else:
|
227 |
+
raise ValueError(
|
228 |
+
"Please set the product of tensor_model_parallel_size and pipeline_model_parallel_size"
|
229 |
+
"equal to the specified number of devices."
|
230 |
+
)
|
231 |
+
|
232 |
+
if num_nodes > 1:
|
233 |
+
raise ValueError(
|
234 |
+
"A number of nodes greater than 1 is not supported yet. Please set num_nodes as 1."
|
235 |
+
)
|
236 |
+
|
237 |
+
trainer = Trainer(
|
238 |
+
strategy=NLPDDPStrategy(),
|
239 |
+
devices=devices,
|
240 |
+
accelerator="gpu",
|
241 |
+
num_nodes=num_nodes,
|
242 |
+
precision=precision,
|
243 |
+
logger=False,
|
244 |
+
enable_checkpointing=False,
|
245 |
+
use_distributed_sampler=False,
|
246 |
+
)
|
247 |
+
# Modify the following flags only for data replication
|
248 |
+
if (
|
249 |
+
tensor_model_parallel_size == 1
|
250 |
+
and pipeline_model_parallel_size == 1
|
251 |
+
and devices > 1
|
252 |
+
):
|
253 |
+
self._device = torch.device(f"cuda:{trainer.global_rank}")
|
254 |
+
self._rank = trainer.global_rank
|
255 |
+
self._world_size = trainer.world_size
|
256 |
+
self.model = load_model(
|
257 |
+
path,
|
258 |
+
trainer,
|
259 |
+
tensor_model_parallel_size=tensor_model_parallel_size,
|
260 |
+
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
261 |
+
).cuda()
|
262 |
+
self.tokenizer = self.model.tokenizer
|
263 |
+
self.app_state = setup_distributed_environment(trainer)
|
264 |
+
|
265 |
+
self._max_length = max_length
|
266 |
+
self._batch_size = int(batch_size)
|
267 |
+
self._max_gen_toks = max_gen_toks
|
268 |
+
|
269 |
+
@classmethod
|
270 |
+
def create_from_arg_string(cls, arg_string, additional_config=None):
|
271 |
+
args = simple_parse_args_string(arg_string)
|
272 |
+
if additional_config:
|
273 |
+
args["batch_size"] = additional_config.get("batch_size", 1)
|
274 |
+
|
275 |
+
return cls(**args)
|
276 |
+
|
277 |
+
@property
|
278 |
+
def eot_token_id(self):
|
279 |
+
try:
|
280 |
+
return self.tokenizer.eos_id
|
281 |
+
except AttributeError:
|
282 |
+
return None
|
283 |
+
|
284 |
+
@property
|
285 |
+
def max_length(self):
|
286 |
+
return self._max_length
|
287 |
+
|
288 |
+
@property
|
289 |
+
def max_gen_toks(self):
|
290 |
+
return self._max_gen_toks
|
291 |
+
|
292 |
+
@property
|
293 |
+
def batch_size(self):
|
294 |
+
return self._batch_size
|
295 |
+
|
296 |
+
@property
|
297 |
+
def device(self):
|
298 |
+
return self._device
|
299 |
+
|
300 |
+
@property
|
301 |
+
def rank(self):
|
302 |
+
return self._rank
|
303 |
+
|
304 |
+
@property
|
305 |
+
def world_size(self):
|
306 |
+
return self._world_size
|
307 |
+
|
308 |
+
@property
|
309 |
+
def accelerator(self):
|
310 |
+
return self._Accelerator(self.world_size)
|
311 |
+
|
312 |
+
class _Accelerator:
|
313 |
+
def __init__(self, world_size):
|
314 |
+
self.world_size = world_size
|
315 |
+
|
316 |
+
def wait_for_everyone(self):
|
317 |
+
torch.distributed.barrier()
|
318 |
+
|
319 |
+
def gather(self, local_tensor):
|
320 |
+
gathered_tensors = [
|
321 |
+
torch.zeros(1, dtype=local_tensor.dtype).cuda()
|
322 |
+
for _ in range(self.world_size)
|
323 |
+
]
|
324 |
+
torch.distributed.all_gather(gathered_tensors, local_tensor)
|
325 |
+
return torch.cat(gathered_tensors)
|
326 |
+
|
327 |
+
def tok_encode(self, string: str):
|
328 |
+
return self.tokenizer.text_to_ids(string)
|
329 |
+
|
330 |
+
def tok_decode(self, tokens):
|
331 |
+
return self.tokenizer.ids_to_text(tokens)
|
332 |
+
|
333 |
+
def _encode_pair(self, context, continuation):
|
334 |
+
n_spaces = len(context) - len(context.rstrip())
|
335 |
+
if n_spaces > 0:
|
336 |
+
continuation = context[-n_spaces:] + continuation
|
337 |
+
context = context[:-n_spaces]
|
338 |
+
whole_enc = self.tok_encode(context + continuation)
|
339 |
+
context_enc = self.tok_encode(context)
|
340 |
+
context_enc_len = len(context_enc)
|
341 |
+
continuation_enc = whole_enc[context_enc_len:]
|
342 |
+
return context_enc, continuation_enc
|
343 |
+
|
344 |
+
def loglikelihood(self, requests):
|
345 |
+
new_reqs = []
|
346 |
+
for context, continuation in [req.args for req in requests]:
|
347 |
+
if context == "":
|
348 |
+
# end of text as context
|
349 |
+
context_enc, continuation_enc = (
|
350 |
+
[self.eot_token_id],
|
351 |
+
self.tok_encode(continuation),
|
352 |
+
)
|
353 |
+
else:
|
354 |
+
context_enc, continuation_enc = self._encode_pair(context, continuation)
|
355 |
+
|
356 |
+
new_reqs.append(((context, continuation), context_enc, continuation_enc))
|
357 |
+
|
358 |
+
return self._loglikelihood_tokens(new_reqs)
|
359 |
+
|
360 |
+
def loglikelihood_rolling(
|
361 |
+
self, requests: List[Instance], disable_tqdm: bool = False
|
362 |
+
) -> List[float]:
|
363 |
+
loglikelihoods = []
|
364 |
+
|
365 |
+
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
|
366 |
+
rolling_token_windows = list(
|
367 |
+
map(
|
368 |
+
make_disjoint_window,
|
369 |
+
get_rolling_token_windows(
|
370 |
+
token_list=self.tok_encode(string),
|
371 |
+
prefix_token=self.eot_token_id,
|
372 |
+
max_seq_len=self.max_length - 1,
|
373 |
+
context_len=1,
|
374 |
+
),
|
375 |
+
)
|
376 |
+
)
|
377 |
+
|
378 |
+
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
|
379 |
+
|
380 |
+
string_nll = self._loglikelihood_tokens(
|
381 |
+
rolling_token_windows,
|
382 |
+
)
|
383 |
+
|
384 |
+
# discard is_greedy
|
385 |
+
string_nll = [x[0] for x in string_nll]
|
386 |
+
|
387 |
+
string_nll = sum(string_nll)
|
388 |
+
loglikelihoods.append(string_nll)
|
389 |
+
return loglikelihoods
|
390 |
+
|
391 |
+
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
|
392 |
+
res = []
|
393 |
+
|
394 |
+
def _collate(x):
|
395 |
+
toks = x[1] + x[2]
|
396 |
+
return -len(toks), tuple(toks)
|
397 |
+
|
398 |
+
re_ord = Collator(requests, sort_fn=_collate)
|
399 |
+
chunks = re_ord.get_batched(n=self.batch_size, batch_fn=None)
|
400 |
+
pbar = tqdm(
|
401 |
+
total=len(requests),
|
402 |
+
disable=(disable_tqdm or (self.rank != 0)),
|
403 |
+
desc="Running loglikelihood requests",
|
404 |
+
)
|
405 |
+
for chunk in chunks:
|
406 |
+
inps = []
|
407 |
+
ctxlens = []
|
408 |
+
contlens = []
|
409 |
+
|
410 |
+
for _, context_enc, continuation_enc in chunk:
|
411 |
+
# Leave one token for generation. Tokens_to_generate = 0 breaks NeMo.
|
412 |
+
inp = (context_enc + continuation_enc)[-(self.max_length - 1) :]
|
413 |
+
|
414 |
+
ctxlen = len(context_enc) - max(
|
415 |
+
0, len(context_enc) + len(continuation_enc) - (self.max_length - 1)
|
416 |
+
)
|
417 |
+
ctxlens.append(ctxlen)
|
418 |
+
contlens.append(len(continuation_enc))
|
419 |
+
|
420 |
+
inps.append(self.tok_decode(inp))
|
421 |
+
|
422 |
+
output = self.generate(
|
423 |
+
self.model,
|
424 |
+
inputs=inps,
|
425 |
+
tokens_to_generate=1,
|
426 |
+
min_tokens_to_generate=1,
|
427 |
+
compute_logprob=True,
|
428 |
+
all_probs=True,
|
429 |
+
)
|
430 |
+
|
431 |
+
batch_token_ids = np.asarray(output["token_ids"])[:, :-1]
|
432 |
+
batch_logprobs = output["logprob"][:, :-1]
|
433 |
+
batch_full_logprob = output["full_logprob"][:, :-1, :]
|
434 |
+
|
435 |
+
# Compute greedy tokens for entire batch rather than calling it with proper ctxlen for each sample.
|
436 |
+
# Additional tokens for each sample will be trimmed later.
|
437 |
+
min_ctxlen = min(ctxlens)
|
438 |
+
|
439 |
+
# Use min_ctxlen-1 instead of min_ctxlen since full_logprobs are not returns for the first token.
|
440 |
+
batch_greedy_tokens = (
|
441 |
+
torch.argmax(batch_full_logprob[:, min_ctxlen - 1 :, :], -1)
|
442 |
+
.cpu()
|
443 |
+
.numpy()
|
444 |
+
)
|
445 |
+
|
446 |
+
for token_ids, greedy_tokens, logprobs, ctxlen, contlen, (
|
447 |
+
cache_key,
|
448 |
+
_,
|
449 |
+
_,
|
450 |
+
) in zip(
|
451 |
+
batch_token_ids,
|
452 |
+
batch_greedy_tokens,
|
453 |
+
batch_logprobs,
|
454 |
+
ctxlens,
|
455 |
+
contlens,
|
456 |
+
chunk,
|
457 |
+
):
|
458 |
+
# Trim at contlen since shorter contexts in a batch will have more than one token generated.
|
459 |
+
# Use ctxlen-1 instead of ctxlen same as for full_logprob in batch_greedy_tokens calculation
|
460 |
+
logprobs = (logprobs[ctxlen - 1 :])[:contlen]
|
461 |
+
logprob = sum(logprobs).tolist()
|
462 |
+
|
463 |
+
continuation_tokens = (token_ids[ctxlen:])[:contlen]
|
464 |
+
len_diff = ctxlen - min_ctxlen
|
465 |
+
is_greedy = continuation_tokens == (greedy_tokens[len_diff:])[:contlen]
|
466 |
+
if not isinstance(is_greedy, bool):
|
467 |
+
is_greedy = is_greedy.all()
|
468 |
+
answer = (logprob, is_greedy)
|
469 |
+
|
470 |
+
if cache_key is not None:
|
471 |
+
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
|
472 |
+
|
473 |
+
res.append(answer)
|
474 |
+
pbar.update(1)
|
475 |
+
|
476 |
+
pbar.close()
|
477 |
+
|
478 |
+
return re_ord.get_original(res)
|
479 |
+
|
480 |
+
def generate_until(self, requests):
|
481 |
+
if not requests:
|
482 |
+
return []
|
483 |
+
res = []
|
484 |
+
|
485 |
+
def get_until(req_args):
|
486 |
+
until = req_args.get("until", [])
|
487 |
+
until = deepcopy(until) # prevent from modifying req_args for cache_key
|
488 |
+
if self.tokenizer.ids_to_tokens([self.eot_token_id])[0] not in until:
|
489 |
+
until.append(self.tokenizer.ids_to_tokens([self.eot_token_id])[0])
|
490 |
+
return until
|
491 |
+
|
492 |
+
def _collate(x):
|
493 |
+
toks = self.tok_encode(x[0])
|
494 |
+
return len(toks), x[0]
|
495 |
+
|
496 |
+
re_ords = Collator(
|
497 |
+
[reg.args for reg in requests], sort_fn=_collate, group_by="gen_kwargs"
|
498 |
+
)
|
499 |
+
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
|
500 |
+
for chunk in chunks:
|
501 |
+
contexts, all_gen_kwargs = zip(*chunk)
|
502 |
+
# we assume all gen kwargs in the batch are the same
|
503 |
+
# this is safe to assume because the `grouper` object ensures it.
|
504 |
+
req_args = all_gen_kwargs[0]
|
505 |
+
# unpack our keyword arguments.
|
506 |
+
until = get_until(req_args)
|
507 |
+
max_gen_toks = req_args.get("max_gen_toks", self.max_gen_toks)
|
508 |
+
|
509 |
+
remaining_length = self.max_length - max_gen_toks
|
510 |
+
contexts = []
|
511 |
+
for context, _ in chunk:
|
512 |
+
encoded_context = self.tok_encode(context)
|
513 |
+
encoded_context = encoded_context[-remaining_length:]
|
514 |
+
contexts.append(self.tok_decode(encoded_context))
|
515 |
+
|
516 |
+
output = self.generate(
|
517 |
+
self.model,
|
518 |
+
inputs=contexts,
|
519 |
+
tokens_to_generate=max_gen_toks,
|
520 |
+
end_strings=until,
|
521 |
+
greedy=True,
|
522 |
+
)
|
523 |
+
|
524 |
+
answers = output["sentences"]
|
525 |
+
|
526 |
+
continuations = []
|
527 |
+
for context, answer in zip(contexts, answers):
|
528 |
+
continuations.append(answer[len(context) :])
|
529 |
+
|
530 |
+
for term in until:
|
531 |
+
continuations = [answer.split(term)[0] for answer in continuations]
|
532 |
+
|
533 |
+
for request, answer in zip(chunk, continuations):
|
534 |
+
self.cache_hook.add_partial("greedy_until", request, answer)
|
535 |
+
res.append(answer)
|
536 |
+
|
537 |
+
return re_ords.get_original(res)
|