koichi12 commited on
Commit
1d13cae
·
verified ·
1 Parent(s): 84a9380

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. scripts/decode/en-ja/llama2/beam_search.sh +19 -0
  2. scripts/decode/en-ja/llama2/greedy_inference.sh +13 -0
  3. scripts/decode/en-ja/llama2/hf_inference.sh +13 -0
  4. scripts/decode/en-ja/llama2/top_p_inference.sh +17 -0
  5. scripts/decode/en-ja/llama2/top_p_inference_1.sh +20 -0
  6. scripts/decode/en-ja/llama2/top_p_inference_2.sh +21 -0
  7. scripts/decode/en-ja/mistral-ve/top_p_inference.sh +16 -0
  8. scripts/decode/en-ja/mistral-ve/top_p_inference_cpo.sh +17 -0
  9. scripts/decode/en-ja/mistral/top_p_inference_2.sh +20 -0
  10. scripts/yans/lm-evaluation-harness/.github/workflows/new_tasks.yml +72 -0
  11. scripts/yans/lm-evaluation-harness/.github/workflows/publish.yml +78 -0
  12. scripts/yans/lm-evaluation-harness/.github/workflows/unit_tests.yml +95 -0
  13. scripts/yans/lm-evaluation-harness/lm_eval/api/__init__.py +0 -0
  14. scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/__init__.cpython-310.pyc +0 -0
  15. scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/filter.cpython-310.pyc +0 -0
  16. scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/group.cpython-310.pyc +0 -0
  17. scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/instance.cpython-310.pyc +0 -0
  18. scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/metrics.cpython-310.pyc +0 -0
  19. scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/model.cpython-310.pyc +0 -0
  20. scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/registry.cpython-310.pyc +0 -0
  21. scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/samplers.cpython-310.pyc +0 -0
  22. scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/task.cpython-310.pyc +0 -0
  23. scripts/yans/lm-evaluation-harness/lm_eval/api/filter.py +56 -0
  24. scripts/yans/lm-evaluation-harness/lm_eval/api/group.py +117 -0
  25. scripts/yans/lm-evaluation-harness/lm_eval/api/instance.py +38 -0
  26. scripts/yans/lm-evaluation-harness/lm_eval/api/metrics.py +570 -0
  27. scripts/yans/lm-evaluation-harness/lm_eval/api/model.py +385 -0
  28. scripts/yans/lm-evaluation-harness/lm_eval/api/registry.py +192 -0
  29. scripts/yans/lm-evaluation-harness/lm_eval/api/samplers.py +198 -0
  30. scripts/yans/lm-evaluation-harness/lm_eval/api/task.py +1674 -0
  31. scripts/yans/lm-evaluation-harness/lm_eval/models/__init__.py +28 -0
  32. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/__init__.cpython-310.pyc +0 -0
  33. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/anthropic_llms.cpython-310.pyc +0 -0
  34. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/api_models.cpython-310.pyc +0 -0
  35. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/dummy.cpython-310.pyc +0 -0
  36. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/gguf.cpython-310.pyc +0 -0
  37. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/huggingface.cpython-310.pyc +0 -0
  38. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/mamba_lm.cpython-310.pyc +0 -0
  39. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/nemo_lm.cpython-310.pyc +0 -0
  40. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/neuralmagic.cpython-310.pyc +0 -0
  41. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/neuron_optimum.cpython-310.pyc +0 -0
  42. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/openai_completions.cpython-310.pyc +0 -0
  43. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/optimum_lm.cpython-310.pyc +0 -0
  44. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/textsynth.cpython-310.pyc +0 -0
  45. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/utils.cpython-310.pyc +0 -0
  46. scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/vllm_causallms.cpython-310.pyc +0 -0
  47. scripts/yans/lm-evaluation-harness/lm_eval/models/anthropic_llms.py +362 -0
  48. scripts/yans/lm-evaluation-harness/lm_eval/models/api_models.py +641 -0
  49. scripts/yans/lm-evaluation-harness/lm_eval/models/huggingface.py +1356 -0
  50. scripts/yans/lm-evaluation-harness/lm_eval/models/nemo_lm.py +537 -0
scripts/decode/en-ja/llama2/beam_search.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -eux
2
+ LLM_RECIPES_DIR=/code/llm-recipes
3
+ source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
4
+
5
+ MAX_INPUT_TOKENS=158
6
+ BEAM_SIZE=50
7
+
8
+ python /code/llm-recipes/tools/hf_inference_distrubuted.py \
9
+ --model /work/models/additiona_trained_hf/llama2-en-ja-continuous-pretrained-v0-dev-finetune-chunked-docs-all-averaged-841-845 \
10
+ -i /work/wmt2024_test/LLM/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl \
11
+ -o /work/translation/wmt2024_test/en-ja/llama2-beam \
12
+ -g 0 1 2 3 4 5 6 7 \
13
+ --attn_implementation sdpa \
14
+ --dynamic_max_new_token_ratio 3.0 \
15
+ --num_return_sequences ${BEAM_SIZE} \
16
+ --num_beams ${BEAM_SIZE} \
17
+ --max_input_tokens ${MAX_INPUT_TOKENS} \
18
+ -b 158
19
+
scripts/decode/en-ja/llama2/greedy_inference.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LLM_RECIPES_DIR=/code/llm-recipes
2
+ source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
3
+
4
+ python /code/llm-recipes/tools/hf_inference.py \
5
+ --model /work/models/translation_finetuned_hf/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-chunked-docs-all-averaged-71-75 \
6
+ -i /work/wmt2024_test/LLM/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl \
7
+ -o /work/translation/wmt24_test/en-ja/mistral-greedy \
8
+ -g 0 \
9
+ -b 4096 \
10
+ --dynamic_max_new_token_ratio 3.0
11
+
12
+ echo "Done!"
13
+
scripts/decode/en-ja/llama2/hf_inference.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LLM_RECIPES_DIR=/code/llm-recipes
2
+ source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
3
+
4
+ python /code/llm-recipes/tools/hf_inference.py \
5
+ --model /work/models/translation_finetuned_hf/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-chunked-docs-all-averaged-71-75 \
6
+ -i /work/wmt2024_test/LLM/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl \
7
+ -o /work/translation/wmt24_test/en-ja/mistral-greedy \
8
+ -g 0 \
9
+ -b 4096 \
10
+ --dynamic_max_new_token_ratio 3.0
11
+
12
+ echo "Done!"
13
+
scripts/decode/en-ja/llama2/top_p_inference.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -eux
2
+ LLM_RECIPES_DIR=/code/llm-recipes
3
+ source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
4
+
5
+ i=4
6
+ GPU_ID=4
7
+ python /code/llm-recipes/tools/hf_inference.py \
8
+ --model /work/models/translation_finetuned_hf/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-chunked-docs-all-averaged-71-75 \
9
+ -i /work/wmt2024_test/LLM/split/en-ja/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl.0${i} \
10
+ -o /work/translation/wmt24_test/en-ja/mistral-top-p-0.95/split_0${i} \
11
+ -g ${GPU_ID} \
12
+ -b 500 \
13
+ --attn_implementation sdpa \
14
+ --dynamic_max_new_token_ratio 3.0 \
15
+ --num_return_sequences 100 \
16
+ --do_sample \
17
+ --top_p 0.95 &
scripts/decode/en-ja/llama2/top_p_inference_1.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -eux
2
+ LLM_RECIPES_DIR=/code/llm-recipes
3
+ source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
4
+
5
+ for i in `seq 0 6`; do
6
+ python /code/llm-recipes/tools/hf_inference.py \
7
+ --model /work/models/additiona_trained_hf/llama2-en-ja-continuous-pretrained-v0-dev-finetune-chunked-docs-all-averaged-841-845 \
8
+ -i /work/wmt2024_test/LLM/split/en-ja/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl.0${i} \
9
+ -o /work/translation/wmt24_test/en-ja/llama2-top-p-0.95/split_0${i} \
10
+ -g ${i} \
11
+ -b 158 \
12
+ --attn_implementation sdpa \
13
+ --dynamic_max_new_token_ratio 3.0 \
14
+ --num_return_sequences 50 \
15
+ --do_sample \
16
+ --top_p 0.95 \
17
+ --max_input_tokens 158 &
18
+ done
19
+ wait
20
+
scripts/decode/en-ja/llama2/top_p_inference_2.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -eux
2
+ LLM_RECIPES_DIR=/code/llm-recipes
3
+ source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
4
+
5
+ for i in `seq 7 9`; do
6
+ GPU_ID=$((i-5))
7
+ python /code/llm-recipes/tools/hf_inference.py \
8
+ --model /work/models/additiona_trained_hf/llama2-en-ja-continuous-pretrained-v0-dev-finetune-chunked-docs-all-averaged-841-845 \
9
+ -i /work/wmt2024_test/LLM/split/en-ja/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl.0${i} \
10
+ -o /work/translation/wmt24_test/en-ja/llama2-top-p-0.95/split_0${i} \
11
+ -g ${GPU_ID} \
12
+ -b 158 \
13
+ --attn_implementation sdpa \
14
+ --dynamic_max_new_token_ratio 3.0 \
15
+ --num_return_sequences 50 \
16
+ --do_sample \
17
+ --top_p 0.95 \
18
+ --max_input_tokens 158 &
19
+ done
20
+ wait
21
+
scripts/decode/en-ja/mistral-ve/top_p_inference.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -eux
2
+ LLM_RECIPES_DIR=/code/llm-recipes
3
+ source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
4
+
5
+ python /code/llm-recipes/tools/hf_inference_distrubuted.py \
6
+ --model /work/models/translation_finetuned_hf/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-ve-sim-chunked-docs-all-averaged-596-600 \
7
+ -i /work/wmt2024_test/LLM/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl \
8
+ -o /work/translation/wmt2024_test/en-ja/mistral-ve-top-p-0.95 \
9
+ -g 0 1 2 3 4 5 6 7 \
10
+ -b 125 \
11
+ --attn_implementation sdpa \
12
+ --dynamic_max_new_token_ratio 2.0 \
13
+ --num_return_sequences 80 \
14
+ --do_sample \
15
+ --top_p 0.95 \
16
+ --max_input_tokens 125
scripts/decode/en-ja/mistral-ve/top_p_inference_cpo.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -eux
2
+ LLM_RECIPES_DIR=/code/llm-recipes
3
+ source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
4
+
5
+ python /code/llm-recipes/tools/hf_inference_distrubuted.py \
6
+ --model /work/models/translation_finetuned_hf/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-ve-sim-chunked-docs-all-averaged-596-600 \
7
+ -i /work/wmt2024_test/LLM/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl \
8
+ -o /work/translation/wmt2024_test/en-ja/mistral-ve-top-p-0.95-cpo \
9
+ -p /work/models/dpo/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-ve-sim-chunked-docs-all-cpo-lora/checkpoint-200 \
10
+ -g 0 1 2 3 4 5 6 7 \
11
+ -b 125 \
12
+ --attn_implementation sdpa \
13
+ --dynamic_max_new_token_ratio 2.0 \
14
+ --num_return_sequences 80 \
15
+ --do_sample \
16
+ --top_p 0.95 \
17
+ --max_input_tokens 125 \
scripts/decode/en-ja/mistral/top_p_inference_2.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -eux
2
+ LLM_RECIPES_DIR=/code/llm-recipes
3
+ source $LLM_RECIPES_DIR/scripts/wmt2024/tokens.sh
4
+
5
+ for i in `seq 8 9`; do
6
+ # minus 2 for gpu id
7
+ GPU_ID=$((i-2))
8
+ python /code/llm-recipes/tools/hf_inference.py \
9
+ --model /work/models/translation_finetuned_hf/mistral-llm-recipes-en-ja-continuous-pretrained-v1-dev-finetune-chunked-docs-all-averaged-71-75 \
10
+ -i /work/wmt2024_test/LLM/split/en-ja/wmttest2024.src.sentence_splited.with_template.en-ja.en.jsonl.0${i} \
11
+ -o /work/translation/wmt24_test/en-ja/mistral-top-p-0.95/split_0${i} \
12
+ -g ${GPU_ID} \
13
+ -b 400 \
14
+ --attn_implementation sdpa \
15
+ --dynamic_max_new_token_ratio 3.0 \
16
+ --num_return_sequences 100 \
17
+ --do_sample \
18
+ --top_p 0.95 &
19
+ done
20
+ wait
scripts/yans/lm-evaluation-harness/.github/workflows/new_tasks.yml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Tasks Modified
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - 'main'
7
+ pull_request:
8
+ branches:
9
+ - 'main'
10
+ workflow_dispatch:
11
+ # comment/edit out the above to stop/change the triggers
12
+ jobs:
13
+ changed_files:
14
+ runs-on: ubuntu-latest # windows-latest || macos-latest
15
+ timeout-minutes: 120
16
+ name: Scan for changed tasks
17
+ steps:
18
+ - name: checkout
19
+ uses: actions/checkout@v3
20
+ with:
21
+ fetch-depth: 2 # OR "2" -> To retrieve the preceding commit.
22
+
23
+ # Uses the tj-actions/changed-files action to check for changes.
24
+ # Outputs provided here: https://github.com/tj-actions/changed-files#outputs
25
+ # The `files_yaml` input optionally takes a yaml string to specify filters,
26
+ # and prepends the filter name to the standard output names.
27
+ - name: Check task folders
28
+ id: changed-tasks
29
+ uses: tj-actions/changed-files@v44.5.2
30
+ with:
31
+ # tasks checks the tasks folder and api checks the api folder for changes
32
+ files_yaml: |
33
+ tasks:
34
+ - lm_eval/tasks/**
35
+ api:
36
+ - lm_eval/api/**
37
+ write_output_files: true
38
+
39
+ # The next step is optional; the files are written to the workspace by default (above).
40
+ # so it's just for debugging
41
+ - name: Run Tests
42
+ if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
43
+ run: |
44
+ echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV'
45
+ echo "One or more test file(s) has changed."
46
+ echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}"
47
+
48
+ - name: Set up Python 3.9
49
+ if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
50
+ uses: actions/setup-python@v4
51
+ with:
52
+ python-version: 3.9
53
+ cache: 'pip'
54
+ cache-dependency-path: setup.py
55
+ - name: Install dependencies
56
+ if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
57
+ run: |
58
+ python -m pip install --upgrade pip
59
+ pip install -e '.[dev,ifeval]' --extra-index-url https://download.pytorch.org/whl/cpu
60
+ # Install optional git dependencies
61
+ # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
62
+ # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
63
+ - name: Test with pytest
64
+ # if new tasks are added, run tests on them
65
+ if: steps.changed-tasks.outputs.tasks_any_modified == 'true'
66
+ run: python -m pytest tests/test_tasks.py -s -vv
67
+ # if api is modified, run tests on it
68
+ - name: Test more tasks with pytest
69
+ env:
70
+ API: true
71
+ if: steps.changed-tasks.outputs.api_any_modified == 'true'
72
+ run: python -m pytest tests/test_tasks.py -s -vv
scripts/yans/lm-evaluation-harness/.github/workflows/publish.yml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Publish Python distribution to PyPI
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - '*'
7
+
8
+ jobs:
9
+ build:
10
+ name: Build distribution
11
+ runs-on: ubuntu-latest
12
+
13
+ steps:
14
+ - uses: actions/checkout@v4
15
+ - name: Set up Python
16
+ uses: actions/setup-python@v4
17
+ with:
18
+ python-version: "3.x"
19
+
20
+ - name: Install pypa/build
21
+ run: >-
22
+ python3 -m
23
+ pip install
24
+ build
25
+ --user
26
+ - name: Build a binary wheel and a source tarball
27
+ run: python3 -m build
28
+ - name: Store the distribution packages
29
+ uses: actions/upload-artifact@v3
30
+ with:
31
+ name: python-package-distributions
32
+ path: dist/
33
+
34
+ publish-to-pypi:
35
+ name: >-
36
+ Publish Python distribution to PyPI
37
+ if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
38
+ needs:
39
+ - build
40
+ runs-on: ubuntu-latest
41
+ environment:
42
+ name: pypi
43
+ url: https://pypi.org/p/lm_eval
44
+ permissions:
45
+ id-token: write # IMPORTANT: mandatory for trusted publishing
46
+
47
+ steps:
48
+ - name: Download all the dists
49
+ uses: actions/download-artifact@v3
50
+ with:
51
+ name: python-package-distributions
52
+ path: dist/
53
+ - name: Publish distribution to PyPI
54
+ uses: pypa/gh-action-pypi-publish@release/v1
55
+
56
+ publish-to-testpypi:
57
+ name: Publish Python distribution to TestPyPI
58
+ needs:
59
+ - build
60
+ runs-on: ubuntu-latest
61
+
62
+ environment:
63
+ name: testpypi
64
+ url: https://test.pypi.org/p/lm_eval
65
+
66
+ permissions:
67
+ id-token: write # IMPORTANT: mandatory for trusted publishing
68
+
69
+ steps:
70
+ - name: Download all the dists
71
+ uses: actions/download-artifact@v3
72
+ with:
73
+ name: python-package-distributions
74
+ path: dist/
75
+ - name: Publish distribution to TestPyPI
76
+ uses: pypa/gh-action-pypi-publish@release/v1
77
+ with:
78
+ repository-url: https://test.pypi.org/legacy/
scripts/yans/lm-evaluation-harness/.github/workflows/unit_tests.yml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2
+ # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3
+ # just comment out unwanted steps to turn off the test.
4
+ name: Unit Tests
5
+
6
+ on:
7
+ push:
8
+ branches:
9
+ - 'main'
10
+ pull_request:
11
+ branches:
12
+ - 'main'
13
+ workflow_dispatch:
14
+ # Jobs run concurrently and steps run sequentially within a job.
15
+ # jobs: linter and cpu_tests. Add more jobs/steps as required.
16
+ jobs:
17
+ linter:
18
+ name: Linters
19
+ runs-on: ubuntu-latest
20
+ timeout-minutes: 5
21
+
22
+ steps:
23
+ - name: Checkout Code
24
+ uses: actions/checkout@v4
25
+ - name: Set up Python 3.8
26
+ uses: actions/setup-python@v5
27
+ with:
28
+ python-version: 3.8
29
+ cache: pip
30
+ cache-dependency-path: pyproject.toml
31
+ - name: Pre-Commit
32
+ env:
33
+ SKIP: "no-commit-to-branch,mypy"
34
+
35
+ uses: pre-commit/action@v3.0.1
36
+ # # mypy turned off for now
37
+ # - name: Lint with mypy
38
+ # run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
39
+ # Job 2
40
+ testcpu:
41
+ name: CPU Tests
42
+ runs-on: ubuntu-latest
43
+ strategy:
44
+ matrix:
45
+ python-version: [ "3.8", "3.9", "3.10", "3.11" ]
46
+ timeout-minutes: 30
47
+ steps:
48
+ - name: Checkout Code
49
+ uses: actions/checkout@v4
50
+ - name: Set up Python ${{ matrix.python-version }}
51
+ uses: actions/setup-python@v5
52
+ with:
53
+ python-version: ${{ matrix.python-version }}
54
+ cache: pip
55
+ cache-dependency-path: pyproject.toml
56
+ - name: Install dependencies
57
+ run: |
58
+ python -m pip install --upgrade pip
59
+ pip install -e '.[dev,sentencepiece,api]' --extra-index-url https://download.pytorch.org/whl/cpu
60
+ # Install optional git dependencies
61
+ # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
62
+ # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
63
+ - name: Test with pytest
64
+ run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/models/test_neuralmagic.py --ignore=tests/models/test_openvino.py
65
+ - name: Archive artifacts
66
+ uses: actions/upload-artifact@v3
67
+ with:
68
+ name: output_results
69
+ path: |
70
+ test_logs/*
71
+ testmodels:
72
+ name: External LM Tests
73
+ runs-on: ubuntu-latest
74
+ timeout-minutes: 30
75
+ steps:
76
+ - name: Checkout Code
77
+ uses: actions/checkout@v4
78
+ - name: Set up Python 3.8
79
+ uses: actions/setup-python@v5
80
+ with:
81
+ python-version: 3.8
82
+ cache: pip
83
+ cache-dependency-path: pyproject.toml
84
+ - name: Install dependencies
85
+ run: |
86
+ python -m pip install --upgrade pip
87
+ pip install -e '.[dev,optimum,deepsparse,sparseml,api]' --extra-index-url https://download.pytorch.org/whl/cpu
88
+ - name: Test with pytest
89
+ run: python -m pytest tests/models --showlocals -s -vv
90
+ - name: Archive artifacts
91
+ uses: actions/upload-artifact@v3
92
+ with:
93
+ name: output_results
94
+ path: |
95
+ test_logs/*
scripts/yans/lm-evaluation-harness/lm_eval/api/__init__.py ADDED
File without changes
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (160 Bytes). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/filter.cpython-310.pyc ADDED
Binary file (2.72 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/group.cpython-310.pyc ADDED
Binary file (4.61 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/instance.cpython-310.pyc ADDED
Binary file (1.51 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/model.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/registry.cpython-310.pyc ADDED
Binary file (5.11 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/samplers.cpython-310.pyc ADDED
Binary file (4.81 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/api/__pycache__/task.cpython-310.pyc ADDED
Binary file (43.5 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/api/filter.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Iterable, List, Union
4
+
5
+ from lm_eval.api.instance import Instance
6
+
7
+
8
+ class Filter(ABC):
9
+ """
10
+ Filter classes operate on a per-task level.
11
+ They take all model outputs (`instance.resps` for all `task.instances`)
12
+ across all instances of a task, and perform operations.
13
+ In a single run, one can configure any number of separate filters or lists of filters.
14
+
15
+ """
16
+
17
+ def __init__(self, **kwargs) -> None:
18
+ """
19
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
20
+ """
21
+
22
+ @abstractmethod
23
+ def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
24
+ """
25
+ Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
26
+ Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
27
+ if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
28
+ [<filtered resps for instance 0>, <filtered resps for instance 1>]
29
+ """
30
+ return resps
31
+
32
+
33
+ @dataclass
34
+ class FilterEnsemble:
35
+ """
36
+ FilterEnsemble creates a pipeline applying multiple filters.
37
+ Its intended usage is to stack multiple post-processing steps in order.
38
+ `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
39
+ pipeline separately.
40
+ """
41
+
42
+ name: str
43
+ filters: List[Callable[[], Filter]]
44
+
45
+ def apply(self, instances: List[Instance]) -> None:
46
+ resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
47
+ resps, docs = list(resps), list(docs)
48
+
49
+ for f in self.filters:
50
+ # apply filters in sequence
51
+ resps = f().apply(resps, docs)
52
+
53
+ # add the end results after filtering to filtered_requests of their respective source instances.
54
+ # has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
55
+ for inst, resp in zip(instances, resps):
56
+ inst.filtered_resps[self.name] = resp
scripts/yans/lm-evaluation-harness/lm_eval/api/group.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from dataclasses import asdict, dataclass
3
+ from inspect import getsource
4
+ from typing import Any, Callable, List, Optional, Union
5
+
6
+
7
+ @dataclass
8
+ class AggMetricConfig(dict):
9
+ metric: Optional[str] = None
10
+ aggregation: Optional[str] = "mean"
11
+ weight_by_size: Optional[str] = False
12
+ # list of filter names which should be incorporated into the aggregated metric.
13
+ filter_list: Optional[Union[str, list]] = "none"
14
+
15
+ def __post_init__(self):
16
+ if self.aggregation != "mean":
17
+ raise ValueError(
18
+ f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{self.aggregation}'."
19
+ )
20
+
21
+ if isinstance(self.filter_list, str):
22
+ self.filter_list = [self.filter_list]
23
+
24
+
25
+ @dataclass
26
+ class GroupConfig(dict):
27
+ group: Optional[str] = None
28
+ group_alias: Optional[str] = None
29
+ task: Optional[Union[str, list]] = None
30
+ aggregate_metric_list: Optional[
31
+ Union[List[AggMetricConfig], AggMetricConfig, dict]
32
+ ] = None
33
+ metadata: Optional[dict] = (
34
+ None # by default, not used in the code. allows for users to pass arbitrary info to tasks
35
+ )
36
+
37
+ def __getitem__(self, item):
38
+ return getattr(self, item)
39
+
40
+ def __setitem__(self, item, value):
41
+ return setattr(self, item, value)
42
+
43
+ def __post_init__(self):
44
+ if self.aggregate_metric_list is not None:
45
+ if isinstance(self.aggregate_metric_list, dict):
46
+ self.aggregate_metric_list = [self.aggregate_metric_list]
47
+
48
+ self.aggregate_metric_list = [
49
+ AggMetricConfig(**item) if isinstance(item, dict) else item
50
+ for item in self.aggregate_metric_list
51
+ ]
52
+
53
+ def to_dict(self, keep_callable: bool = False) -> dict:
54
+ """dumps the current config as a dictionary object, as a printable format.
55
+ null fields will not be printed.
56
+ Used for dumping results alongside full task configuration
57
+
58
+ :return: dict
59
+ A printable dictionary version of the TaskConfig object.
60
+
61
+ # TODO: should any default value in the TaskConfig not be printed?
62
+ """
63
+ cfg_dict = asdict(self)
64
+ # remove values that are `None`
65
+ for k, v in list(cfg_dict.items()):
66
+ if callable(v):
67
+ cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
68
+ return cfg_dict
69
+
70
+ def serialize_function(
71
+ self, value: Union[Callable, str], keep_callable=False
72
+ ) -> Union[Callable, str]:
73
+ """Serializes a given function or string.
74
+
75
+ If 'keep_callable' is True, the original callable is returned.
76
+ Otherwise, attempts to return the source code of the callable using 'getsource'.
77
+ """
78
+ if keep_callable:
79
+ return value
80
+ else:
81
+ try:
82
+ return getsource(value)
83
+ except (TypeError, OSError):
84
+ return str(value)
85
+
86
+
87
+ class ConfigurableGroup(abc.ABC):
88
+ def __init__(
89
+ self,
90
+ config: Optional[dict] = None,
91
+ ) -> None:
92
+ self._config = GroupConfig(**config)
93
+
94
+ @property
95
+ def group(self):
96
+ return self._config.group
97
+
98
+ @property
99
+ def group_alias(self):
100
+ return self._config.group_alias
101
+
102
+ @property
103
+ def version(self):
104
+ return self._config.version
105
+
106
+ @property
107
+ def config(self):
108
+ return self._config.to_dict()
109
+
110
+ @property
111
+ def group_name(self) -> Any:
112
+ return self._config.group
113
+
114
+ def __repr__(self):
115
+ return (
116
+ f"ConfigurableGroup(group={self.group}," f"group_alias={self.group_alias})"
117
+ )
scripts/yans/lm-evaluation-harness/lm_eval/api/instance.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Literal, Optional, Tuple
3
+
4
+
5
+ OutputType = Literal[
6
+ "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
7
+ ]
8
+
9
+
10
+ @dataclass
11
+ class Instance:
12
+ request_type: OutputType
13
+ doc: dict
14
+ arguments: tuple
15
+ idx: int
16
+ metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
17
+ default_factory=lambda: (None, None, None)
18
+ )
19
+ resps: list = field(default_factory=list)
20
+ filtered_resps: dict = field(default_factory=dict)
21
+
22
+ # initialized after init
23
+ task_name: Optional[str] = None
24
+ doc_id: Optional[int] = None
25
+ repeats: Optional[int] = None
26
+
27
+ def __post_init__(self) -> None:
28
+ # unpack metadata field
29
+ self.task_name, self.doc_id, self.repeats = self.metadata
30
+
31
+ @property
32
+ def args(self):
33
+ """
34
+ Returns (string,) where `string` is the string to calculate loglikelihood over
35
+ """
36
+ return (
37
+ self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
38
+ )
scripts/yans/lm-evaluation-harness/lm_eval/api/metrics.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import random
4
+ import re
5
+ import string
6
+ from collections.abc import Iterable
7
+ from typing import List
8
+
9
+ import numpy as np
10
+ import sacrebleu
11
+
12
+ from lm_eval.api.registry import register_aggregation, register_metric
13
+
14
+
15
+ eval_logger = logging.getLogger("lm-eval")
16
+
17
+
18
+ # Register Aggregations First
19
+ @register_aggregation("bypass")
20
+ def bypass_agg(arr):
21
+ return 999
22
+
23
+
24
+ @register_aggregation("mean")
25
+ def mean(arr):
26
+ return sum(arr) / len(arr)
27
+
28
+
29
+ @register_aggregation("median")
30
+ def median(arr):
31
+ return arr[len(arr) // 2]
32
+
33
+
34
+ # Certain metrics must be calculated across all documents in a benchmark.
35
+ # We use them as aggregation metrics, paired with no-op passthrough metric fns.
36
+ @register_aggregation("perplexity")
37
+ def perplexity(items):
38
+ return math.exp(-mean(items))
39
+
40
+
41
+ @register_aggregation("weighted_perplexity")
42
+ def weighted_perplexity(items):
43
+ return math.exp(-weighted_mean(items))
44
+
45
+
46
+ @register_aggregation("bits_per_byte")
47
+ def bits_per_byte(items):
48
+ return -weighted_mean(items) / math.log(2)
49
+
50
+
51
+ @register_aggregation("f1")
52
+ def f1_score(items):
53
+ from sklearn.metrics import f1_score
54
+
55
+ unzipped_list = list(zip(*items))
56
+ golds = unzipped_list[0]
57
+ preds = unzipped_list[1]
58
+ fscore = f1_score(golds, preds)
59
+
60
+ return np.max(fscore)
61
+
62
+
63
+ @register_aggregation("matthews_corrcoef")
64
+ def matthews_corrcoef(items):
65
+ from sklearn.metrics import matthews_corrcoef
66
+
67
+ unzipped_list = list(zip(*items))
68
+ golds = unzipped_list[0]
69
+ preds = unzipped_list[1]
70
+ return matthews_corrcoef(golds, preds)
71
+
72
+
73
+ @register_aggregation("bleu")
74
+ def bleu(items):
75
+ """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
76
+ for evaluating a generated sentence to a reference sentence. It counts matching
77
+ n-grams in the candidate translation to n-grams in the reference text, where
78
+ 1-gram or unigram would be each token and a bigram comparison would be each
79
+ word pair. The comparison is made regardless of word order
80
+ Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
81
+ Paper: https://www.aclweb.org/anthology/P02-1040/
82
+
83
+ Higher is better
84
+ """
85
+ refs = list(zip(*items))[0]
86
+ preds = list(zip(*items))[1]
87
+ refs, preds = _sacreformat(refs, preds)
88
+ return sacrebleu.corpus_bleu(preds, refs).score
89
+
90
+
91
+ @register_aggregation("chrf")
92
+ def chrf(items):
93
+ """chrF++ is a tool for automatic evaluation of machine translation output
94
+ based on character n-gram precision and recall enhanced with word n-grams.
95
+ Source: https://github.com/m-popovic/chrF
96
+ Paper: https://www.aclweb.org/anthology/W15-3049.pdf
97
+
98
+ Higher is better # TODO I think
99
+ """
100
+ refs = list(zip(*items))[0]
101
+ preds = list(zip(*items))[1]
102
+ refs, preds = _sacreformat(refs, preds)
103
+ return sacrebleu.corpus_chrf(preds, refs).score
104
+
105
+
106
+ @register_aggregation("ter")
107
+ def ter(items):
108
+ """Translation Error Rate is an error metric for machine translation that
109
+ measures the number of edits required to change a system output into one
110
+ of the references
111
+ Source: http://www.cs.umd.edu/~snover/tercom/
112
+ Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
113
+
114
+ Lower is better
115
+ """
116
+ refs = list(zip(*items))[0]
117
+ preds = list(zip(*items))[1]
118
+ refs, preds = _sacreformat(refs, preds)
119
+ return sacrebleu.corpus_ter(preds, refs).score
120
+
121
+
122
+ @register_aggregation("brier_score")
123
+ def brier_score(items): # This is a passthrough function
124
+ gold, predictions = list(zip(*items))
125
+ bs, num_class = np.array(predictions).shape
126
+
127
+ gold = list(gold)
128
+ gold_one_hot = np.eye(num_class)[gold]
129
+ return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
130
+
131
+
132
+ @register_metric(
133
+ metric="brier_score",
134
+ higher_is_better=False,
135
+ output_type=["multiple_choice"],
136
+ aggregation="brier_score",
137
+ )
138
+ def brier_score_fn(items): # This is a passthrough function
139
+ return items
140
+
141
+
142
+ @register_metric(
143
+ metric="acc",
144
+ higher_is_better=True,
145
+ output_type=["loglikelihood", "multiple_choice"],
146
+ aggregation="mean",
147
+ )
148
+ def acc_fn(items): # This is a passthrough function
149
+ return items
150
+
151
+
152
+ @register_metric(
153
+ metric="acc_norm",
154
+ higher_is_better=True,
155
+ output_type=["loglikelihood", "multiple_choice"],
156
+ aggregation="mean",
157
+ )
158
+ def acc_norm_fn(items): # This is a passthrough function
159
+ return items
160
+
161
+
162
+ @register_metric(
163
+ metric="acc_mutual_info",
164
+ higher_is_better=True,
165
+ output_type="multiple_choice",
166
+ aggregation="mean",
167
+ )
168
+ def acc_mutual_info_fn(items): # This is a passthrough function
169
+ return items
170
+
171
+
172
+ ### the code used in the `exact_match_hf_evaluate` function is ported from
173
+ ### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py
174
+ ### which is under the apache license.
175
+
176
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
177
+
178
+ # Licensed under the Apache License, Version 2.0 (the "License");
179
+ # you may not use this file except in compliance with the License.
180
+ # You may obtain a copy of the License at
181
+
182
+ # http://www.apache.org/licenses/LICENSE-2.0
183
+
184
+
185
+ # Unless required by applicable law or agreed to in writing, software
186
+ # distributed under the License is distributed on an "AS IS" BASIS,
187
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
188
+ # See the License for the specific language governing permissions and
189
+ # limitations under the License.
190
+ def exact_match_hf_evaluate(
191
+ predictions,
192
+ references,
193
+ regexes_to_ignore=None,
194
+ ignore_case=False,
195
+ ignore_punctuation=False,
196
+ ignore_numbers=False,
197
+ ):
198
+ if regexes_to_ignore is not None:
199
+ for s in regexes_to_ignore:
200
+ predictions = np.array([re.sub(s, "", x) for x in predictions])
201
+ references = np.array([re.sub(s, "", x) for x in references])
202
+ else:
203
+ predictions = np.asarray(predictions)
204
+ references = np.asarray(references)
205
+
206
+ if ignore_case:
207
+ predictions = np.char.lower(predictions)
208
+ references = np.char.lower(references)
209
+
210
+ if ignore_punctuation:
211
+ repl_table = string.punctuation.maketrans("", "", string.punctuation)
212
+ predictions = np.char.translate(predictions, table=repl_table)
213
+ references = np.char.translate(references, table=repl_table)
214
+
215
+ if ignore_numbers:
216
+ repl_table = string.digits.maketrans("", "", string.digits)
217
+ predictions = np.char.translate(predictions, table=repl_table)
218
+ references = np.char.translate(references, table=repl_table)
219
+
220
+ score_list = predictions == references
221
+
222
+ return {"exact_match": np.mean(score_list)}
223
+
224
+
225
+ ###
226
+
227
+
228
+ @register_metric(
229
+ metric="exact_match",
230
+ higher_is_better=True,
231
+ output_type="generate_until",
232
+ aggregation="mean",
233
+ )
234
+ def exact_match_fn(**kwargs):
235
+ return exact_match_hf_evaluate(**kwargs)
236
+
237
+
238
+ @register_metric(
239
+ metric="perplexity",
240
+ higher_is_better=False,
241
+ output_type="loglikelihood",
242
+ aggregation="perplexity",
243
+ )
244
+ def perplexity_fn(items): # This is a passthrough function
245
+ return items
246
+
247
+
248
+ @register_metric(
249
+ metric="word_perplexity",
250
+ higher_is_better=False,
251
+ output_type="loglikelihood_rolling",
252
+ aggregation="weighted_perplexity",
253
+ )
254
+ def word_perplexity_fn(items): # This is a passthrough function
255
+ return items
256
+
257
+
258
+ @register_metric(
259
+ metric="byte_perplexity",
260
+ higher_is_better=False,
261
+ output_type="loglikelihood_rolling",
262
+ aggregation="weighted_perplexity",
263
+ )
264
+ def byte_perplexity_fn(items): # This is a passthrough function
265
+ return items
266
+
267
+
268
+ @register_metric(
269
+ metric="bits_per_byte",
270
+ higher_is_better=False,
271
+ output_type="loglikelihood_rolling",
272
+ aggregation="bits_per_byte",
273
+ )
274
+ def bits_per_byte_fn(items): # This is a passthrough function
275
+ return items
276
+
277
+
278
+ def pop_stddev(arr):
279
+ mu = mean(arr)
280
+ return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
281
+
282
+
283
+ def sample_stddev(arr):
284
+ mu = mean(arr)
285
+ return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
286
+
287
+
288
+ def mean_stderr(arr):
289
+ return sample_stddev(arr) / math.sqrt(len(arr))
290
+
291
+
292
+ @register_metric(
293
+ metric="bypass",
294
+ higher_is_better=True,
295
+ output_type=["loglikelihood", "multiple_choice", "generate_until"],
296
+ aggregation="bypass",
297
+ )
298
+ def bypass(items):
299
+ return None
300
+
301
+
302
+ @register_metric(
303
+ metric="mcc",
304
+ higher_is_better=True,
305
+ output_type="multiple_choice",
306
+ aggregation="matthews_corrcoef",
307
+ )
308
+ def mcc_fn(items): # This is a passthrough function
309
+ return items
310
+
311
+
312
+ @register_metric(
313
+ metric="f1",
314
+ higher_is_better=True,
315
+ output_type="multiple_choice",
316
+ aggregation="f1",
317
+ )
318
+ def f1_fn(items): # This is a passthrough function
319
+ return items
320
+
321
+
322
+ @register_metric(
323
+ metric="bleu",
324
+ higher_is_better=True,
325
+ output_type="generate_until",
326
+ aggregation="bleu",
327
+ )
328
+ def bleu_fn(items): # This is a passthrough function
329
+ return items
330
+
331
+
332
+ @register_metric(
333
+ metric="chrf",
334
+ higher_is_better=True,
335
+ output_type="generate_until",
336
+ aggregation="chrf",
337
+ )
338
+ def chrf_fn(items): # This is a passthrough function
339
+ return items
340
+
341
+
342
+ @register_metric(
343
+ metric="ter",
344
+ higher_is_better=True,
345
+ output_type="generate_until",
346
+ aggregation="ter",
347
+ )
348
+ def ter_fn(items): # This is a passthrough function
349
+ return items
350
+
351
+
352
+ @register_metric(
353
+ metric="acc_all",
354
+ higher_is_better=True,
355
+ output_type="loglikelihood",
356
+ aggregation="mean",
357
+ )
358
+ def acc_all(items):
359
+ # Only count as correct if all answers are labeled correctly for each question
360
+ question_scoring_dict = {}
361
+ preds = list(zip(*items))[0]
362
+ docs = list(zip(*items))[1]
363
+
364
+ for doc, pred in zip(docs, preds):
365
+ paragraph_id = doc["idx"]["paragraph"]
366
+ question_id = doc["idx"]["question"]
367
+ if (paragraph_id, question_id) not in question_scoring_dict:
368
+ question_scoring_dict[(paragraph_id, question_id)] = []
369
+
370
+ gold_label = doc["label"] == 1
371
+
372
+ question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
373
+ acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
374
+ return acc
375
+
376
+
377
+ def acc_all_stderr(items):
378
+ # Only count as correct if all answers are labeled correctly for each question
379
+ question_scoring_dict = {}
380
+ preds = list(zip(*items))[0]
381
+ docs = list(zip(*items))[1]
382
+
383
+ for doc, pred in zip(docs, preds):
384
+ question_id = doc["idx"]["question"]
385
+ if question_id not in question_scoring_dict:
386
+ question_scoring_dict[question_id] = []
387
+
388
+ gold_label = doc["label"] == 1
389
+ question_scoring_dict[question_id].append(gold_label == pred)
390
+
391
+ acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
392
+ return acc
393
+
394
+
395
+ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
396
+ """Compute max metric between prediction and each ground truth."""
397
+ scores_for_ground_truths = []
398
+ for ground_truth in ground_truths:
399
+ score = metric_fn(prediction, ground_truth)
400
+ scores_for_ground_truths.append(score)
401
+ return max(scores_for_ground_truths)
402
+
403
+
404
+ def weighted_mean(items):
405
+ a, b = zip(*items)
406
+ return sum(a) / sum(b)
407
+
408
+
409
+ def is_non_str_iterable(obj):
410
+ return isinstance(obj, Iterable) and not isinstance(obj, str)
411
+
412
+
413
+ def _sacreformat(refs, preds):
414
+ """Format refs and preds for sacrebleu corpus calculation. It is very particular"""
415
+ # Sacrebleu expects (List[str], List[List[str])
416
+ # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
417
+
418
+ # Note [ref1_stream] is the first reference for each pred.
419
+ # So lists are size N and (M, N) for N preds and M possible refs for each pred
420
+ # This is a different order of dimensions that I would expect
421
+
422
+ # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
423
+ # Must become List[List[str]] with the inner list corresponding to preds
424
+ if not is_non_str_iterable(refs):
425
+ refs = list(refs)
426
+ if not is_non_str_iterable(refs[0]):
427
+ refs = [[ref] for ref in refs]
428
+ refs = list(zip(*refs))
429
+ # Note the number of refs in each ref list much match the number of preds
430
+
431
+ # We expect preds to be List[str] or List[List[str]]. Must become List[str]
432
+ if not is_non_str_iterable(preds):
433
+ preds = list(preds)
434
+ if is_non_str_iterable(preds[0]):
435
+ assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
436
+ preds = [pred[0] for pred in preds]
437
+
438
+ return refs, preds
439
+
440
+
441
+ # stderr stuff
442
+
443
+
444
+ class _bootstrap_internal:
445
+ def __init__(self, f, n) -> None:
446
+ self.f = f
447
+ self.n = n
448
+
449
+ def __call__(self, v):
450
+ i, xs = v
451
+ rnd = random.Random()
452
+ rnd.seed(i)
453
+ res = []
454
+ for _ in range(self.n):
455
+ res.append(self.f(rnd.choices(xs, k=len(xs))))
456
+ return res
457
+
458
+
459
+ def bootstrap_stderr(f, xs, iters):
460
+ import multiprocessing as mp
461
+
462
+ pool = mp.Pool(mp.cpu_count())
463
+ # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
464
+ # equivalent to stderr calculated without Bessel's correction in the stddev.
465
+ # Unfortunately, I haven't been able to figure out what the right correction is
466
+ # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
467
+ # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
468
+ # Thankfully, shouldn't matter because our samples are pretty big usually anyways
469
+ res = []
470
+ chunk_size = min(1000, iters)
471
+ from tqdm import tqdm
472
+
473
+ print("bootstrapping for stddev:", f.__name__)
474
+ for bootstrap in tqdm(
475
+ pool.imap(
476
+ _bootstrap_internal(f, chunk_size),
477
+ [(i, xs) for i in range(iters // chunk_size)],
478
+ ),
479
+ total=iters // chunk_size,
480
+ ):
481
+ # sample w replacement
482
+ res.extend(bootstrap)
483
+
484
+ pool.close()
485
+ return sample_stddev(res)
486
+
487
+
488
+ def stderr_for_metric(metric, bootstrap_iters: int):
489
+ if bootstrap_iters <= 0:
490
+ # return no function (don't compute stderr) if bootstrap iters = 0
491
+ return None
492
+
493
+ bootstrappable = [
494
+ median,
495
+ matthews_corrcoef,
496
+ f1_score,
497
+ perplexity,
498
+ bleu,
499
+ chrf,
500
+ ter,
501
+ ]
502
+
503
+ if metric in bootstrappable:
504
+ return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
505
+
506
+ stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
507
+
508
+ return stderr.get(metric, None)
509
+
510
+
511
+ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
512
+ # Used to aggregate bootstrapped stderrs across subtasks in a group,
513
+ # when we are weighting by the size of each subtask.
514
+ #
515
+
516
+ assert len(stderrs) == len(sizes)
517
+
518
+ # formula source: https://en.wikipedia.org/wiki/Pooled_variance
519
+ # and: https://stats.stackexchange.com/a/4841331
520
+ # this empirically seems to match running `stderr_for_metric` on all instances
521
+ # from the subtasks concatenated with each other.
522
+ pooled_sample_var = (
523
+ sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)])
524
+ ) / (sum(sizes) - len(sizes))
525
+
526
+ return np.sqrt(pooled_sample_var / sum(sizes))
527
+
528
+
529
+ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
530
+ assert (
531
+ metrics is not None
532
+ ), "Need to pass a list of each subtask's metric for this stderr aggregation"
533
+ assert len(stderrs) == len(sizes) and len(sizes) == len(metrics)
534
+
535
+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation.
536
+ # This formula depends on sample means.
537
+ # removed because it seems to give erroneously huge stderrs for groupings of tasks
538
+ # and does not seem to match up with bootstrap-calculated stderrs for groups.
539
+
540
+ ### don't use this unless a statistician has told you it's the right thing to do ###
541
+
542
+ # accumulators: we'll aggregate pairwise N - 1 times
543
+ variance = stderrs[0] ** 2
544
+ curr_size = sizes[0]
545
+ curr_score = metrics[0]
546
+
547
+ for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]):
548
+ curr_score = ((curr_score * curr_size) + (score * size)) / (
549
+ curr_size + size
550
+ ) # NOTE: this assumes our aggregation fn is "mean"
551
+
552
+ variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / (
553
+ curr_size + size - 1
554
+ ) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * (
555
+ curr_score - score
556
+ ) ** 2
557
+
558
+ return np.sqrt(variance)
559
+
560
+
561
+ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
562
+ # A helper function that is used to aggregate
563
+ # subtask scores cross-task.
564
+ # TODO: does not hold for non-mean aggregations
565
+ if not weight_by_size:
566
+ sizes = [1] * len(sizes)
567
+
568
+ assert len(metrics) == len(sizes)
569
+
570
+ return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes)
scripts/yans/lm-evaluation-harness/lm_eval/api/model.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import hashlib
3
+ import json
4
+ import logging
5
+ import os
6
+ from typing import Dict, List, Optional, Tuple, Type, TypeVar
7
+
8
+ import transformers
9
+ from sqlitedict import SqliteDict
10
+ from tqdm import tqdm
11
+
12
+ from lm_eval import utils
13
+
14
+
15
+ eval_logger = logging.getLogger("lm-eval")
16
+
17
+ T = TypeVar("T", bound="LM")
18
+
19
+
20
+ class LM(abc.ABC):
21
+ def __init__(self) -> None:
22
+ """Defines the interface that should be implemented by all LM subclasses.
23
+ LMs are assumed to take text (strings) as input and yield strings as output
24
+ (inputs/outputs should be tokenization-agnostic.)
25
+
26
+ """
27
+ # set rank and world size to a single process, by default.
28
+ self._rank = 0
29
+ self._world_size = 1
30
+ self.cache_hook = CacheHook(None)
31
+
32
+ @abc.abstractmethod
33
+ def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
34
+ """Compute log-likelihood of generating a continuation from a context.
35
+ Downstream tasks should attempt to use loglikelihood instead of other
36
+ LM calls whenever possible.
37
+
38
+ :param requests: list[Instance]
39
+ A list of Instance objects, with property `args` which returns a tuple (context, continuation).
40
+ `context: str`
41
+ Context string. Implementations of LM must be able to handle an
42
+ empty context string.
43
+ `continuation: str`
44
+ The continuation over which log likelihood will be calculated. If
45
+ there is a word boundary, the space should be in the continuation.
46
+ For example, context="hello" continuation=" world" is correct.
47
+
48
+ :return: list[tuple[float, bool]]
49
+ A list of pairs (logprob, isgreedy)
50
+ `logprob: float`
51
+ The log probability of `continuation`.
52
+ `isgreedy`:
53
+ Whether `continuation` would be generated by greedy sampling from `context`.
54
+ """
55
+ pass
56
+
57
+ @abc.abstractmethod
58
+ def loglikelihood_rolling(self, requests) -> List[float]:
59
+ """Compute full log-likelihood of a string, with no truncation, for perplexity computation
60
+ - We will use the full max context length of the model.
61
+ - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
62
+ the max context length.
63
+ - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
64
+ which may simply concatenate multiple documents together.
65
+ - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
66
+ multiple chunks, the last input will still a full-sized context.
67
+ Example:
68
+ Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
69
+ Prefix: BOS/EOS
70
+ Max context length: 4
71
+ Resulting input/prediction pairs:
72
+
73
+ INPUT: BOS 0 1 2
74
+ PRED: 0 1 2 3
75
+
76
+ INPUT: 3 4 5 6
77
+ PRED: 4 5 6 7
78
+
79
+ INPUT: 5 6 7 8
80
+ PRED: 8 9
81
+
82
+ Observe that:
83
+ 1. Each token is predicted exactly once
84
+ 2. For the last pair, we provide the full context, but only score the last two tokens
85
+
86
+ :param requests: list[Instance]
87
+ A list of Instance objects with property `args` which returns a tuple (context,).
88
+ string: str
89
+ String for which we are computing overall loglikelihood
90
+ :return: list[tuple[float]]
91
+ A list of tuples (logprob,)
92
+ logprob: float
93
+ The log probability of `context` conditioned on the BOS/EOS token.
94
+ Can also be overridden for custom cases by `prefix_token_id`.
95
+ """
96
+ pass
97
+
98
+ # TODO: Add an optional max length
99
+ @abc.abstractmethod
100
+ def generate_until(self, requests) -> List[str]:
101
+ """Generate greedily until a stopping sequence
102
+
103
+ :param requests: list[Instance]
104
+ A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
105
+ context: str
106
+ Context string
107
+ gen_kwargs: dict
108
+ A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
109
+ :return: list[str]
110
+ A list of model generated continuations.
111
+ continuation: str
112
+ The generated continuation.
113
+ """
114
+ pass
115
+
116
+ def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
117
+ """
118
+ Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
119
+
120
+ :param chat_history: list[dict[str, str]]
121
+ A list of dictionaries with keys 'role' and 'content'.
122
+ Values are strings representing the role name and the content of the message, respectively.
123
+ :return: str
124
+ A string representing the chat history in a format that can be used as input to the LM.
125
+ """
126
+ raise NotImplementedError(
127
+ "To use this model with chat templates, please implement the 'apply_chat_template' method for your model type."
128
+ )
129
+
130
+ @classmethod
131
+ def create_from_arg_string(
132
+ cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
133
+ ) -> T:
134
+ """
135
+ Creates an instance of the LM class using the given argument string and additional config.
136
+
137
+ Parameters:
138
+ - arg_string: A string containing arguments in the format key1=value1,key2=value2.
139
+ - additional_config: Optional dictionary containing additional configuration parameters.
140
+
141
+ Returns:
142
+ - Instance of the LM class.
143
+ """
144
+ additional_config = {} if additional_config is None else additional_config
145
+ args = utils.simple_parse_args_string(arg_string)
146
+ args2 = {k: v for k, v in additional_config.items() if v is not None}
147
+ return cls(**args, **args2)
148
+
149
+ @classmethod
150
+ def create_from_arg_obj(
151
+ cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
152
+ ) -> T:
153
+ """
154
+ Creates an instance of the LM class using the given arg_obj
155
+
156
+ Parameters:
157
+ - arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
158
+ - additional_config: Optional dictionary containing additional configuration parameters.
159
+
160
+ Returns:
161
+ - Instance of the LM class.
162
+ """
163
+
164
+ additional_config = {} if additional_config is None else additional_config
165
+ additional_config = {
166
+ k: v for k, v in additional_config.items() if v is not None
167
+ }
168
+
169
+ return cls(**arg_dict, **additional_config)
170
+
171
+ @property
172
+ def rank(self):
173
+ # used in the case of parallelism. Hardcoded to
174
+ # ensure no errors arise using API models which do
175
+ # not support multi-device parallelism nor expect it.
176
+ return self._rank
177
+
178
+ @property
179
+ def world_size(self):
180
+ # used in the case of parallelism. Hardcoded to
181
+ # ensure no errors arise using API models which do
182
+ # not support multi-device parallelism nor expect it.
183
+ return self._world_size
184
+
185
+ @property
186
+ def tokenizer_name(self) -> str:
187
+ """Must be defined for LM subclasses which implement Chat Templating.
188
+ Should return the name of the tokenizer or chat template used.
189
+ Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
190
+ """
191
+ raise NotImplementedError(
192
+ "To use this model with chat templates, please implement the 'tokenizer_name' property."
193
+ )
194
+
195
+ @property
196
+ def chat_template(self) -> str:
197
+ """Must be defined for LM subclasses that implement Chat Templating.
198
+ Should return the structure of the chat template applied to user/assistant messages.
199
+ This is used only to save in the experiment results for reproducibility.
200
+ """
201
+ raise NotImplementedError(
202
+ "To use this model with chat templates, please implement the 'chat_template' property."
203
+ )
204
+
205
+ def set_cache_hook(self, cache_hook) -> None:
206
+ self.cache_hook = cache_hook
207
+
208
+
209
+ ### SQLite-based caching of LM responses
210
+ def hash_args(attr, args):
211
+ dat = json.dumps([attr] + list(args))
212
+ return hashlib.sha256(dat.encode("utf-8")).hexdigest()
213
+
214
+
215
+ class CacheHook:
216
+ def __init__(self, cachinglm) -> None:
217
+ if cachinglm is None:
218
+ self.dbdict = None
219
+ return
220
+
221
+ self.dbdict = cachinglm.dbdict
222
+
223
+ def add_partial(self, attr, req, res) -> None:
224
+ if self.dbdict is None:
225
+ return
226
+ hsh = hash_args(attr, req)
227
+ self.dbdict[hsh] = res
228
+
229
+
230
+ class CachingLM:
231
+ def __init__(self, lm, cache_db) -> None:
232
+ """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
233
+
234
+ :param lm: LM
235
+ Underlying LM
236
+ :param cache_db: str
237
+ Path to cache db
238
+ """
239
+ self.lm = lm
240
+ self.cache_db = cache_db
241
+ if os.path.dirname(cache_db):
242
+ os.makedirs(os.path.dirname(cache_db), exist_ok=True)
243
+ self.dbdict = SqliteDict(cache_db, autocommit=True)
244
+
245
+ # add hook to lm
246
+ lm.set_cache_hook(self.get_cache_hook())
247
+
248
+ def __getattr__(self, attr: str):
249
+ lm_attr = getattr(self.lm, attr)
250
+ if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
251
+ eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
252
+ return lm_attr
253
+
254
+ def fn(requests):
255
+ res = []
256
+ remaining_reqs = []
257
+ warned = False
258
+ # figure out which ones are cached and which ones are new
259
+ eval_logger.info(
260
+ f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
261
+ )
262
+ for req in tqdm(requests, desc="Checking cached requests"):
263
+ hsh = hash_args(attr, req.args)
264
+ if attr == "generate_until" and req.args[1].get("do_sample", False):
265
+ # when we are doing non-greedy generation, don't use the cache
266
+ # (else every "randomly sampled" generation would be identical for repeats > 1).
267
+ if not warned:
268
+ eval_logger.warning(
269
+ f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
270
+ )
271
+ warned = True
272
+ res.append(None)
273
+ remaining_reqs.append(req)
274
+ elif hsh in self.dbdict:
275
+ ob = self.dbdict[hsh]
276
+
277
+ assert ob is not None
278
+
279
+ res.append(ob)
280
+ else:
281
+ res.append(None)
282
+ remaining_reqs.append(req)
283
+ eval_logger.info(
284
+ f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
285
+ )
286
+ # actually run the LM on the requests that do not have cached results
287
+ rem_res = getattr(self.lm, attr)(remaining_reqs)
288
+
289
+ # stick the new ones back into the list and also cache any of the new ones
290
+ resptr = 0
291
+ for req, r in zip(remaining_reqs, rem_res):
292
+ while res[resptr] is not None:
293
+ resptr += 1
294
+
295
+ res[resptr] = r
296
+
297
+ # caching
298
+ hsh = hash_args(attr, req.args)
299
+ self.dbdict[hsh] = r
300
+ self.dbdict.commit()
301
+
302
+ return res
303
+
304
+ return fn
305
+
306
+ def get_cache_hook(self):
307
+ return CacheHook(self)
308
+
309
+
310
+ class TemplateLM(LM):
311
+ """
312
+ A class acting as intermediary between the LM base class
313
+ and boilerplate often included in other LM subclasses.
314
+ """
315
+
316
+ @property
317
+ @abc.abstractmethod
318
+ def eot_token_id(self):
319
+ pass
320
+
321
+ @property
322
+ def prefix_token_id(self):
323
+ # it is used as prefix for loglikelihood
324
+ return self.eot_token_id
325
+
326
+ @abc.abstractmethod
327
+ def tok_encode(self, string: str, **kwargs) -> List[int]:
328
+ """
329
+ Tokenize a string using the model's tokenizer and return a list of token IDs.
330
+ """
331
+ pass
332
+
333
+ @abc.abstractmethod
334
+ def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
335
+ pass
336
+
337
+ def _encode_pair(
338
+ self, context: str, continuation: str
339
+ ) -> Tuple[List[int], List[int]]:
340
+ n_spaces = len(context) - len(context.rstrip())
341
+ if n_spaces > 0:
342
+ continuation = context[-n_spaces:] + continuation
343
+ context = context[:-n_spaces]
344
+
345
+ model_class = getattr(self, "AUTO_MODEL_CLASS", None)
346
+
347
+ if model_class == transformers.AutoModelForSeq2SeqLM:
348
+ context_enc = self.tok_encode(context)
349
+ continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
350
+ else:
351
+ whole_enc = self.tok_encode(context + continuation)
352
+ context_enc = self.tok_encode(context)
353
+
354
+ context_enc_len = len(context_enc)
355
+ continuation_enc = whole_enc[context_enc_len:]
356
+
357
+ return context_enc, continuation_enc
358
+
359
+ def loglikelihood(
360
+ self, requests, disable_tqdm: bool = False
361
+ ) -> List[Tuple[float, bool]]:
362
+ new_reqs = []
363
+ for context, continuation in [req.args for req in requests]:
364
+ if context == "":
365
+ # BOS or EOS as context
366
+ context_enc, continuation_enc = (
367
+ [self.prefix_token_id],
368
+ self.tok_encode(continuation),
369
+ )
370
+ else:
371
+ context_enc, continuation_enc = self._encode_pair(context, continuation)
372
+
373
+ new_reqs.append(((context, continuation), context_enc, continuation_enc))
374
+
375
+ return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
376
+
377
+ @abc.abstractmethod
378
+ def loglikelihood_rolling(
379
+ self, requests, disable_tqdm: bool = False
380
+ ) -> List[float]:
381
+ pass
382
+
383
+ @abc.abstractmethod
384
+ def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
385
+ pass
scripts/yans/lm-evaluation-harness/lm_eval/api/registry.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Callable, Dict
3
+
4
+ import evaluate as hf_evaluate
5
+
6
+ from lm_eval.api.model import LM
7
+
8
+
9
+ eval_logger = logging.getLogger("lm-eval")
10
+
11
+ MODEL_REGISTRY = {}
12
+
13
+
14
+ def register_model(*names):
15
+ # either pass a list or a single alias.
16
+ # function receives them as a tuple of strings
17
+
18
+ def decorate(cls):
19
+ for name in names:
20
+ assert issubclass(
21
+ cls, LM
22
+ ), f"Model '{name}' ({cls.__name__}) must extend LM class"
23
+
24
+ assert (
25
+ name not in MODEL_REGISTRY
26
+ ), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
27
+
28
+ MODEL_REGISTRY[name] = cls
29
+ return cls
30
+
31
+ return decorate
32
+
33
+
34
+ def get_model(model_name):
35
+ try:
36
+ return MODEL_REGISTRY[model_name]
37
+ except KeyError:
38
+ raise ValueError(
39
+ f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
40
+ )
41
+
42
+
43
+ TASK_REGISTRY = {}
44
+ GROUP_REGISTRY = {}
45
+ ALL_TASKS = set()
46
+ func2task_index = {}
47
+
48
+
49
+ def register_task(name):
50
+ def decorate(fn):
51
+ assert (
52
+ name not in TASK_REGISTRY
53
+ ), f"task named '{name}' conflicts with existing registered task!"
54
+
55
+ TASK_REGISTRY[name] = fn
56
+ ALL_TASKS.add(name)
57
+ func2task_index[fn.__name__] = name
58
+ return fn
59
+
60
+ return decorate
61
+
62
+
63
+ def register_group(name):
64
+ def decorate(fn):
65
+ func_name = func2task_index[fn.__name__]
66
+ if name in GROUP_REGISTRY:
67
+ GROUP_REGISTRY[name].append(func_name)
68
+ else:
69
+ GROUP_REGISTRY[name] = [func_name]
70
+ ALL_TASKS.add(name)
71
+ return fn
72
+
73
+ return decorate
74
+
75
+
76
+ OUTPUT_TYPE_REGISTRY = {}
77
+ METRIC_REGISTRY = {}
78
+ METRIC_AGGREGATION_REGISTRY = {}
79
+ AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
80
+ HIGHER_IS_BETTER_REGISTRY = {}
81
+ FILTER_REGISTRY = {}
82
+
83
+ DEFAULT_METRIC_REGISTRY = {
84
+ "loglikelihood": [
85
+ "perplexity",
86
+ "acc",
87
+ ],
88
+ "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
89
+ "multiple_choice": ["acc", "acc_norm"],
90
+ "generate_until": ["exact_match"],
91
+ }
92
+
93
+
94
+ def register_metric(**args):
95
+ # TODO: do we want to enforce a certain interface to registered metrics?
96
+ def decorate(fn):
97
+ assert "metric" in args
98
+ name = args["metric"]
99
+
100
+ for key, registry in [
101
+ ("metric", METRIC_REGISTRY),
102
+ ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
103
+ ("aggregation", METRIC_AGGREGATION_REGISTRY),
104
+ ]:
105
+ if key in args:
106
+ value = args[key]
107
+ assert (
108
+ value not in registry
109
+ ), f"{key} named '{value}' conflicts with existing registered {key}!"
110
+
111
+ if key == "metric":
112
+ registry[name] = fn
113
+ elif key == "aggregation":
114
+ registry[name] = AGGREGATION_REGISTRY[value]
115
+ else:
116
+ registry[name] = value
117
+
118
+ return fn
119
+
120
+ return decorate
121
+
122
+
123
+ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
124
+ if not hf_evaluate_metric:
125
+ if name in METRIC_REGISTRY:
126
+ return METRIC_REGISTRY[name]
127
+ else:
128
+ eval_logger.warning(
129
+ f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
130
+ )
131
+
132
+ try:
133
+ metric_object = hf_evaluate.load(name)
134
+ return metric_object.compute
135
+ except Exception:
136
+ eval_logger.error(
137
+ f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
138
+ )
139
+
140
+
141
+ def register_aggregation(name: str):
142
+ def decorate(fn):
143
+ assert (
144
+ name not in AGGREGATION_REGISTRY
145
+ ), f"aggregation named '{name}' conflicts with existing registered aggregation!"
146
+
147
+ AGGREGATION_REGISTRY[name] = fn
148
+ return fn
149
+
150
+ return decorate
151
+
152
+
153
+ def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
154
+ try:
155
+ return AGGREGATION_REGISTRY[name]
156
+ except KeyError:
157
+ eval_logger.warning(f"{name} not a registered aggregation metric!")
158
+
159
+
160
+ def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
161
+ try:
162
+ return METRIC_AGGREGATION_REGISTRY[name]
163
+ except KeyError:
164
+ eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
165
+
166
+
167
+ def is_higher_better(metric_name) -> bool:
168
+ try:
169
+ return HIGHER_IS_BETTER_REGISTRY[metric_name]
170
+ except KeyError:
171
+ eval_logger.warning(
172
+ f"higher_is_better not specified for metric '{metric_name}'!"
173
+ )
174
+
175
+
176
+ def register_filter(name):
177
+ def decorate(cls):
178
+ if name in FILTER_REGISTRY:
179
+ eval_logger.info(
180
+ f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}"
181
+ )
182
+ FILTER_REGISTRY[name] = cls
183
+ return cls
184
+
185
+ return decorate
186
+
187
+
188
+ def get_filter(filter_name: str) -> type:
189
+ try:
190
+ return FILTER_REGISTRY[filter_name]
191
+ except KeyError:
192
+ eval_logger.warning(f"filter `{filter_name}` is not registered!")
scripts/yans/lm-evaluation-harness/lm_eval/api/samplers.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import datasets
4
+
5
+
6
+ class ContextSampler:
7
+ def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
8
+ self.rnd = rnd
9
+ if not self.rnd:
10
+ raise ValueError(
11
+ "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
12
+ )
13
+
14
+ self.task = task
15
+ self.config = task._config
16
+
17
+ self.target_delimiter = self.config.target_delimiter
18
+ self.fewshot_delimiter = self.config.fewshot_delimiter
19
+
20
+ if (
21
+ self.config.fewshot_config is not None
22
+ and self.config.fewshot_config.get("doc_to_text", None) is not None
23
+ ):
24
+ self.doc_to_text = partial(
25
+ self.task.doc_to_text,
26
+ doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
27
+ )
28
+ else:
29
+ self.doc_to_text = self.task.doc_to_text
30
+
31
+ if (
32
+ self.config.fewshot_config is not None
33
+ and self.config.fewshot_config.get("doc_to_target", None) is not None
34
+ ):
35
+ self.doc_to_target = partial(
36
+ self.task.doc_to_target,
37
+ doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
38
+ )
39
+ else:
40
+ self.doc_to_target = self.task.doc_to_target
41
+
42
+ if (
43
+ self.config.fewshot_config is not None
44
+ and self.config.fewshot_config.get("doc_to_choice", None) is not None
45
+ ):
46
+ self.doc_to_choice = partial(
47
+ self.task.doc_to_choice,
48
+ doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
49
+ )
50
+ else:
51
+ self.doc_to_choice = self.task.doc_to_choice
52
+
53
+ self.docs = docs # HF dataset split, provided by task._fewshot_docs()
54
+ if fewshot_indices: # subset few-shot docs from
55
+ if not isinstance(self.docs, datasets.Dataset):
56
+ raise ValueError(
57
+ "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
58
+ )
59
+ self.docs = self.docs.select(fewshot_indices)
60
+
61
+ def get_context(self, doc, num_fewshot):
62
+ # draw an extra fewshot sample if using same split as evaluating on
63
+ n_samples = (
64
+ num_fewshot + 1
65
+ if self.config.fewshot_split == self.config.test_split
66
+ else num_fewshot
67
+ )
68
+
69
+ # draw `n_samples` docs from fewshot_docs
70
+ fewshotex = self.sample(n_samples)
71
+
72
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
73
+ # TODO: should we just stop people from using fewshot from same split as evaluating?
74
+ selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
75
+
76
+ labeled_examples = ""
77
+ for doc in selected_docs:
78
+ doc_content = self.doc_to_text(doc)
79
+ doc_target = self.doc_to_target(doc)
80
+ labeled_examples += (
81
+ doc_content
82
+ if self.config.doc_to_choice is None or isinstance(doc_content, str)
83
+ else self.doc_to_choice(doc)[doc_content]
84
+ )
85
+ labeled_examples += self.target_delimiter
86
+ if doc_target != "":
87
+ labeled_examples += (
88
+ str(doc_target[0])
89
+ if isinstance(doc_target, list)
90
+ else doc_target
91
+ if self.config.doc_to_choice is None or isinstance(doc_target, str)
92
+ else str(self.doc_to_choice(doc)[doc_target])
93
+ )
94
+ labeled_examples += self.fewshot_delimiter
95
+
96
+ return labeled_examples
97
+
98
+ def get_chat_context(
99
+ self,
100
+ doc,
101
+ num_fewshot,
102
+ fewshot_as_multiturn: bool = False,
103
+ ):
104
+ chat_history = []
105
+ # draw an extra fewshot sample if using same split as evaluating on
106
+ n_samples = (
107
+ num_fewshot + 1
108
+ if self.config.fewshot_split == self.config.test_split
109
+ else num_fewshot
110
+ )
111
+ # draw `n_samples` docs from fewshot_docs
112
+ fewshotex = self.sample(n_samples)
113
+
114
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
115
+ # TODO: should we just stop people from using fewshot from same split as evaluating?
116
+ selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
117
+
118
+ if fewshot_as_multiturn:
119
+ for doc in selected_docs:
120
+ doc_content = self.doc_to_text(doc)
121
+ doc_target = self.doc_to_target(doc)
122
+ chat_history.append(
123
+ {
124
+ "role": "user",
125
+ "content": doc_content
126
+ if self.config.doc_to_choice is None
127
+ or isinstance(doc_content, str)
128
+ else self.doc_to_choice(doc)[doc_content],
129
+ }
130
+ )
131
+ chat_history.append(
132
+ {
133
+ "role": "assistant",
134
+ "content": str(doc_target[0])
135
+ if isinstance(doc_target, list)
136
+ else doc_target
137
+ if self.config.doc_to_choice is None
138
+ or isinstance(doc_target, str)
139
+ else str(self.doc_to_choice(doc)[doc_target]),
140
+ }
141
+ )
142
+ else:
143
+ # get fewshot context as one user turn
144
+ chat_history.append(
145
+ {"role": "user", "content": self.get_context(doc, num_fewshot)}
146
+ )
147
+
148
+ return chat_history
149
+
150
+ def sample(self, n):
151
+ """
152
+ Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
153
+ """
154
+
155
+ return self.rnd.sample(self.docs, n)
156
+
157
+
158
+ class FirstNSampler(ContextSampler):
159
+ def sample(self, n) -> None:
160
+ """
161
+ Draw the first `n` samples in order from the specified split.
162
+ Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
163
+ """
164
+ assert (
165
+ n <= len(self.docs)
166
+ ), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
167
+ return self.docs[:n]
168
+
169
+
170
+ class BalancedSampler(ContextSampler):
171
+ def sample(self, n) -> None:
172
+ """
173
+ TODO: this should return approximately class-balanced samples from our fewshot examples.
174
+ TODO: what order should they be in? maybe random?
175
+ """
176
+
177
+ pass
178
+
179
+
180
+ class ManualSampler(ContextSampler):
181
+ def sample(self, n) -> None:
182
+ """ """
183
+ pass
184
+
185
+
186
+ SAMPLER_REGISTRY = {
187
+ "default": ContextSampler,
188
+ "first_n": FirstNSampler,
189
+ }
190
+
191
+
192
+ def get_sampler(name):
193
+ try:
194
+ return SAMPLER_REGISTRY[name]
195
+ except KeyError:
196
+ raise ValueError(
197
+ f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
198
+ )
scripts/yans/lm-evaluation-harness/lm_eval/api/task.py ADDED
@@ -0,0 +1,1674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import ast
3
+ import logging
4
+ import random
5
+ import re
6
+ from collections.abc import Callable
7
+ from copy import deepcopy
8
+ from dataclasses import asdict, dataclass
9
+ from inspect import getsource
10
+ from typing import (
11
+ Any,
12
+ Dict,
13
+ Iterable,
14
+ Iterator,
15
+ List,
16
+ Literal,
17
+ Mapping,
18
+ Optional,
19
+ Tuple,
20
+ Union,
21
+ )
22
+
23
+ import datasets
24
+ import numpy as np
25
+ from tqdm import tqdm
26
+
27
+ from lm_eval import utils
28
+ from lm_eval.api import samplers
29
+ from lm_eval.api.instance import Instance, OutputType
30
+ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
31
+ from lm_eval.api.registry import (
32
+ AGGREGATION_REGISTRY,
33
+ DEFAULT_METRIC_REGISTRY,
34
+ get_aggregation,
35
+ get_metric,
36
+ get_metric_aggregation,
37
+ is_higher_better,
38
+ )
39
+ from lm_eval.caching.cache import load_from_cache, save_to_cache
40
+ from lm_eval.filters import build_filter_ensemble
41
+ from lm_eval.prompts import get_prompt
42
+
43
+
44
+ ALL_OUTPUT_TYPES = [
45
+ "loglikelihood",
46
+ "multiple_choice",
47
+ "loglikelihood_rolling",
48
+ "generate_until",
49
+ ]
50
+
51
+ eval_logger = logging.getLogger("lm-eval")
52
+
53
+
54
+ @dataclass
55
+ class TaskConfig(dict):
56
+ # task naming/registry
57
+ task: Optional[str] = None
58
+ task_alias: Optional[str] = None
59
+ tag: Optional[Union[str, list]] = None
60
+ group: Optional[Union[str, list]] = None
61
+ # HF dataset options.
62
+ # which dataset to use,
63
+ # and what splits for what purpose
64
+ dataset_path: Optional[str] = None
65
+ dataset_name: Optional[str] = None
66
+ dataset_kwargs: Optional[dict] = None
67
+ training_split: Optional[str] = None
68
+ validation_split: Optional[str] = None
69
+ test_split: Optional[str] = None
70
+ fewshot_split: Optional[str] = (
71
+ None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
72
+ )
73
+ # formatting / prompting options.
74
+ # see docs/advanced_task_guide.md for more info
75
+ process_docs: Optional[Callable] = None
76
+ doc_to_text: Optional[Union[Callable, str]] = None
77
+ doc_to_target: Optional[Union[Callable, str]] = None
78
+ doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
79
+ process_results: Optional[Union[Callable, str]] = None
80
+ use_prompt: Optional[str] = None
81
+ description: str = ""
82
+ target_delimiter: str = " "
83
+ fewshot_delimiter: str = "\n\n"
84
+ fewshot_config: Optional[dict] = None
85
+ # runtime configuration options
86
+ num_fewshot: Optional[int] = None
87
+ # scoring options
88
+ metric_list: Optional[list] = None
89
+ output_type: OutputType = "generate_until"
90
+ generation_kwargs: Optional[dict] = None
91
+ repeats: int = 1
92
+ filter_list: Optional[Union[str, list]] = None
93
+ should_decontaminate: bool = False
94
+ doc_to_decontamination_query: Optional[str] = None
95
+ metadata: Optional[dict] = (
96
+ None # by default, not used in the code. allows for users to pass arbitrary info to tasks
97
+ )
98
+
99
+ def __post_init__(self) -> None:
100
+ if self.group is not None:
101
+ eval_logger.warning(
102
+ "A task YAML file was found to contain a `group` key. Groups which provide aggregate scores over several subtasks now require a separate config file--if not aggregating, you may want to use the `tag` config option instead within your config. Setting `group` within a TaskConfig will be deprecated in v0.4.4. Please see https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/task_guide.md for more information."
103
+ )
104
+
105
+ if self.tag is None:
106
+ self.tag = self.group
107
+ else:
108
+ raise ValueError(
109
+ "Got both a `group` and `tag` entry within a TaskConfig. Please use one or the other--`group` values will be deprecated in v0.4.4."
110
+ )
111
+
112
+ if self.generation_kwargs is not None:
113
+ if self.output_type != "generate_until":
114
+ eval_logger.warning(
115
+ f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
116
+ )
117
+
118
+ if "temperature" in self.generation_kwargs:
119
+ self.generation_kwargs["temperature"] = float(
120
+ self.generation_kwargs["temperature"]
121
+ )
122
+
123
+ if "until" not in self.generation_kwargs:
124
+ self.generation_kwargs["until"] = [self.fewshot_delimiter]
125
+ else:
126
+ if self.output_type == "generate_until":
127
+ # ensure that we greedily generate in absence of explicit arguments otherwise
128
+ self.generation_kwargs = {
129
+ "until": (
130
+ None
131
+ if self.fewshot_delimiter is None
132
+ else [self.fewshot_delimiter]
133
+ ),
134
+ "do_sample": False,
135
+ }
136
+
137
+ def __getitem__(self, item):
138
+ return getattr(self, item)
139
+
140
+ def __setitem__(self, item, value):
141
+ return setattr(self, item, value)
142
+
143
+ def to_dict(self, keep_callable: bool = False) -> dict:
144
+ """dumps the current config as a dictionary object, as a printable format.
145
+ null fields will not be printed.
146
+ Used for dumping results alongside full task configuration
147
+
148
+ :return: dict
149
+ A printable dictionary version of the TaskConfig object.
150
+
151
+ # TODO: should any default value in the TaskConfig not be printed?
152
+ """
153
+ cfg_dict = asdict(self)
154
+ # remove values that are `None`
155
+ for k, v in list(cfg_dict.items()):
156
+ if v is None:
157
+ cfg_dict.pop(k)
158
+ elif k == "metric_list":
159
+ for metric_dict in v:
160
+ for metric_key, metric_value in metric_dict.items():
161
+ if callable(metric_value):
162
+ metric_dict[metric_key] = self.serialize_function(
163
+ metric_value, keep_callable=keep_callable
164
+ )
165
+ cfg_dict[k] = v
166
+ elif callable(v):
167
+ cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
168
+ return cfg_dict
169
+
170
+ def serialize_function(
171
+ self, value: Union[Callable, str], keep_callable=False
172
+ ) -> Union[Callable, str]:
173
+ """Serializes a given function or string.
174
+
175
+ If 'keep_callable' is True, the original callable is returned.
176
+ Otherwise, attempts to return the source code of the callable using 'getsource'.
177
+ """
178
+ if keep_callable:
179
+ return value
180
+ else:
181
+ try:
182
+ return getsource(value)
183
+ except (TypeError, OSError):
184
+ return str(value)
185
+
186
+
187
+ class Task(abc.ABC):
188
+ """A task represents an entire benchmark including its dataset, problems,
189
+ answers, and evaluation methods. See BoolQ for a simple example implementation
190
+
191
+ A `doc` can be any python object which represents one instance of evaluation.
192
+ This is usually a dictionary e.g.
193
+ {"question": ..., "answer": ...} or
194
+ {"question": ..., question, answer)
195
+ """
196
+
197
+ VERSION: Optional[Union[int, str]] = None
198
+
199
+ # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
200
+ # or a path to a custom `datasets` loading script.
201
+ DATASET_PATH: Optional[str] = None
202
+
203
+ # The name of a subset within `DATASET_PATH`.
204
+ DATASET_NAME: Optional[str] = None
205
+
206
+ OUTPUT_TYPE: Optional[OutputType] = None
207
+
208
+ def __init__(
209
+ self,
210
+ data_dir: Optional[str] = None,
211
+ cache_dir: Optional[str] = None,
212
+ download_mode: Optional[datasets.DownloadMode] = None,
213
+ config: Optional[Mapping] = None, # Union[dict, TaskConfig]
214
+ ) -> None:
215
+ """
216
+ :param data_dir: str
217
+ Stores the path to a local folder containing the `Task`'s data files.
218
+ Use this to specify the path to manually downloaded data (usually when
219
+ the dataset is not publicly accessible).
220
+ :param cache_dir: str
221
+ The directory to read/write the `Task` dataset. This follows the
222
+ HuggingFace `datasets` API with the default cache directory located at:
223
+ `~/.cache/huggingface/datasets`
224
+ NOTE: You can change the cache location globally for a given process
225
+ to another directory:
226
+ `export HF_DATASETS_CACHE="/path/to/another/directory"`
227
+ :param download_mode: datasets.DownloadMode
228
+ How to treat pre-existing `Task` downloads and data.
229
+ - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
230
+ Reuse download and reuse dataset.
231
+ - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
232
+ Reuse download with fresh dataset.
233
+ - `datasets.DownloadMode.FORCE_REDOWNLOAD`
234
+ Fresh download and fresh dataset.
235
+ """
236
+ self.download(data_dir, cache_dir, download_mode)
237
+ self._training_docs: Optional[list] = None
238
+ self._fewshot_docs: Optional[list] = None
239
+ self._instances: Optional[List[Instance]] = None
240
+
241
+ self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
242
+
243
+ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
244
+ self.fewshot_rnd: Optional[random.Random] = (
245
+ None # purposely induce errors in case of improper usage
246
+ )
247
+
248
+ def download(
249
+ self,
250
+ data_dir: Optional[str] = None,
251
+ cache_dir: Optional[str] = None,
252
+ download_mode=None,
253
+ ) -> None:
254
+ """Downloads and returns the task dataset.
255
+ Override this method to download the dataset from a custom API.
256
+
257
+ :param data_dir: str
258
+ Stores the path to a local folder containing the `Task`'s data files.
259
+ Use this to specify the path to manually downloaded data (usually when
260
+ the dataset is not publicly accessible).
261
+ :param cache_dir: str
262
+ The directory to read/write the `Task` dataset. This follows the
263
+ HuggingFace `datasets` API with the default cache directory located at:
264
+ `~/.cache/huggingface/datasets`
265
+ NOTE: You can change the cache location globally for a given process
266
+ by setting the shell environment variable, `HF_DATASETS_CACHE`,
267
+ to another directory:
268
+ `export HF_DATASETS_CACHE="/path/to/another/directory"`
269
+ :param download_mode: datasets.DownloadMode
270
+ How to treat pre-existing `Task` downloads and data.
271
+ - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
272
+ Reuse download and reuse dataset.
273
+ - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
274
+ Reuse download with fresh dataset.
275
+ - `datasets.DownloadMode.FORCE_REDOWNLOAD`
276
+ Fresh download and fresh dataset.
277
+ """
278
+ self.dataset = datasets.load_dataset(
279
+ path=self.DATASET_PATH,
280
+ name=self.DATASET_NAME,
281
+ data_dir=data_dir,
282
+ cache_dir=cache_dir,
283
+ download_mode=download_mode,
284
+ )
285
+
286
+ @property
287
+ def config(self) -> TaskConfig:
288
+ """Returns the TaskConfig associated with this class."""
289
+ return self._config
290
+
291
+ @abc.abstractmethod
292
+ def has_training_docs(self):
293
+ """Whether the task has a training set"""
294
+ pass
295
+
296
+ @abc.abstractmethod
297
+ def has_validation_docs(self):
298
+ """Whether the task has a validation set"""
299
+ pass
300
+
301
+ @abc.abstractmethod
302
+ def has_test_docs(self):
303
+ """Whether the task has a test set"""
304
+ pass
305
+
306
+ def training_docs(self) -> Iterable:
307
+ """
308
+ :return: Iterable[obj]
309
+ A iterable of any object, that doc_to_text can handle
310
+ """
311
+ return []
312
+
313
+ def validation_docs(self) -> Iterable:
314
+ """
315
+ :return: Iterable[obj]
316
+ A iterable of any object, that doc_to_text can handle
317
+ """
318
+ return []
319
+
320
+ def test_docs(self) -> Iterable:
321
+ """
322
+ :return: Iterable[obj]
323
+ A iterable of any object, that doc_to_text can handle
324
+ """
325
+ return []
326
+
327
+ def fewshot_docs(self) -> Iterable:
328
+ """
329
+ :return: Iterable[obj]
330
+ A iterable of any object, that doc_to_text can handle
331
+ """
332
+ if self.has_training_docs():
333
+ return self.training_docs()
334
+ elif self.has_validation_docs():
335
+ return self.validation_docs()
336
+ else:
337
+ eval_logger.warning(
338
+ f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
339
+ ", using test_docs as fewshot_docs but this is not recommended."
340
+ )
341
+ return self.test_docs()
342
+
343
+ def _process_doc(self, doc: dict) -> dict:
344
+ """
345
+ Override this to process (detokenize, strip, replace, etc.) individual
346
+ documents. This can be used in a map over documents of a data split.
347
+ E.g. `map(self._process_doc, self.dataset["validation"])`
348
+
349
+ :return: dict
350
+ The processed version of the specified `doc`.
351
+ """
352
+ return doc
353
+
354
+ @property
355
+ def instances(self) -> List[Instance]:
356
+ """After calling `task.build_all_requests()`, tasks
357
+ maintain a list of the dataset instances which will be evaluated.
358
+ """
359
+ return self._instances
360
+
361
+ def fewshot_examples(self, k, rnd):
362
+ if self._training_docs is None:
363
+ self._training_docs = list(self.training_docs())
364
+
365
+ return rnd.sample(self._training_docs, k)
366
+
367
+ def doc_to_decontamination_query(self, doc):
368
+ raise NotImplementedError(
369
+ "Override doc_to_decontamination_query with document specific decontamination query."
370
+ )
371
+
372
+ @abc.abstractmethod
373
+ def doc_to_text(self, doc):
374
+ pass
375
+
376
+ @abc.abstractmethod
377
+ def doc_to_target(self, doc):
378
+ pass
379
+
380
+ def build_all_requests(
381
+ self,
382
+ *,
383
+ limit: Union[int, None] = None,
384
+ rank: int = 0,
385
+ world_size: int = 1,
386
+ cache_requests: bool = False,
387
+ rewrite_requests_cache: bool = False,
388
+ system_instruction: Optional[str] = None,
389
+ apply_chat_template: bool = False,
390
+ fewshot_as_multiturn: bool = False,
391
+ chat_template: Optional[Callable] = None,
392
+ tokenizer_name: str = "",
393
+ ) -> None:
394
+ """Build a set of Instances for a task, and store them in task.instances"""
395
+
396
+ # used with caching
397
+ og_limit = limit
398
+
399
+ cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}"
400
+ cache_key += "-chat_template" if apply_chat_template else ""
401
+ cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else ""
402
+ cache_key += (
403
+ f"-system_prompt_hash{utils.hash_string(system_instruction)}"
404
+ if system_instruction is not None
405
+ else ""
406
+ )
407
+ cache_key += f"-tokenizer{tokenizer_name}"
408
+
409
+ cached_instances = load_from_cache(file_name=cache_key)
410
+
411
+ if cache_requests and cached_instances and not rewrite_requests_cache:
412
+ cached_instances = cached_instances[:limit]
413
+
414
+ flattened_instances = [
415
+ instance
416
+ for instance_group in cached_instances
417
+ for instance in instance_group
418
+ ]
419
+
420
+ self._instances = flattened_instances
421
+ return
422
+
423
+ eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
424
+
425
+ instances = []
426
+
427
+ # process all documents when caching is specified for simplicity
428
+ if (
429
+ cache_requests
430
+ and (not cached_instances or rewrite_requests_cache)
431
+ and limit is not None
432
+ ):
433
+ limit = None
434
+
435
+ doc_id_docs = list(
436
+ self.doc_iterator(rank=rank, limit=limit, world_size=world_size)
437
+ )
438
+
439
+ num_docs = len(doc_id_docs)
440
+
441
+ for doc_id, doc in tqdm(
442
+ doc_id_docs,
443
+ total=num_docs,
444
+ ):
445
+ # sample fewshot context #TODO: need to offset doc_id by rank now!
446
+ fewshot_ctx = self.fewshot_context(
447
+ doc,
448
+ 0 if self.config.num_fewshot is None else self.config.num_fewshot,
449
+ system_instruction,
450
+ apply_chat_template,
451
+ fewshot_as_multiturn,
452
+ chat_template,
453
+ )
454
+
455
+ # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
456
+ inst = self.construct_requests(
457
+ doc=doc,
458
+ ctx=fewshot_ctx,
459
+ metadata=(self.config["task"], doc_id, self.config.repeats),
460
+ )
461
+
462
+ if not isinstance(inst, list):
463
+ inst = [inst]
464
+
465
+ instances.append(inst)
466
+
467
+ # now flatten, this is to allow slicing to work with pickles
468
+
469
+ sliced_instances = instances[:og_limit]
470
+
471
+ flattened_instances = [
472
+ instance
473
+ for instance_group in sliced_instances
474
+ for instance in instance_group
475
+ ]
476
+
477
+ self._instances = flattened_instances
478
+
479
+ if len(self._instances) == 0:
480
+ raise ValueError("task.build_requests() did not find any docs!")
481
+
482
+ if cache_requests and (not cached_instances or rewrite_requests_cache):
483
+ save_to_cache(file_name=cache_key, obj=instances)
484
+
485
+ @abc.abstractmethod
486
+ def construct_requests(self, doc, ctx, **kwargs):
487
+ """Uses RequestFactory to construct Requests and returns an iterable of
488
+ Requests which will be sent to the LM.
489
+
490
+ :param doc:
491
+ The document as returned from training_docs, validation_docs, or test_docs.
492
+ :param ctx: str
493
+ The context string, generated by fewshot_context. This includes the natural
494
+ language description, as well as the few shot examples, and the question
495
+ part of the document for `doc`.
496
+ :param doc_idx: int
497
+ The index of a document within `self.test_docs()` or `self.validation_docs()`,
498
+ whichever is the main split used.
499
+ :param repeats: int
500
+ TODO: update this docstring
501
+ The number of times each instance in a dataset is inferred on. Defaults to 1,
502
+ can be increased for techniques like majority voting.
503
+ """
504
+ pass
505
+
506
+ @abc.abstractmethod
507
+ def process_results(self, doc, results):
508
+ """Take a single document and the LM results and evaluates, returning a
509
+ dict where keys are the names of submetrics and values are the values of
510
+ the metric for that one document
511
+
512
+ :param doc:
513
+ The document as returned from training_docs, validation_docs, or test_docs.
514
+ :param results:
515
+ The results of the requests created in construct_requests.
516
+ """
517
+ pass
518
+
519
+ @abc.abstractmethod
520
+ def aggregation(self):
521
+ """
522
+ :returns: {str: [metric_score] -> float}
523
+ A dictionary where keys are the names of submetrics and values are
524
+ functions that aggregate a list of metric scores
525
+ """
526
+ pass
527
+
528
+ @abc.abstractmethod
529
+ def higher_is_better(self):
530
+ """
531
+ :returns: {str: bool}
532
+ A dictionary where keys are the names of submetrics and values are
533
+ whether a higher value of the submetric is better
534
+ """
535
+ pass
536
+
537
+ def get_config(self, key: str) -> Any:
538
+ return getattr(self._config, key, None)
539
+
540
+ @classmethod
541
+ def count_bytes(cls, doc):
542
+ """Used for byte-level perplexity metrics in rolling loglikelihood"""
543
+ return len(doc.encode("utf-8"))
544
+
545
+ @classmethod
546
+ def count_words(cls, doc):
547
+ """Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
548
+ return len(re.split(r"\s+", doc))
549
+
550
+ @utils.positional_deprecated
551
+ def fewshot_context(
552
+ self,
553
+ doc,
554
+ num_fewshot,
555
+ rnd=None,
556
+ description=None,
557
+ ):
558
+ """Returns a fewshot context string that is made up of a prepended description
559
+ (if provided), the `num_fewshot` number of examples, and an appended prompt example.
560
+
561
+ :param doc: str
562
+ The document as returned from training_docs, validation_docs, or test_docs.
563
+ :param num_fewshot: int
564
+ The number of fewshot examples to provide in the returned context string.
565
+ :param rnd: random.Random
566
+ The pseudo-random number generator used to randomly sample examples.
567
+ WARNING: This is currently a required arg although it's optionalized with a default `None`.
568
+ :param description: str
569
+ The task's description that will be prepended to the fewshot examples.
570
+ :returns: str
571
+ The fewshot context.
572
+ """
573
+ if rnd is None:
574
+ if self.fewshot_rnd is not None:
575
+ rnd = self.fewshot_rnd
576
+ else:
577
+ raise ValueError(
578
+ "A `random.Random` generator argument must be provided to `rnd`"
579
+ )
580
+
581
+ description = description if description else ""
582
+
583
+ if num_fewshot == 0:
584
+ labeled_examples = ""
585
+ else:
586
+ # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
587
+ if self.has_training_docs():
588
+ fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
589
+ else:
590
+ if self._fewshot_docs is None:
591
+ self._fewshot_docs = list(
592
+ self.validation_docs()
593
+ if self.has_validation_docs()
594
+ else self.test_docs()
595
+ )
596
+
597
+ fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
598
+
599
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
600
+ fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
601
+
602
+ labeled_examples = (
603
+ "\n\n".join(
604
+ [
605
+ self.doc_to_text(doc) + self.doc_to_target(doc)
606
+ for doc in fewshotex
607
+ ]
608
+ )
609
+ + "\n\n"
610
+ )
611
+
612
+ example = self.doc_to_text(doc)
613
+ return description + labeled_examples + example
614
+
615
+ def apply_filters(self) -> Optional[List[Instance]]:
616
+ """Iterates over FilterEnsembles and applies them to instances"""
617
+ if hasattr(self, "_filters"):
618
+ for f in self._filters:
619
+ f.apply(self._instances)
620
+ else:
621
+ eval_logger.warning("No filter defined, passing through instances")
622
+ return self._instances
623
+
624
+ def dump_config(self) -> dict:
625
+ """Returns the config as a dictionary."""
626
+ # TODO: this should only return the overrides applied to a non-YAML task's configuration.
627
+ # (num_fewshot)
628
+ return self.config.to_dict()
629
+
630
+ def set_config(self, key: str, value: Any, update: bool = False) -> None:
631
+ """Set or update the configuration for a given key."""
632
+ if key is None:
633
+ raise ValueError("Key must be provided.")
634
+
635
+ if update:
636
+ current_value = getattr(self._config, key, {})
637
+ if not isinstance(current_value, dict):
638
+ raise TypeError(
639
+ f"Expected a dict for key '{key}', got {type(current_value).__name__} instead."
640
+ )
641
+ current_value.update(value)
642
+ else:
643
+ setattr(self._config, key, value)
644
+
645
+ def override_metric(self, metric_name: str) -> None:
646
+ """
647
+ Override the default metrics used for evaluation with custom metrics.
648
+
649
+ Parameters:
650
+ - metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
651
+ """
652
+ (
653
+ self._metric_fn_list,
654
+ self._aggregation_list,
655
+ self._metric_fn_kwargs,
656
+ self._higher_is_better,
657
+ ) = ({}, {}, {}, {})
658
+ self._metric_fn_list[metric_name] = get_metric(metric_name)
659
+ self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
660
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
661
+ self._metric_fn_kwargs[metric_name] = {}
662
+ if not isinstance(self, ConfigurableTask):
663
+ self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
664
+ self.aggregation = lambda: {
665
+ metric_name: get_metric_aggregation(metric_name)
666
+ }
667
+ setattr(self._config, "metric_list", [{"metric": metric_name}])
668
+ setattr(self._config, "process_results", None)
669
+
670
+ def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
671
+ self.fewshot_rnd = random.Random(seed)
672
+ if hasattr(self, "sampler"):
673
+ self.sampler.rnd = self.fewshot_rnd
674
+
675
+ @property
676
+ def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
677
+ if self.has_test_docs():
678
+ return self.test_docs()
679
+ elif self.has_validation_docs():
680
+ return self.validation_docs()
681
+ else:
682
+ raise ValueError(
683
+ f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
684
+ )
685
+
686
+ def doc_iterator(
687
+ self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
688
+ ) -> Iterator[Tuple[int, Any]]:
689
+ limit = int(limit) if limit else None
690
+ doc_iterator = utils.create_iterator(
691
+ enumerate(self.eval_docs),
692
+ rank=int(rank),
693
+ limit=limit,
694
+ world_size=int(world_size),
695
+ )
696
+ return doc_iterator
697
+
698
+
699
+ class ConfigurableTask(Task):
700
+ VERSION = "Yaml"
701
+ OUTPUT_TYPE = None
702
+ CONFIG = None
703
+
704
+ def __init__(
705
+ self,
706
+ data_dir=None,
707
+ cache_dir=None,
708
+ download_mode=None,
709
+ config: Optional[dict] = None,
710
+ ) -> None: # TODO no super() call here
711
+ # Get pre-configured attributes
712
+ self._config = self.CONFIG
713
+
714
+ # Use new configurations if there was no preconfiguration
715
+ if self.config is None:
716
+ self._config = TaskConfig(**config)
717
+ # Overwrite configs
718
+ else:
719
+ if config is not None:
720
+ self._config.__dict__.update(config)
721
+
722
+ if self.config is None:
723
+ raise ValueError(
724
+ "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
725
+ )
726
+
727
+ if isinstance(self.config.metadata, dict):
728
+ if "version" in self.config.metadata:
729
+ self.VERSION = self.config.metadata["version"]
730
+
731
+ if self.config.output_type is not None:
732
+ if self.config.output_type not in ALL_OUTPUT_TYPES:
733
+ raise ValueError(
734
+ f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
735
+ )
736
+ self.OUTPUT_TYPE = self.config.output_type
737
+
738
+ if self.config.dataset_path is not None:
739
+ self.DATASET_PATH = self.config.dataset_path
740
+
741
+ if self.config.dataset_name is not None:
742
+ self.DATASET_NAME = self.config.dataset_name
743
+
744
+ self._metric_fn_list = {}
745
+ self._metric_fn_kwargs = {}
746
+ self._aggregation_list = {}
747
+ self._higher_is_better = {}
748
+
749
+ if self.config.metric_list is None:
750
+ # TODO: handle this in TaskConfig.__post_init__ ?
751
+ _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
752
+
753
+ for metric_name in _metric_list:
754
+ self._metric_fn_list[metric_name] = get_metric(metric_name)
755
+ self._metric_fn_kwargs[metric_name] = {}
756
+ self._aggregation_list[metric_name] = get_metric_aggregation(
757
+ metric_name
758
+ )
759
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
760
+ else:
761
+ for metric_config in self.config.metric_list:
762
+ if "metric" not in metric_config:
763
+ raise ValueError(
764
+ "'metric' key not provided for an entry in 'metric_list', must be specified!"
765
+ )
766
+ metric_name = metric_config["metric"]
767
+ kwargs = {
768
+ key: metric_config[key]
769
+ for key in metric_config
770
+ if key
771
+ not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
772
+ }
773
+ hf_evaluate_metric = (
774
+ "hf_evaluate" in metric_config
775
+ and metric_config["hf_evaluate"] is True
776
+ )
777
+
778
+ if self.config.process_results is not None:
779
+ self._metric_fn_list[metric_name] = None
780
+ self._metric_fn_kwargs[metric_name] = {}
781
+ elif callable(metric_name):
782
+ metric_fn = metric_name.__call__
783
+ metric_name = metric_name.__name__
784
+ self._metric_fn_list[metric_name] = metric_fn
785
+ self._metric_fn_kwargs[metric_name] = kwargs
786
+ else:
787
+ self._metric_fn_list[metric_name] = get_metric(
788
+ metric_name, hf_evaluate_metric
789
+ )
790
+ self._metric_fn_kwargs[metric_name] = kwargs
791
+
792
+ if "aggregation" in metric_config:
793
+ agg_name = metric_config["aggregation"]
794
+ if isinstance(agg_name, str):
795
+ self._aggregation_list[metric_name] = get_aggregation(agg_name)
796
+ elif callable(agg_name): # noqa: E721
797
+ self._aggregation_list[metric_name] = metric_config[
798
+ "aggregation"
799
+ ]
800
+ else:
801
+ INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
802
+ metric_agg = get_metric_aggregation(metric_name)
803
+ eval_logger.warning(
804
+ f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
805
+ f"using default "
806
+ f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
807
+ )
808
+ self._aggregation_list[metric_name] = metric_agg
809
+
810
+ if "higher_is_better" in metric_config:
811
+ self._higher_is_better[metric_name] = metric_config[
812
+ "higher_is_better"
813
+ ]
814
+ else:
815
+ eval_logger.warning(
816
+ f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
817
+ f"using default "
818
+ f"higher_is_better={is_higher_better(metric_name)}"
819
+ )
820
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
821
+
822
+ self.download(self.config.dataset_kwargs)
823
+ self._training_docs = None
824
+ self._fewshot_docs = None
825
+
826
+ if self.config.filter_list is not None:
827
+ self._filters = []
828
+ for filter_config in self.config.filter_list:
829
+ filter_name = filter_config["name"]
830
+ filter_functions = filter_config["filter"]
831
+ components = []
832
+ for function in filter_functions:
833
+ kwargs = {
834
+ key: function[key] for key in function if key != "function"
835
+ }
836
+ components.append([function["function"], kwargs])
837
+ filter_pipeline = build_filter_ensemble(filter_name, components)
838
+ self._filters.append(filter_pipeline)
839
+ else:
840
+ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
841
+
842
+ if self.config.use_prompt is not None:
843
+ eval_logger.info(f"loading prompt {self.config.use_prompt}")
844
+ self.prompt = get_prompt(
845
+ self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
846
+ )
847
+ else:
848
+ self.prompt = None
849
+
850
+ if self.fewshot_docs() is not None:
851
+ self.fewshot_rnd = (
852
+ random.Random()
853
+ ) # setting with no seed, to be overridden at a later time
854
+ config_sampler: Union[str, Callable] = (
855
+ self.config.fewshot_config.get("sampler", "default")
856
+ if self.config.fewshot_config
857
+ else "default"
858
+ )
859
+ if isinstance(config_sampler, str):
860
+ self.sampler = samplers.get_sampler(config_sampler)(
861
+ list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
862
+ )
863
+ elif callable(config_sampler) and issubclass(
864
+ config_sampler, samplers.ContextSampler
865
+ ):
866
+ self.sampler = config_sampler(
867
+ docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
868
+ )
869
+ else:
870
+ raise TypeError(
871
+ f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
872
+ f"not {type(config_sampler)}"
873
+ )
874
+
875
+ self.task_docs = self.eval_docs
876
+
877
+ # Test One Doc
878
+ self.features = list(self.task_docs.features.keys())
879
+ self.multiple_input = 0
880
+ self.multiple_target = 0
881
+ test_doc = self.task_docs[0]
882
+ test_text = self.doc_to_text(test_doc)
883
+ test_target = self.doc_to_target(test_doc)
884
+
885
+ if self.config.doc_to_choice is not None:
886
+ test_choice = self.doc_to_choice(test_doc)
887
+ if not isinstance(test_choice, list):
888
+ eval_logger.error("doc_to_choice must return list")
889
+ else:
890
+ num_choice = len(test_choice)
891
+
892
+ if isinstance(test_text, int):
893
+ self.multiple_input = num_choice
894
+ else:
895
+ test_choice = None
896
+
897
+ if isinstance(test_target, list):
898
+ self.multiple_target = len(test_target)
899
+ else:
900
+ if (isinstance(test_target, int)) and (test_choice is not None):
901
+ test_target = test_choice[test_target]
902
+ else:
903
+ test_target = str(test_target)
904
+
905
+ if test_choice is not None:
906
+ check_choices = test_choice
907
+ else:
908
+ check_choices = [test_target]
909
+ if self.config.doc_to_choice is not None:
910
+ for choice in check_choices:
911
+ choice_has_whitespace = True if choice[0].isspace() else False
912
+ delimiter_has_whitespace = (
913
+ True
914
+ if self.config.target_delimiter.rstrip()
915
+ != self.config.target_delimiter
916
+ else False
917
+ )
918
+
919
+ if delimiter_has_whitespace and choice_has_whitespace:
920
+ eval_logger.debug(
921
+ f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
922
+ )
923
+ elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
924
+ eval_logger.debug(
925
+ f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
926
+ )
927
+
928
+ def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
929
+ self.dataset = datasets.load_dataset(
930
+ path=self.DATASET_PATH,
931
+ name=self.DATASET_NAME,
932
+ **dataset_kwargs if dataset_kwargs is not None else {},
933
+ )
934
+
935
+ def has_training_docs(self) -> bool:
936
+ if self.config.training_split is not None:
937
+ return True
938
+ else:
939
+ return False
940
+
941
+ def has_validation_docs(self) -> bool:
942
+ if self.config.validation_split is not None:
943
+ return True
944
+ else:
945
+ return False
946
+
947
+ def has_test_docs(self) -> bool:
948
+ if self.config.test_split is not None:
949
+ return True
950
+ else:
951
+ return False
952
+
953
+ def training_docs(self) -> datasets.Dataset:
954
+ if self.has_training_docs():
955
+ if self.config.process_docs is not None:
956
+ return self.config.process_docs(
957
+ self.dataset[self.config.training_split]
958
+ )
959
+ return self.dataset[self.config.training_split]
960
+
961
+ def validation_docs(self) -> datasets.Dataset:
962
+ if self.has_validation_docs():
963
+ if self.config.process_docs is not None:
964
+ return self.config.process_docs(
965
+ self.dataset[self.config.validation_split]
966
+ )
967
+ return self.dataset[self.config.validation_split]
968
+
969
+ def test_docs(self) -> datasets.Dataset:
970
+ if self.has_test_docs():
971
+ if self.config.process_docs is not None:
972
+ return self.config.process_docs(self.dataset[self.config.test_split])
973
+ return self.dataset[self.config.test_split]
974
+
975
+ def fewshot_docs(self):
976
+ if self.config.fewshot_split is not None:
977
+ if self.config.process_docs is not None:
978
+ return self.config.process_docs(self.dataset[self.config.fewshot_split])
979
+ return self.dataset[self.config.fewshot_split]
980
+ elif (
981
+ self.config.fewshot_config is not None
982
+ and self.config.fewshot_config.get("samples", None) is not None
983
+ ):
984
+ if isinstance(self.config.fewshot_config["samples"], list):
985
+ return self.config.fewshot_config["samples"]
986
+ elif callable(self.config.fewshot_config["samples"]):
987
+ return self.config.fewshot_config["samples"]()
988
+ else:
989
+ raise Exception(
990
+ "`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list."
991
+ )
992
+ else:
993
+ if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
994
+ eval_logger.warning(
995
+ f"[Task: {self.config.task}] "
996
+ "num_fewshot > 0 but fewshot_split is None. "
997
+ "using preconfigured rule."
998
+ )
999
+ return super().fewshot_docs()
1000
+
1001
+ @staticmethod
1002
+ def append_target_question(
1003
+ labeled_examples: List[Dict[str, str]],
1004
+ question: str,
1005
+ fewshot_as_multiturn: bool = False,
1006
+ ) -> None:
1007
+ """Adds a target question to the labeled examples list.
1008
+ If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
1009
+ Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant.
1010
+ """
1011
+ if not fewshot_as_multiturn:
1012
+ # if no messages or last message is system, append as new user entry
1013
+ if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
1014
+ labeled_examples.append({"role": "user", "content": question})
1015
+ # if last message is user, append to it to avoid two user messages in a row
1016
+ else:
1017
+ labeled_examples[-1]["content"] += question
1018
+ else:
1019
+ # if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
1020
+ labeled_examples.append({"role": "user", "content": question})
1021
+
1022
+ @utils.positional_deprecated
1023
+ def fewshot_context(
1024
+ self,
1025
+ doc: str,
1026
+ num_fewshot: int,
1027
+ system_instruction: Optional[str] = None,
1028
+ apply_chat_template: bool = False,
1029
+ fewshot_as_multiturn: bool = False,
1030
+ chat_template: Optional[Callable] = None,
1031
+ ) -> str:
1032
+ """Returns a fewshot context string that is made up of a prepended description
1033
+ (if provided), the `num_fewshot` number of examples, and an appended prompt example.
1034
+
1035
+ :param doc: str
1036
+ The document as returned from training_docs, validation_docs, or test_docs.
1037
+ :param num_fewshot: int
1038
+ The number of fewshot examples to provide in the returned context string.
1039
+ :param system_instruction: str
1040
+ System instruction to be applied to the prompt.
1041
+ :param apply_chat_template: bool
1042
+ Whether to apply the chat template to the fewshot context.
1043
+ :param fewshot_as_multiturn: bool
1044
+ Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
1045
+ :param chat_template: Callable
1046
+ Chat template to be applied to the fewshot context.
1047
+ :returns: str
1048
+ The fewshot context.
1049
+ """
1050
+
1051
+ if apply_chat_template:
1052
+ labeled_examples = []
1053
+ else:
1054
+ labeled_examples = ""
1055
+
1056
+ # get task description
1057
+ if description := self.config.description:
1058
+ description = utils.apply_template(self.config.description, doc)
1059
+
1060
+ # create system prompt based on the provided system instruction and description
1061
+ if system_instruction is not None and description:
1062
+ system_prompt = (
1063
+ f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
1064
+ )
1065
+ elif system_instruction is not None:
1066
+ system_prompt = system_instruction
1067
+ elif description:
1068
+ system_prompt = description
1069
+ else:
1070
+ system_prompt = ""
1071
+
1072
+ # add system prompt if specified
1073
+ if system_prompt:
1074
+ if apply_chat_template:
1075
+ labeled_examples.append({"role": "system", "content": system_prompt})
1076
+ else:
1077
+ labeled_examples = system_prompt
1078
+
1079
+ # if few-shot - append examples after the system prompt
1080
+ if num_fewshot > 0:
1081
+ if apply_chat_template:
1082
+ labeled_examples.extend(
1083
+ self.sampler.get_chat_context(
1084
+ doc, num_fewshot, fewshot_as_multiturn
1085
+ )
1086
+ )
1087
+ else:
1088
+ labeled_examples += self.sampler.get_context(doc, num_fewshot)
1089
+
1090
+ example = self.doc_to_text(doc)
1091
+ if apply_chat_template:
1092
+ if self.multiple_input:
1093
+ return chat_template(labeled_examples)
1094
+ if isinstance(example, str):
1095
+ self.append_target_question(
1096
+ labeled_examples, example, fewshot_as_multiturn
1097
+ )
1098
+ # for loglikelihood create a list of questions with appended choices
1099
+ elif isinstance(example, list):
1100
+ labeled_examples_list = []
1101
+ # copy chat history for each example and append the answer
1102
+ for ex in example:
1103
+ chat = deepcopy(labeled_examples)
1104
+ self.append_target_question(chat, ex, fewshot_as_multiturn)
1105
+ labeled_examples_list.append(chat_template(chat))
1106
+ return labeled_examples_list
1107
+ # if example is an integer, append the choice or convert to string
1108
+ elif isinstance(example, int):
1109
+ if self.config.doc_to_choice is not None:
1110
+ choices = self.doc_to_choice(doc)
1111
+ self.append_target_question(
1112
+ labeled_examples, choices[example], fewshot_as_multiturn
1113
+ )
1114
+ else:
1115
+ self.append_target_question(
1116
+ labeled_examples, str(example), fewshot_as_multiturn
1117
+ )
1118
+ # return lm.apply_chat_template(labeled_examples)
1119
+ return chat_template(labeled_examples)
1120
+ else:
1121
+ if self.multiple_input:
1122
+ return labeled_examples
1123
+ if isinstance(example, str):
1124
+ return labeled_examples + example
1125
+ elif isinstance(example, list):
1126
+ return [labeled_examples + ex for ex in example]
1127
+ elif isinstance(example, int):
1128
+ if self.config.doc_to_choice is not None:
1129
+ choices = self.doc_to_choice(doc)
1130
+ return labeled_examples + choices[example]
1131
+ else:
1132
+ return labeled_examples + str(example)
1133
+
1134
+ def apply_filters(self):
1135
+ """Iterates over FilterEnsembles and applies them to instances"""
1136
+ if hasattr(self, "_filters"):
1137
+ for f in self._filters:
1138
+ f.apply(self._instances)
1139
+ else:
1140
+ eval_logger.warning("No filter defined, passing through instances")
1141
+ return self._instances
1142
+
1143
+ def should_decontaminate(self):
1144
+ return self.config.should_decontaminate
1145
+
1146
+ def doc_to_decontamination_query(self, doc):
1147
+ if self.config.should_decontaminate:
1148
+ if self.config.doc_to_decontamination_query is None:
1149
+ return self.doc_to_text(doc)
1150
+ else:
1151
+ doc_to_decontamination_query = self.config.doc_to_decontamination_query
1152
+ if doc_to_decontamination_query in self.features:
1153
+ return doc[doc_to_decontamination_query]
1154
+ elif callable(doc_to_decontamination_query):
1155
+ return doc_to_decontamination_query(doc)
1156
+ else:
1157
+ return ast.literal_eval(
1158
+ utils.apply_template(
1159
+ self.config.doc_to_decontamination_query, doc
1160
+ )
1161
+ )
1162
+
1163
+ def _process_doc(self, doc: dict) -> dict:
1164
+ """
1165
+ Override this to process (detokenize, strip, replace, etc.) individual
1166
+ documents. This can be used in a map over documents of a data split.
1167
+ E.g. `map(self._process_doc, self.dataset["validation"])`
1168
+
1169
+ :return: dict
1170
+ The processed version of the specified `doc`.
1171
+ """
1172
+ return doc
1173
+
1174
+ def doc_to_text(self, doc, doc_to_text=None):
1175
+ if self.prompt is not None:
1176
+ doc_to_text = self.prompt
1177
+ elif doc_to_text is not None:
1178
+ doc_to_text = doc_to_text
1179
+ else:
1180
+ doc_to_text = self.config.doc_to_text
1181
+
1182
+ if isinstance(doc_to_text, int):
1183
+ return doc_to_text
1184
+ elif isinstance(doc_to_text, str):
1185
+ if doc_to_text in self.features:
1186
+ # if self.config.doc_to_choice is not None:
1187
+ # return self.doc_to_choice(doc)[doc[doc_to_text]]
1188
+ # else:
1189
+ return doc[doc_to_text]
1190
+ else:
1191
+ text_string = utils.apply_template(doc_to_text, doc)
1192
+ if text_string.isdigit() and self._config.doc_to_choice is not None:
1193
+ return ast.literal_eval(text_string)
1194
+ else:
1195
+ return text_string
1196
+ elif callable(doc_to_text):
1197
+ return doc_to_text(doc)
1198
+ # Used when applying a Promptsource template
1199
+ elif hasattr(doc_to_text, "apply"):
1200
+ applied_prompt = doc_to_text.apply(doc)
1201
+ if len(applied_prompt) == 2:
1202
+ return applied_prompt[0]
1203
+ else:
1204
+ eval_logger.warning("Applied prompt returns empty string")
1205
+ return self.config.fewshot_delimiter
1206
+ else:
1207
+ print(type(doc_to_text))
1208
+ raise TypeError
1209
+
1210
+ def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
1211
+ if self.prompt is not None:
1212
+ doc_to_target = self.prompt
1213
+ elif doc_to_target is not None:
1214
+ doc_to_target = doc_to_target
1215
+ else:
1216
+ doc_to_target = self.config.doc_to_target
1217
+
1218
+ if isinstance(doc_to_target, int):
1219
+ return doc_to_target
1220
+ elif isinstance(doc_to_target, str):
1221
+ if doc_to_target in self.features:
1222
+ # if self.config.doc_to_choice is not None:
1223
+ # return self.doc_to_choice(doc)[doc[doc_to_target]]
1224
+ # else:
1225
+ return doc[doc_to_target]
1226
+ else:
1227
+ target_string = utils.apply_template(doc_to_target, doc)
1228
+ if target_string.isdigit() and self._config.doc_to_choice is not None:
1229
+ return ast.literal_eval(target_string)
1230
+ elif (
1231
+ len(target_string) >= 2
1232
+ and (target_string[0] == "[")
1233
+ and (target_string[-1] == "]")
1234
+ ):
1235
+ try:
1236
+ return ast.literal_eval(target_string)
1237
+ except (SyntaxError, ValueError):
1238
+ return target_string
1239
+ else:
1240
+ return target_string
1241
+ elif isinstance(doc_to_target, list):
1242
+ return doc_to_target
1243
+ elif callable(doc_to_target):
1244
+ return doc_to_target(doc)
1245
+ # Used when applying a Promptsource template
1246
+ elif hasattr(doc_to_target, "apply"):
1247
+ applied_prompt = doc_to_target.apply(doc)
1248
+ if len(applied_prompt) == 2:
1249
+ return applied_prompt[1]
1250
+ else:
1251
+ eval_logger.warning("Applied prompt returns empty string")
1252
+ return self.config.fewshot_delimiter
1253
+ else:
1254
+ raise TypeError
1255
+
1256
+ def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
1257
+ if self.prompt is not None:
1258
+ doc_to_choice = self.prompt
1259
+ elif doc_to_choice is not None:
1260
+ doc_to_choice = doc_to_choice
1261
+ elif self.config.doc_to_choice is None:
1262
+ eval_logger.error("doc_to_choice was called but not set in config")
1263
+ else:
1264
+ doc_to_choice = self.config.doc_to_choice
1265
+
1266
+ if isinstance(doc_to_choice, str):
1267
+ if doc_to_choice in self.features:
1268
+ return doc[doc_to_choice]
1269
+ else:
1270
+ return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
1271
+ elif isinstance(doc_to_choice, list):
1272
+ return doc_to_choice
1273
+ elif isinstance(doc_to_choice, dict):
1274
+ return list(doc_to_choice.values())
1275
+ elif callable(doc_to_choice):
1276
+ return doc_to_choice(doc)
1277
+ elif hasattr(doc_to_choice, "get_answer_choices_list"):
1278
+ return doc_to_choice.get_answer_choices_list(doc)
1279
+ else:
1280
+ raise TypeError
1281
+
1282
+ def construct_requests(
1283
+ self, doc: dict, ctx: str, **kwargs
1284
+ ) -> Union[List[Instance], Instance]:
1285
+ if self.OUTPUT_TYPE == "loglikelihood":
1286
+ arguments = (ctx, self.doc_to_target(doc))
1287
+ elif self.OUTPUT_TYPE == "loglikelihood_rolling":
1288
+ arguments = (self.doc_to_target(doc),)
1289
+ elif self.OUTPUT_TYPE == "multiple_choice":
1290
+ choices = self.doc_to_choice(doc)
1291
+ target_delimiter = self.config.target_delimiter
1292
+ if self.multiple_input:
1293
+ # If there are multiple inputs, choices are placed in the ctx
1294
+ cont = self.doc_to_target(doc)
1295
+ arguments = [
1296
+ (ctx + choice, f"{target_delimiter}{cont}") for choice in choices
1297
+ ]
1298
+ else:
1299
+ # Otherwise they are placed in the continuation
1300
+ arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
1301
+
1302
+ request_list = [
1303
+ Instance(
1304
+ request_type="loglikelihood",
1305
+ doc=doc,
1306
+ arguments=arg,
1307
+ idx=i,
1308
+ **kwargs,
1309
+ )
1310
+ for i, arg in enumerate(arguments)
1311
+ ]
1312
+ # TODO: we should raise a warning telling users this will at most ~2x runtime.
1313
+ if "acc_mutual_info" in self._metric_fn_list.keys():
1314
+ # if we are calculating multiple choice accuracy
1315
+ # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
1316
+
1317
+ # here mutual info refers to calculating
1318
+ # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
1319
+ # in other words normalizing by subtracting the unconditional logprob of each choice.
1320
+ request_list.extend(
1321
+ [
1322
+ Instance(
1323
+ request_type="loglikelihood",
1324
+ doc=doc,
1325
+ arguments=("", "{}".format(choice)),
1326
+ idx=i,
1327
+ **kwargs,
1328
+ )
1329
+ for i, choice in enumerate(choices)
1330
+ ]
1331
+ )
1332
+ return request_list
1333
+
1334
+ elif self.OUTPUT_TYPE == "generate_until":
1335
+ arguments = (ctx, deepcopy(self.config.generation_kwargs))
1336
+
1337
+ return Instance(
1338
+ request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
1339
+ )
1340
+
1341
+ def process_results(self, doc, results):
1342
+ if callable(self.config.process_results):
1343
+ return self.config.process_results(doc, results)
1344
+
1345
+ result_dict = {}
1346
+ use_metric = list(self._metric_fn_list.keys())
1347
+ if self.OUTPUT_TYPE == "loglikelihood":
1348
+ results = results[0]
1349
+ ll, is_greedy = results
1350
+ return {
1351
+ **({"perplexity": ll} if "perplexity" in use_metric else {}),
1352
+ **({"acc": int(is_greedy)} if "acc" in use_metric else {}),
1353
+ }
1354
+ elif self.OUTPUT_TYPE == "loglikelihood_rolling":
1355
+ (loglikelihood,) = results
1356
+ _words = self.count_words(self.doc_to_target(doc))
1357
+ _bytes = self.count_bytes(self.doc_to_target(doc))
1358
+ return {
1359
+ **(
1360
+ {"word_perplexity": (loglikelihood, _words)}
1361
+ if "word_perplexity" in use_metric
1362
+ else {}
1363
+ ),
1364
+ **(
1365
+ {"byte_perplexity": (loglikelihood, _bytes)}
1366
+ if "byte_perplexity" in use_metric
1367
+ else {}
1368
+ ),
1369
+ **(
1370
+ {"bits_per_byte": (loglikelihood, _bytes)}
1371
+ if "bits_per_byte" in use_metric
1372
+ else {}
1373
+ ),
1374
+ }
1375
+ elif self.OUTPUT_TYPE == "multiple_choice":
1376
+ lls, is_greedy = zip(*results)
1377
+
1378
+ # retrieve choices in List[str] form, to compute choice lengths, etc.
1379
+ choices = self.doc_to_choice(doc)
1380
+ completion_len = np.array([float(len(i)) for i in choices])
1381
+
1382
+ if (
1383
+ 2 * len(choices) == len(lls)
1384
+ and "acc_mutual_info" in self._metric_fn_list.keys()
1385
+ ):
1386
+ # then we are doing mutual info.
1387
+ # this stores the "dryrun" / unconditional answer loglikelihoods
1388
+ lls_unconditional = lls[1::2]
1389
+ if len(lls_unconditional) != len(choices):
1390
+ raise ValueError
1391
+ # and this stores our "regular" conditional loglikelihoods
1392
+ lls = lls[::2]
1393
+
1394
+ pred = np.argmax(lls)
1395
+ pred_norm = np.argmax(lls / completion_len)
1396
+
1397
+ if self.multiple_input:
1398
+ gold = self.doc_to_text(doc)
1399
+ else:
1400
+ gold = self.doc_to_target(doc)
1401
+
1402
+ gold_index_error = False
1403
+ if isinstance(gold, list):
1404
+ gold = [i if i < len(choices) else -100 for i in gold]
1405
+ if -100 in gold:
1406
+ gold_index_error = True
1407
+ else:
1408
+ if isinstance(gold, int):
1409
+ gold = gold if gold < len(choices) else -100
1410
+ elif isinstance(gold, str):
1411
+ gold = choices.index(gold) if gold in choices else -100
1412
+
1413
+ if gold == -100:
1414
+ gold_index_error = True
1415
+
1416
+ if gold_index_error:
1417
+ eval_logger.warning(
1418
+ f"Label index was not in within range of available choices,"
1419
+ f"Sample:\n\n{doc}\n\n"
1420
+ )
1421
+
1422
+ if self.multiple_target:
1423
+ acc = 1.0 if pred in gold else 0.0
1424
+ acc_norm = 1.0 if pred_norm in gold else 0.0
1425
+ exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
1426
+ else:
1427
+ acc = 1.0 if pred == gold else 0.0
1428
+ acc_norm = 1.0 if pred_norm == gold else 0.0
1429
+ # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
1430
+ exact_match = int(is_greedy[gold]) if gold != -100 else 0
1431
+
1432
+ prob_norm = utils.softmax(lls)
1433
+
1434
+ # TODO use keyword arguments to the metric?
1435
+ # gold, pred, norm stuff, the original lls,
1436
+ result_dict = {
1437
+ **({"acc": acc} if "acc" in use_metric else {}),
1438
+ **({"f1": (gold, pred)} if "f1" in use_metric else {}),
1439
+ **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
1440
+ **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
1441
+ **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
1442
+ **(
1443
+ {"brier_score": (gold, prob_norm)}
1444
+ if "brier_score" in use_metric
1445
+ else {}
1446
+ ),
1447
+ }
1448
+
1449
+ if "acc_mutual_info" in use_metric:
1450
+ lls_mutual_info = [
1451
+ ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
1452
+ ]
1453
+ acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
1454
+ result_dict["acc_mutual_info"] = acc_mutual_info
1455
+
1456
+ elif self.OUTPUT_TYPE == "generate_until":
1457
+ gold = self.doc_to_target(doc)
1458
+ result = results[0]
1459
+ if self.config.doc_to_choice is not None:
1460
+ # If you set doc_to_choice,
1461
+ # it assumes that doc_to_target returns a number.
1462
+ choices = self.doc_to_choice(doc)
1463
+ gold = choices[gold]
1464
+ # we expect multiple_targets to be a list.
1465
+ elif self.multiple_target:
1466
+ gold = list(gold)
1467
+ elif type(gold) != type(result):
1468
+ # cast gold to the same type as result
1469
+ gold = type(result)(gold)
1470
+
1471
+ for metric in self._metric_fn_list.keys():
1472
+ if self.multiple_target:
1473
+ # in the case where we have multiple targets,
1474
+ # return true if any are true
1475
+ # TODO: this may break for multipLe_target, non zero-or-1 metrics
1476
+ scores = []
1477
+ if not isinstance(gold, list):
1478
+ # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
1479
+ # print(gold)
1480
+ gold = [gold]
1481
+ if metric == "exact_match":
1482
+ result = [result for _ in range(len(gold))]
1483
+ scores = self._metric_fn_list[metric](
1484
+ references=gold,
1485
+ predictions=result,
1486
+ **self._metric_fn_kwargs[metric],
1487
+ )[metric]
1488
+ result_score = 1.0 if scores > 0.0 else 0.0
1489
+ else:
1490
+ for gold_option in gold:
1491
+ try:
1492
+ result_score = self._metric_fn_list[metric](
1493
+ references=[gold_option],
1494
+ predictions=[result],
1495
+ **self._metric_fn_kwargs[metric],
1496
+ )
1497
+ except (
1498
+ TypeError
1499
+ ): # TODO: this is hacky and I don't want to do it
1500
+ result_score = self._metric_fn_list[metric](
1501
+ [gold_option, result]
1502
+ )
1503
+ if isinstance(result_score, dict):
1504
+ # TODO: this handles the case where HF evaluate returns a dict.
1505
+ result_score = result_score[metric]
1506
+ scores.append(result_score)
1507
+ if any(scores):
1508
+ result_score = 1.0
1509
+ else:
1510
+ result_score = 0.0
1511
+ else:
1512
+ try:
1513
+ result_score = self._metric_fn_list[metric](
1514
+ references=[gold],
1515
+ predictions=[result],
1516
+ **self._metric_fn_kwargs[metric],
1517
+ )
1518
+ except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
1519
+ result_score = self._metric_fn_list[metric]([gold, result])
1520
+ if isinstance(result_score, dict):
1521
+ # TODO: this handles the case where HF evaluate returns a dict.
1522
+ result_score = result_score[metric]
1523
+ result_dict[metric] = result_score
1524
+ else:
1525
+ raise ValueError(
1526
+ f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
1527
+ "'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
1528
+ )
1529
+
1530
+ return result_dict
1531
+
1532
+ def aggregation(self) -> dict:
1533
+ return self._aggregation_list
1534
+
1535
+ def higher_is_better(self) -> dict:
1536
+ return self._higher_is_better
1537
+
1538
+ def get_config(self, key: str) -> Any:
1539
+ return getattr(self._config, key, None)
1540
+
1541
+ @property
1542
+ def task_name(self) -> Any:
1543
+ return getattr(self.config, "task", None)
1544
+
1545
+ def __repr__(self):
1546
+ return (
1547
+ f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
1548
+ f"output_type={self.OUTPUT_TYPE},"
1549
+ f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
1550
+ f"num_samples={len(self.eval_docs)})"
1551
+ )
1552
+
1553
+
1554
+ class MultipleChoiceTask(Task):
1555
+ OUTPUT_TYPE = "loglikelihood"
1556
+
1557
+ def doc_to_target(self, doc: dict) -> str:
1558
+ return " " + doc["choices"][doc["gold"]]
1559
+
1560
+ def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
1561
+ # TODO: add mutual info here?
1562
+ return [
1563
+ Instance(
1564
+ request_type="loglikelihood",
1565
+ doc=doc,
1566
+ arguments=(ctx, " {}".format(choice)),
1567
+ idx=i,
1568
+ **kwargs,
1569
+ )
1570
+ for i, choice in enumerate(doc["choices"])
1571
+ ]
1572
+
1573
+ def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
1574
+ results = [
1575
+ res[0] for res in results
1576
+ ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
1577
+ gold = doc["gold"]
1578
+
1579
+ acc = 1.0 if np.argmax(results) == gold else 0.0
1580
+ completion_len = np.array([float(len(i)) for i in doc["choices"]])
1581
+ acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
1582
+
1583
+ return {
1584
+ "acc": acc,
1585
+ "acc_norm": acc_norm,
1586
+ }
1587
+
1588
+ def higher_is_better(self) -> dict:
1589
+ return {
1590
+ "acc": True,
1591
+ "acc_norm": True,
1592
+ }
1593
+
1594
+ def aggregation(self) -> dict:
1595
+ return {
1596
+ "acc": mean,
1597
+ "acc_norm": mean,
1598
+ }
1599
+
1600
+
1601
+ class PerplexityTask(Task):
1602
+ OUTPUT_TYPE = "loglikelihood_rolling"
1603
+
1604
+ def has_training_docs(self) -> bool:
1605
+ return False
1606
+
1607
+ def fewshot_examples(self, k: int, rnd) -> List:
1608
+ if k != 0:
1609
+ raise ValueError(
1610
+ "The number of fewshot examples must be 0 for perplexity tasks."
1611
+ )
1612
+ return []
1613
+
1614
+ def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
1615
+ if num_fewshot != 0:
1616
+ raise ValueError(
1617
+ "The number of fewshot examples must be 0 for perplexity tasks."
1618
+ )
1619
+
1620
+ return ""
1621
+
1622
+ def higher_is_better(self) -> dict:
1623
+ return {
1624
+ "word_perplexity": False,
1625
+ "byte_perplexity": False,
1626
+ "bits_per_byte": False,
1627
+ }
1628
+
1629
+ def doc_to_decontamination_query(self, doc):
1630
+ return doc
1631
+
1632
+ def doc_to_text(self, doc) -> str:
1633
+ return ""
1634
+
1635
+ def doc_to_target(self, doc):
1636
+ return doc
1637
+
1638
+ def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
1639
+ if bool(ctx):
1640
+ raise ValueError
1641
+
1642
+ return Instance(
1643
+ request_type=self.OUTPUT_TYPE,
1644
+ doc=doc,
1645
+ arguments=(self.doc_to_target(doc),),
1646
+ idx=0,
1647
+ **kwargs,
1648
+ )
1649
+
1650
+ def process_results(self, doc: dict, results: Tuple[float]) -> dict:
1651
+ (loglikelihood,) = results
1652
+ words = self.count_words(self.doc_to_target(doc))
1653
+ bytes_ = self.count_bytes(self.doc_to_target(doc))
1654
+ return {
1655
+ "word_perplexity": (loglikelihood, words),
1656
+ "byte_perplexity": (loglikelihood, bytes_),
1657
+ "bits_per_byte": (loglikelihood, bytes_),
1658
+ }
1659
+
1660
+ def aggregation(self) -> dict:
1661
+ return {
1662
+ "word_perplexity": weighted_perplexity,
1663
+ "byte_perplexity": weighted_perplexity,
1664
+ "bits_per_byte": bits_per_byte,
1665
+ }
1666
+
1667
+ @classmethod
1668
+ def count_bytes(cls, doc) -> int:
1669
+ return len(doc.encode("utf-8"))
1670
+
1671
+ @classmethod
1672
+ def count_words(cls, doc) -> int:
1673
+ """Downstream tasks with custom word boundaries should override this!"""
1674
+ return len(re.split(r"\s+", doc))
scripts/yans/lm-evaluation-harness/lm_eval/models/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import (
2
+ anthropic_llms,
3
+ api_models,
4
+ dummy,
5
+ gguf,
6
+ huggingface,
7
+ mamba_lm,
8
+ nemo_lm,
9
+ neuralmagic,
10
+ neuron_optimum,
11
+ openai_completions,
12
+ optimum_lm,
13
+ textsynth,
14
+ vllm_causallms,
15
+ )
16
+
17
+
18
+ # TODO: implement __all__
19
+
20
+
21
+ try:
22
+ # enable hf hub transfer if available
23
+ import hf_transfer # type: ignore # noqa
24
+ import huggingface_hub.constants # type: ignore
25
+
26
+ huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
27
+ except ImportError:
28
+ pass
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (631 Bytes). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/anthropic_llms.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/api_models.cpython-310.pyc ADDED
Binary file (16.6 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/dummy.cpython-310.pyc ADDED
Binary file (1.58 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/gguf.cpython-310.pyc ADDED
Binary file (4.11 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/huggingface.cpython-310.pyc ADDED
Binary file (29.8 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/mamba_lm.cpython-310.pyc ADDED
Binary file (3.69 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/nemo_lm.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/neuralmagic.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/neuron_optimum.cpython-310.pyc ADDED
Binary file (18.3 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/openai_completions.cpython-310.pyc ADDED
Binary file (6.39 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/optimum_lm.cpython-310.pyc ADDED
Binary file (2.65 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/textsynth.cpython-310.pyc ADDED
Binary file (5.23 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/utils.cpython-310.pyc ADDED
Binary file (21.3 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/__pycache__/vllm_causallms.cpython-310.pyc ADDED
Binary file (14.3 kB). View file
 
scripts/yans/lm-evaluation-harness/lm_eval/models/anthropic_llms.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import cached_property
3
+ from typing import Any, Dict, List, Tuple, Union
4
+
5
+ from tqdm import tqdm
6
+
7
+ from lm_eval import utils
8
+ from lm_eval.api.model import LM
9
+ from lm_eval.api.registry import register_model
10
+ from lm_eval.models.openai_completions import LocalCompletionsAPI
11
+ from lm_eval.models.utils import retry_on_specific_exceptions
12
+
13
+
14
+ eval_logger = utils.eval_logger
15
+
16
+
17
+ def anthropic_completion(
18
+ client, #: anthropic.Anthropic,
19
+ model: str,
20
+ prompt: str,
21
+ max_tokens_to_sample: int,
22
+ temperature: float,
23
+ stop: List[str],
24
+ **kwargs: Any,
25
+ ) -> str:
26
+ """Wrapper function around the Anthropic completion API client with exponential back-off
27
+ in case of RateLimitError.
28
+
29
+ params:
30
+ client: anthropic.Anthropic
31
+ Anthropic API client
32
+ model: str
33
+ Anthropic model e.g. 'claude-instant-v1', 'claude-2'
34
+ prompt: str
35
+ Prompt to feed to the model
36
+ max_tokens_to_sample: int
37
+ Maximum number of tokens to sample from the model
38
+ temperature: float
39
+ Sampling temperature
40
+ stop: List[str]
41
+ List of stop sequences
42
+ kwargs: Any
43
+ Additional model_args to pass to the API client
44
+ """
45
+
46
+ try:
47
+ import anthropic
48
+ except ModuleNotFoundError:
49
+ raise Exception(
50
+ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
51
+ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
52
+ )
53
+
54
+ def _exception_callback(e: Exception, sleep_time: float) -> None:
55
+ eval_logger.warning(
56
+ f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
57
+ )
58
+
59
+ @retry_on_specific_exceptions(
60
+ on_exceptions=[anthropic.RateLimitError],
61
+ max_retries=None, # retry forever, consider changing
62
+ on_exception_callback=_exception_callback,
63
+ )
64
+ def completion():
65
+ response = client.completions.create(
66
+ prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
67
+ model=model,
68
+ # NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
69
+ # (e.g. gsm8k's ":") may truncate a lot of the input.
70
+ stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
71
+ max_tokens_to_sample=max_tokens_to_sample,
72
+ temperature=temperature,
73
+ **kwargs,
74
+ )
75
+ return response.completion
76
+
77
+ return completion()
78
+
79
+
80
+ def anthropic_chat(
81
+ client, #: anthropic.Anthropic,
82
+ model: str,
83
+ prompt: str,
84
+ max_tokens: int,
85
+ temperature: float,
86
+ stop: List[str],
87
+ **kwargs: Any,
88
+ ) -> str:
89
+ """Wrapper function around the Anthropic completion API client with exponential back-off
90
+ in case of RateLimitError.
91
+
92
+ params:
93
+ client: anthropic.Anthropic
94
+ Anthropic API client
95
+ model: str
96
+ Anthropic model e.g. 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'
97
+ prompt: str
98
+ Prompt to feed to the model
99
+ max_tokens: int
100
+ Maximum number of tokens to sample from the model
101
+ temperature: float
102
+ Sampling temperature
103
+ stop: List[str]
104
+ List of stop sequences
105
+ kwargs: Any
106
+ Additional model_args to pass to the API client
107
+ """
108
+
109
+ try:
110
+ import anthropic
111
+ except ModuleNotFoundError:
112
+ raise Exception(
113
+ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
114
+ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
115
+ )
116
+
117
+ def _exception_callback(e: Exception, sleep_time: float) -> None:
118
+ eval_logger.warning(
119
+ f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
120
+ )
121
+
122
+ @retry_on_specific_exceptions(
123
+ on_exceptions=[
124
+ anthropic.RateLimitError,
125
+ anthropic.APIConnectionError,
126
+ anthropic.APIStatusError,
127
+ ],
128
+ max_retries=None, # retry forever, consider changing
129
+ on_exception_callback=_exception_callback,
130
+ )
131
+ def messages():
132
+ response = client.messages.create(
133
+ model=model,
134
+ max_tokens=max_tokens,
135
+ temperature=temperature,
136
+ messages=[{"role": "user", "content": f"{prompt}"}],
137
+ **kwargs,
138
+ )
139
+ return response.content[0].text
140
+
141
+ return messages()
142
+
143
+
144
+ @register_model("anthropic-completions")
145
+ class AnthropicLM(LM):
146
+ REQ_CHUNK_SIZE = 20 # TODO: not used
147
+
148
+ def __init__(
149
+ self,
150
+ batch_size: int = 1,
151
+ model: str = "claude-2.0",
152
+ max_tokens_to_sample: int = 256,
153
+ temperature: float = 0, # defaults to 1
154
+ **kwargs, # top_p, top_k, etc.
155
+ ) -> None:
156
+ """Anthropic API wrapper.
157
+
158
+ :param model: str
159
+ Anthropic model e.g. 'claude-instant-v1', 'claude-2'
160
+ :param max_tokens_to_sample: int
161
+ Maximum number of tokens to sample from the model
162
+ :param temperature: float
163
+ Sampling temperature
164
+ :param kwargs: Any
165
+ Additional model_args to pass to the API client
166
+ """
167
+ super().__init__()
168
+
169
+ try:
170
+ import anthropic
171
+ except ModuleNotFoundError:
172
+ raise Exception(
173
+ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
174
+ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
175
+ )
176
+
177
+ self.model = model
178
+ # defaults to os.environ.get("ANTHROPIC_API_KEY")
179
+ self.client = anthropic.Anthropic()
180
+ self.temperature = temperature
181
+ self.max_tokens_to_sample = max_tokens_to_sample
182
+ self.tokenizer = self.client.get_tokenizer()
183
+ self.kwargs = kwargs
184
+
185
+ @property
186
+ def eot_token_id(self):
187
+ # Not sure but anthropic.HUMAN_PROMPT ?
188
+ raise NotImplementedError("No idea about anthropic tokenization.")
189
+
190
+ @property
191
+ def max_length(self) -> int:
192
+ return 2048
193
+
194
+ @property
195
+ def max_gen_toks(self) -> int:
196
+ return self.max_tokens_to_sample
197
+
198
+ @property
199
+ def batch_size(self):
200
+ # Isn't used because we override _loglikelihood_tokens
201
+ raise NotImplementedError("No support for logits.")
202
+
203
+ @property
204
+ def device(self):
205
+ # Isn't used because we override _loglikelihood_tokens
206
+ raise NotImplementedError("No support for logits.")
207
+
208
+ def tok_encode(self, string: str) -> List[int]:
209
+ return self.tokenizer.encode(string).ids
210
+
211
+ def tok_decode(self, tokens: List[int]) -> str:
212
+ return self.tokenizer.decode(tokens)
213
+
214
+ def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
215
+ raise NotImplementedError("No support for logits.")
216
+
217
+ def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
218
+ try:
219
+ import anthropic
220
+ except ModuleNotFoundError:
221
+ raise Exception(
222
+ "attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
223
+ please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
224
+ )
225
+
226
+ if not requests:
227
+ return []
228
+
229
+ _requests: List[Tuple[str, dict]] = [req.args for req in requests]
230
+
231
+ res = []
232
+ for request in tqdm(_requests, disable=disable_tqdm):
233
+ try:
234
+ inp = request[0]
235
+ request_args = request[1]
236
+ # generation_kwargs
237
+ until = request_args.get("until")
238
+ max_gen_toks = request_args.get("max_gen_toks", self.max_length)
239
+ temperature = request_args.get("temperature", self.temperature)
240
+ response = anthropic_completion(
241
+ client=self.client,
242
+ model=self.model,
243
+ prompt=inp,
244
+ max_tokens_to_sample=max_gen_toks,
245
+ temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
246
+ stop=until, # type: ignore
247
+ **self.kwargs,
248
+ )
249
+ res.append(response)
250
+
251
+ self.cache_hook.add_partial("generate_until", request, response)
252
+ except anthropic.APIConnectionError as e: # type: ignore # noqa: F821
253
+ eval_logger.critical(f"Server unreachable: {e.__cause__}")
254
+ break
255
+ except anthropic.APIStatusError as e: # type: ignore # noqa: F821
256
+ eval_logger.critical(f"API error {e.status_code}: {e.message}")
257
+ break
258
+
259
+ return res
260
+
261
+ def _model_call(self, inps):
262
+ # Isn't used because we override _loglikelihood_tokens
263
+ raise NotImplementedError()
264
+
265
+ def _model_generate(self, context, max_length, eos_token_id):
266
+ # Isn't used because we override generate_until
267
+ raise NotImplementedError()
268
+
269
+ def loglikelihood(self, requests, disable_tqdm: bool = False):
270
+ raise NotImplementedError("No support for logits.")
271
+
272
+ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
273
+ raise NotImplementedError("No support for logits.")
274
+
275
+
276
+ @register_model("anthropic-chat", "anthropic-chat-completions")
277
+ class AnthropicChat(LocalCompletionsAPI):
278
+ def __init__(
279
+ self,
280
+ base_url="https://api.anthropic.com/v1/messages",
281
+ tokenizer_backend=None,
282
+ **kwargs,
283
+ ):
284
+ super().__init__(
285
+ base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
286
+ )
287
+ eval_logger.warning(
288
+ "Chat completions does not support batching. Defaulting to batch size 1."
289
+ )
290
+ self._batch_size = 1
291
+ self.anthropic_version = "2023-06-01"
292
+ eval_logger.warning(
293
+ f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning"
294
+ )
295
+
296
+ @cached_property
297
+ def api_key(self):
298
+ """Override this property to return the API key for the API request."""
299
+ key = os.environ.get("ANTHROPIC_API_KEY", None)
300
+ if key is None:
301
+ raise ValueError(
302
+ "API key not found. Please set the ANTHROPIC_API_KEY environment variable."
303
+ )
304
+ return key
305
+
306
+ @cached_property
307
+ def header(self):
308
+ return {
309
+ "x-api-key": f"{self.api_key}",
310
+ "anthropic-version": self.anthropic_version,
311
+ }
312
+
313
+ def _create_payload(
314
+ self, messages: List[Dict], generate=True, gen_kwargs: dict = None, **kwargs
315
+ ) -> dict:
316
+ system = (
317
+ messages[0].get("content") if messages[0].get("role") == "system" else None
318
+ )
319
+ if system:
320
+ messages = messages[1:]
321
+ gen_kwargs.pop("do_sample", False)
322
+ max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
323
+ temperature = gen_kwargs.pop("temperature", 0)
324
+ stop = gen_kwargs.pop("until", ["\n\nHuman:"])
325
+ if not isinstance(stop, list):
326
+ stop = [stop]
327
+ out = {
328
+ "messages": messages,
329
+ "model": self.model,
330
+ "max_tokens": max_tokens,
331
+ "temperature": temperature,
332
+ "stop_sequences": stop,
333
+ **gen_kwargs,
334
+ }
335
+ if system:
336
+ out["system"] = system
337
+ return out
338
+
339
+ def parse_generations(
340
+ self, outputs: Union[Dict, List[Dict]], **kwargs
341
+ ) -> List[str]:
342
+ res = []
343
+ if not isinstance(outputs, list):
344
+ outputs = [outputs]
345
+ for out in outputs:
346
+ for choices in out["content"]:
347
+ res.append(choices["text"])
348
+ return res
349
+
350
+ def tok_encode(
351
+ self,
352
+ string: str,
353
+ left_truncate_len=None,
354
+ add_special_tokens=None,
355
+ **kwargs,
356
+ ) -> List[str]:
357
+ return [string]
358
+
359
+ def loglikelihood(self, requests, **kwargs):
360
+ raise NotImplementedError(
361
+ "Anthropic Chat Completions API does not support the return of loglikelihood"
362
+ )
scripts/yans/lm-evaluation-harness/lm_eval/models/api_models.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import asyncio
3
+ import copy
4
+ import itertools
5
+ import json
6
+ from functools import cached_property
7
+ from typing import (
8
+ Any,
9
+ Awaitable,
10
+ Callable,
11
+ Dict,
12
+ Iterable,
13
+ List,
14
+ Literal,
15
+ NamedTuple,
16
+ Optional,
17
+ Tuple,
18
+ Union,
19
+ )
20
+
21
+
22
+ try:
23
+ import requests
24
+ from aiohttp import ClientSession, TCPConnector
25
+ from tenacity import RetryError, retry, stop_after_attempt, wait_exponential
26
+ from tqdm import tqdm
27
+ from tqdm.asyncio import tqdm_asyncio
28
+ except ModuleNotFoundError:
29
+ pass
30
+
31
+
32
+ from importlib.util import find_spec
33
+
34
+ from lm_eval import utils
35
+ from lm_eval.api.instance import Instance
36
+ from lm_eval.api.model import TemplateLM
37
+ from lm_eval.models.utils import Collator, chunks, configure_pad_token
38
+
39
+
40
+ LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]]
41
+
42
+
43
+ # utility class to keep track of json encoded chats
44
+ class JsonChatStr(NamedTuple):
45
+ prompt: str
46
+
47
+ def encode(self, encoding):
48
+ return self.prompt.encode(encoding)
49
+
50
+
51
+ eval_logger = utils.eval_logger
52
+
53
+
54
+ class TemplateAPI(TemplateLM):
55
+ def __init__(
56
+ self,
57
+ model: str = None,
58
+ pretrained: str = None, # `model` takes precedence over `pretrained` when passed.
59
+ base_url: str = None,
60
+ tokenizer: Optional[str] = None,
61
+ # Logliklehood tasks require a tokenizer to calculate context lengths,
62
+ # however the requests can be sent as a string if the API doesn't support token inputs.
63
+ # use tokenized_requests=False
64
+ tokenizer_backend: Optional[
65
+ Literal["tiktoken", "huggingface", None]
66
+ ] = "huggingface",
67
+ truncate: bool = False,
68
+ # number of concurrent requests. More useful if not batching
69
+ num_concurrent: int = 1,
70
+ max_retries: int = 3,
71
+ max_gen_toks: int = 256,
72
+ batch_size: Union[str, int] = 1,
73
+ seed: int = 1234,
74
+ max_length: Optional[int] = 2048,
75
+ add_bos_token: bool = False,
76
+ custom_prefix_token_id=None,
77
+ # send the requests as tokens or strings
78
+ tokenized_requests=True,
79
+ **kwargs,
80
+ ) -> None:
81
+ super().__init__()
82
+ missing_packages = [
83
+ pkg
84
+ for pkg in ["aiohttp", "tqdm", "tenacity", "requests"]
85
+ if find_spec(pkg) is None
86
+ ]
87
+ if missing_packages:
88
+ raise ModuleNotFoundError(
89
+ f"Attempted to use an API model, but the required packages {missing_packages} are not installed. "
90
+ 'Please install these via `pip install lm-eval[api]` or `pip install -e ."[api]"`'
91
+ )
92
+ self.model = model or pretrained
93
+ self.base_url = base_url
94
+ self.tokenizer = tokenizer
95
+ if not isinstance(batch_size, int) and "auto" in batch_size:
96
+ eval_logger.warning(
97
+ "Automatic batch size is not supported for API models. Defaulting to batch size 1."
98
+ )
99
+ elif int(batch_size) > 1:
100
+ eval_logger.warning(
101
+ "Batch size > 1 detected. Ensure your API supports batched requests with varying total sequence lengths."
102
+ )
103
+ self._batch_size = int(batch_size) if batch_size != "auto" else 1
104
+ self._truncate = truncate
105
+ self._max_gen_toks = int(max_gen_toks)
106
+ self._seed = int(seed)
107
+ self.max_length = max_length
108
+ if int(num_concurrent) <= 1:
109
+ eval_logger.info(
110
+ "Concurrent requests are disabled. To enable concurrent requests, set `num_concurrent` > 1."
111
+ )
112
+ self._concurrent = int(num_concurrent)
113
+ self.tokenizer_backend = tokenizer_backend
114
+ self.add_bos_token = add_bos_token
115
+ self.custom_prefix_token_id = custom_prefix_token_id
116
+ self.tokenized_requests = tokenized_requests
117
+ self.max_retries = int(max_retries)
118
+
119
+ eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
120
+ if self.tokenizer_backend is None:
121
+ self.tokenizer = None
122
+ self.tokenized_requests = False
123
+ else:
124
+ if self.tokenizer is None:
125
+ if self.tokenizer_backend == "huggingface":
126
+ import transformers
127
+
128
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
129
+ self.tokenizer if self.tokenizer else self.model
130
+ )
131
+ # Not used as the API will handle padding but to mirror the behavior of the HFLM
132
+ self.tokenizer = configure_pad_token(self.tokenizer)
133
+ elif self.tokenizer_backend == "tiktoken":
134
+ try:
135
+ import tiktoken
136
+
137
+ self.tokenizer = tiktoken.encoding_for_model(self.model)
138
+ except ModuleNotFoundError as e:
139
+ raise Exception(
140
+ "Attempted to use 'openai' LM type, but the package `tiktoken` is not installed. "
141
+ "Please install it via `pip install lm-eval[api]` or `pip install -e .[api]`."
142
+ ) from e
143
+ if "openai" not in self.base_url:
144
+ eval_logger.warning(
145
+ f"Passed `base_url={self.base_url}` but using (OpenAI) Tiktoken tokenizer backend. "
146
+ "Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken."
147
+ )
148
+ else:
149
+ import transformers
150
+
151
+ assert isinstance(tokenizer, str), "tokenizer must be a string"
152
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
153
+ tokenizer,
154
+ )
155
+
156
+ @abc.abstractmethod
157
+ def _create_payload(
158
+ self,
159
+ messages: Union[List[List[int]], List[dict], List[str], str],
160
+ *,
161
+ generate: bool = True,
162
+ gen_kwargs: Optional[dict] = None,
163
+ seed: int = 1234,
164
+ **kwargs,
165
+ ) -> dict:
166
+ """This method is responsible for creating the json payload that will be sent to the API."""
167
+ raise NotImplementedError
168
+
169
+ def create_message(
170
+ self,
171
+ messages: Union[List[List[int]], List[str], List[JsonChatStr]],
172
+ generate=False,
173
+ ) -> Union[List[List[int]], List[dict], List[str], str]:
174
+ """Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
175
+ if isinstance(messages[0], JsonChatStr):
176
+ # for chat completions we need to decode the json string to list[dict,...]
177
+ assert (
178
+ self._batch_size == 1
179
+ ), "non-tokenized chat requests are only supported with batch_size=1"
180
+ # list[dict["role":..., "content":...],...]
181
+ return json.loads(messages[0].prompt)
182
+
183
+ if not self.tokenized_requests:
184
+ # if messages are tokenized:
185
+ if isinstance(messages[0][0], int):
186
+ # assuming decoding is lossless. However, this is only for logliklehood requests
187
+ # as we need to compute the context length. For generations, we don't need to tokenize.
188
+ messages = self.decode_batch(messages)
189
+ if self._batch_size <= 1:
190
+ # if batch is 1 return str
191
+ return messages[0]
192
+ else:
193
+ # list[str,...]
194
+ return messages
195
+
196
+ # list[list[int], ...]
197
+ return messages
198
+
199
+ @staticmethod
200
+ @abc.abstractmethod
201
+ def parse_logprobs(
202
+ outputs: Union[Any, List[Any]],
203
+ tokens: List[List[int]] = None,
204
+ ctxlen: List[int] = None,
205
+ **kwargs,
206
+ ) -> List[Tuple[float, bool]]:
207
+ """Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
208
+ raise NotImplementedError
209
+
210
+ @staticmethod
211
+ @abc.abstractmethod
212
+ def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]:
213
+ """Method used to parse the generations from the (batched) API response. This method should return a list of str"""
214
+ raise NotImplementedError
215
+
216
+ @cached_property
217
+ def api_key(self) -> str:
218
+ """Override this property to return the API key for the API request."""
219
+ return ""
220
+
221
+ @cached_property
222
+ def header(self) -> dict:
223
+ """Override this property to return the headers for the API request."""
224
+ return {"Authorization": f"Bearer {self.api_key}"}
225
+
226
+ @property
227
+ def chat_template(self) -> str:
228
+ """Must be defined for LM subclasses that implement Chat Templating.
229
+ Should return the structure of the chat template applied to user/assistant messages.
230
+ Only used for logging and reproducibility.
231
+ """
232
+ return ""
233
+
234
+ @property
235
+ def tokenizer_name(self) -> str:
236
+ """Must be defined for LM subclasses which implement Chat Templating.
237
+ Should return the name of the tokenizer or chat template used.
238
+ Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
239
+ """
240
+ return ""
241
+
242
+ def apply_chat_template(
243
+ self, chat_history: List[Dict[str, str]]
244
+ ) -> Union[str, JsonChatStr]:
245
+ """Applies a chat template to a list of chat history between user and model."""
246
+ if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
247
+ return self.tokenizer.apply_chat_template(
248
+ chat_history, tokenize=False, add_generation_prompt=True
249
+ )
250
+ else:
251
+ # bit of a hack. We'll load back before sending to the API
252
+ return JsonChatStr(json.dumps(chat_history))
253
+
254
+ @cached_property
255
+ def eot_token_id(self) -> Optional[int]:
256
+ if self.tokenizer is None:
257
+ return None
258
+ else:
259
+ if self.tokenizer_backend == "huggingface":
260
+ return self.tokenizer.eos_token_id
261
+ elif self.tokenizer_backend == "tiktoken":
262
+ return self.tokenizer.eot_token
263
+
264
+ @cached_property
265
+ def prefix_token_id(self) -> Optional[int]:
266
+ if self.tokenizer is None:
267
+ return None
268
+ else:
269
+ if self.custom_prefix_token_id is not None:
270
+ return self.custom_prefix_token_id
271
+ if self.tokenizer_backend == "huggingface":
272
+ if self.tokenizer.bos_token_id is not None:
273
+ return self.tokenizer.bos_token_id
274
+ return self.tokenizer.eos_token_id
275
+ else:
276
+ return self.tokenizer.eot_token
277
+
278
+ def tok_encode(
279
+ self,
280
+ string: str,
281
+ left_truncate_len: int = None,
282
+ add_special_tokens: bool = False,
283
+ truncation: bool = False,
284
+ **kwargs,
285
+ ) -> Union[List[List[int]], List[int], List[str]]:
286
+ if self.tokenizer_backend is None:
287
+ return [string]
288
+ elif self.tokenizer_backend == "huggingface":
289
+ # by default for CausalLM - false or self.add_bos_token is set
290
+ if not add_special_tokens:
291
+ add_special_tokens = False or self.add_bos_token
292
+ encoding: Union[List[List[int]], List[int]] = self.tokenizer(
293
+ string,
294
+ add_special_tokens=add_special_tokens,
295
+ truncation=truncation,
296
+ return_attention_mask=False,
297
+ ).input_ids
298
+
299
+ # left-truncate the encoded context to be at most `left_truncate_len` tokens long
300
+ if left_truncate_len:
301
+ if not isinstance(string, str):
302
+ encoding = [enc[-left_truncate_len:] for enc in encoding]
303
+ else:
304
+ encoding = encoding[-left_truncate_len:]
305
+
306
+ return encoding
307
+
308
+ else:
309
+ try:
310
+ encoding = self.tokenizer.encode(string)
311
+ except Exception:
312
+ encoding = self.tokenizer.encode_batch(string)
313
+ return encoding
314
+
315
+ def decode_batch(self, tokens: List[List[int]]) -> List[str]:
316
+ if self.tokenizer_backend == "huggingface":
317
+ return self.tokenizer.batch_decode(tokens)
318
+ elif self.tokenizer_backend == "tiktoken":
319
+ return self.tokenizer.decode_batch(tokens)
320
+
321
+ def model_call(
322
+ self,
323
+ messages: Union[List[List[int]], List[str], List[JsonChatStr]],
324
+ *,
325
+ generate: bool = True,
326
+ gen_kwargs: Optional[Dict] = None,
327
+ **kwargs,
328
+ ) -> Optional[dict]:
329
+ # !!! Copy: shared dict for each request, need new object !!!
330
+ gen_kwargs = copy.deepcopy(gen_kwargs)
331
+ try:
332
+ response = requests.post(
333
+ self.base_url,
334
+ json=self._create_payload(
335
+ self.create_message(messages),
336
+ generate=generate,
337
+ gen_kwargs=gen_kwargs,
338
+ seed=self._seed,
339
+ **kwargs,
340
+ ),
341
+ headers=self.header,
342
+ )
343
+ if not response.ok:
344
+ eval_logger.warning(
345
+ f"API request failed with error message: {response.text}. Retrying..."
346
+ )
347
+ response.raise_for_status()
348
+ return response.json()
349
+ except RetryError:
350
+ eval_logger.error(
351
+ "API request failed after multiple retries. Please check the API status."
352
+ )
353
+ return None
354
+
355
+ async def amodel_call(
356
+ self,
357
+ session: ClientSession,
358
+ messages: Union[List[List[int]], List[str], List[JsonChatStr]],
359
+ *,
360
+ generate: bool = True,
361
+ cache_keys: list = None,
362
+ ctxlens: Optional[List[int]] = None,
363
+ gen_kwargs: Optional[Dict] = None,
364
+ **kwargs,
365
+ ) -> Union[List[str], List[Tuple[float, bool]], None]:
366
+ # !!! Copy: shared dict for each request, need new object !!!
367
+ gen_kwargs = copy.deepcopy(gen_kwargs)
368
+ payload = self._create_payload(
369
+ self.create_message(messages),
370
+ generate=generate,
371
+ gen_kwargs=gen_kwargs,
372
+ seed=self._seed,
373
+ **kwargs,
374
+ )
375
+ cache_method = "generate_until" if generate else "loglikelihood"
376
+ try:
377
+ async with session.post(
378
+ self.base_url,
379
+ json=payload,
380
+ headers=self.header,
381
+ ) as response:
382
+ if not response.ok:
383
+ error_text = await response.text()
384
+ eval_logger.warning(
385
+ f"API request failed with error message: {error_text}. Retrying..."
386
+ )
387
+ # raising exception will retry the request
388
+ response.raise_for_status()
389
+ outputs = await response.json()
390
+ answers = (
391
+ self.parse_generations(
392
+ outputs=outputs,
393
+ )
394
+ if generate
395
+ else self.parse_logprobs(
396
+ outputs=outputs,
397
+ tokens=messages,
398
+ ctxlens=ctxlens,
399
+ )
400
+ )
401
+ if cache_keys:
402
+ for res, cache in zip(answers, cache_keys):
403
+ self.cache_hook.add_partial(cache_method, cache, res)
404
+ return answers
405
+ # If the retries also fail
406
+ except RetryError:
407
+ eval_logger.error(
408
+ "API request failed after multiple retries. Please check the API status."
409
+ )
410
+ return None
411
+
412
+ def batch_logliklehood_requests(
413
+ self, chunks: Iterable[List[LogLikelihoodInputs]]
414
+ ) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]:
415
+ inputs = []
416
+ ctxlens = []
417
+ cache_keys = []
418
+ for chunk in chunks:
419
+ for cache_key, context_enc, continuation_enc in chunk:
420
+ inp = (context_enc + continuation_enc)[-(self.max_length) :]
421
+ ctxlen = len(context_enc) - max(
422
+ 0, len(context_enc) + len(continuation_enc) - (self.max_length)
423
+ )
424
+
425
+ inputs.append(inp)
426
+ ctxlens.append(ctxlen)
427
+ cache_keys.append(cache_key)
428
+ return inputs, ctxlens, cache_keys
429
+
430
+ async def get_batched_requests(
431
+ self,
432
+ requests: list,
433
+ cache_keys: list,
434
+ *,
435
+ generate: bool = True,
436
+ ctxlens: List[int] = None,
437
+ **kwargs,
438
+ ) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
439
+ ctxlens = ctxlens if ctxlens else [None] * len(requests)
440
+ conn = TCPConnector(limit=self._concurrent)
441
+ async with ClientSession(connector=conn) as session:
442
+ retry_: Callable[..., Awaitable[Any]] = retry(
443
+ stop=stop_after_attempt(self.max_retries),
444
+ wait=wait_exponential(multiplier=0.5, min=1, max=10),
445
+ reraise=True,
446
+ )(self.amodel_call)
447
+ # Create tasks for each batch of request
448
+ tasks = [
449
+ asyncio.create_task(
450
+ retry_(
451
+ session=session,
452
+ messages=message,
453
+ cache_keys=cache_key,
454
+ generate=generate,
455
+ ctxlens=ctxlen,
456
+ **kwargs,
457
+ )
458
+ )
459
+ for message, cache_key, ctxlen in zip(
460
+ chunks(requests, n=self._batch_size),
461
+ chunks(cache_keys, n=self._batch_size),
462
+ chunks(ctxlens, n=self._batch_size),
463
+ )
464
+ ]
465
+
466
+ return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
467
+
468
+ def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
469
+ assert (
470
+ self.tokenizer is not None
471
+ ), "Tokenizer is required for loglikelihood tasks to compute context lengths."
472
+ res = []
473
+
474
+ def _collate(req: LogLikelihoodInputs):
475
+ """Defines the key for the sorted method"""
476
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
477
+ # - time estimates will always be over not underestimates, which is more useful for planning
478
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
479
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
480
+ # automatic adaptive batches much much easier to implement
481
+ # - any OOMs will happen right away rather than near the end
482
+
483
+ toks = req[1] + req[2]
484
+ return -len(toks), tuple(toks)
485
+
486
+ re_ord = Collator(
487
+ requests,
488
+ sort_fn=_collate,
489
+ group_by=None,
490
+ )
491
+ # if concurrent then we'll batch in the async context
492
+ chunked = re_ord.get_batched(n=self._batch_size if self._concurrent <= 1 else 0)
493
+ if self._concurrent <= 1:
494
+ pbar = tqdm(desc="Requesting API", total=len(requests))
495
+ for chunk in chunked:
496
+ inputs, ctxlens, cache_keys = self.batch_logliklehood_requests([chunk])
497
+
498
+ outputs = retry(
499
+ stop=stop_after_attempt(self.max_retries),
500
+ wait=wait_exponential(multiplier=0.5, min=1, max=10),
501
+ reraise=True,
502
+ )(self.model_call)(messages=inputs, generate=False)
503
+ if isinstance(outputs, dict):
504
+ outputs = [outputs]
505
+ for answer_, cache_key in zip(
506
+ self.parse_logprobs(
507
+ outputs=outputs, tokens=inputs, ctxlens=ctxlens
508
+ ),
509
+ cache_keys,
510
+ ):
511
+ if answer_ is not None:
512
+ res.append(answer_)
513
+ # partial caching
514
+ if cache_key is not None:
515
+ self.cache_hook.add_partial(
516
+ "loglikelihood", cache_key, answer_
517
+ )
518
+ pbar.update(1)
519
+ else:
520
+ inputs, ctxlens, cache_keys = self.batch_logliklehood_requests(chunked)
521
+ res = itertools.chain.from_iterable(
522
+ asyncio.run(
523
+ self.get_batched_requests(
524
+ inputs, cache_keys, generate=False, ctxlens=ctxlens
525
+ )
526
+ )
527
+ )
528
+
529
+ return re_ord.get_original(res)
530
+
531
+ def generate_until(
532
+ self, requests: List[Instance], disable_tqdm: bool = False
533
+ ) -> List[str]:
534
+ res = []
535
+
536
+ def _collate_gen(_requests):
537
+ # sort by the length of the non-tokenized contexts
538
+ return -len(_requests[0])
539
+
540
+ # Let the API deal with tokenization
541
+ requests, all_gen_kwargs = zip(*(req.args for req in requests))
542
+ if self.tokenized_requests:
543
+ encodings_list = self.tok_encode(
544
+ requests, add_special_tokens=self.add_bos_token
545
+ )
546
+ else:
547
+ encodings_list = [None] * len(requests)
548
+ requests = [
549
+ (a, b, c) for a, b, c in zip(requests, all_gen_kwargs, encodings_list)
550
+ ]
551
+
552
+ re_ord = Collator(
553
+ requests,
554
+ sort_fn=_collate_gen,
555
+ group_by="gen_kwargs",
556
+ )
557
+ chunked = re_ord.get_batched(
558
+ n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None
559
+ )
560
+ if self._concurrent <= 1:
561
+ pbar = tqdm(desc="Requesting API", total=len(requests))
562
+ for chunk in chunked:
563
+ contexts, all_gen_kwargs, encodings_list = zip(*chunk)
564
+ req = encodings_list if self.tokenized_requests else contexts
565
+ outputs = retry(
566
+ stop=stop_after_attempt(self.max_retries),
567
+ wait=wait_exponential(multiplier=0.5, min=1, max=10),
568
+ reraise=True,
569
+ )(self.model_call)(
570
+ messages=req,
571
+ generate=True,
572
+ gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
573
+ )
574
+ for generated_text, context in zip(
575
+ self.parse_generations(
576
+ outputs=outputs,
577
+ contexts=contexts,
578
+ ),
579
+ contexts,
580
+ ):
581
+ if generated_text is not None:
582
+ res.append(generated_text)
583
+
584
+ # partial caching
585
+ if context is not None:
586
+ self.cache_hook.add_partial(
587
+ "generate_until",
588
+ (context, all_gen_kwargs[0]),
589
+ generated_text,
590
+ )
591
+ pbar.update(1)
592
+ else:
593
+ for chunk in chunked:
594
+ contexts, all_gen_kwargs, encodings_list = zip(*chunk)
595
+ req = encodings_list if self.tokenized_requests else contexts
596
+ results = itertools.chain.from_iterable(
597
+ asyncio.run(
598
+ self.get_batched_requests(
599
+ req,
600
+ cache_keys=[(ctx, all_gen_kwargs[0]) for ctx in contexts],
601
+ generate=True,
602
+ gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
603
+ )
604
+ )
605
+ )
606
+ res.extend(results)
607
+
608
+ return re_ord.get_original(res)
609
+
610
+ def loglikelihood_rolling(
611
+ self, requests: List[Instance], disable_tqdm: bool = False
612
+ ) -> List[float]:
613
+ loglikelihoods = []
614
+
615
+ for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
616
+ rolling_token_windows = list(
617
+ map(
618
+ utils.make_disjoint_window,
619
+ utils.get_rolling_token_windows(
620
+ token_list=self.tok_encode(string),
621
+ prefix_token=self.prefix_token_id,
622
+ max_seq_len=self.max_length,
623
+ context_len=1,
624
+ ),
625
+ )
626
+ )
627
+
628
+ # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
629
+ rolling_token_windows = [(None,) + x for x in rolling_token_windows]
630
+
631
+ string_nll = self._loglikelihood_tokens(
632
+ rolling_token_windows,
633
+ disable_tqdm=True,
634
+ )
635
+
636
+ # discard is_greedy
637
+ string_nll = [x[0] for x in string_nll]
638
+
639
+ string_nll = sum(string_nll)
640
+ loglikelihoods.append(string_nll)
641
+ return loglikelihoods
scripts/yans/lm-evaluation-harness/lm_eval/models/huggingface.py ADDED
@@ -0,0 +1,1356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from datetime import timedelta
4
+ from pathlib import Path
5
+ from typing import Dict, List, Literal, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import transformers
10
+ from accelerate import (
11
+ Accelerator,
12
+ InitProcessGroupKwargs,
13
+ find_executable_batch_size,
14
+ )
15
+ from accelerate.utils import get_max_memory
16
+ from huggingface_hub import HfApi
17
+ from packaging import version
18
+ from peft import PeftModel
19
+ from peft import __version__ as PEFT_VERSION
20
+ from tqdm import tqdm
21
+ from transformers.models.auto.modeling_auto import (
22
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
23
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
24
+ )
25
+
26
+ from lm_eval import utils
27
+ from lm_eval.api.instance import Instance
28
+ from lm_eval.api.model import TemplateLM
29
+ from lm_eval.api.registry import register_model
30
+ from lm_eval.models.utils import (
31
+ Collator,
32
+ clear_torch_cache,
33
+ configure_pad_token,
34
+ get_dtype,
35
+ pad_and_concat,
36
+ stop_sequences_criteria,
37
+ )
38
+
39
+
40
+ eval_logger = utils.eval_logger
41
+
42
+
43
+ @register_model("hf-auto", "hf", "huggingface")
44
+ class HFLM(TemplateLM):
45
+ """
46
+ An abstracted Huggingface model class. Enables usage with both models of
47
+ `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
48
+
49
+ Supports data-parallel multi-GPU with HF Accelerate.
50
+ """
51
+
52
+ AUTO_MODEL_CLASS = None
53
+ _DEFAULT_MAX_LENGTH = 2048
54
+
55
+ def __init__(
56
+ self,
57
+ pretrained: Union[str, transformers.PreTrainedModel],
58
+ backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
59
+ # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
60
+ revision: Optional[str] = "main",
61
+ subfolder: Optional[str] = None,
62
+ tokenizer: Optional[
63
+ Union[
64
+ str,
65
+ transformers.PreTrainedTokenizer,
66
+ transformers.PreTrainedTokenizerFast,
67
+ ]
68
+ ] = None,
69
+ truncation: Optional[bool] = False,
70
+ logits_cache: bool = True,
71
+ max_length: Optional[int] = None,
72
+ device: Optional[str] = "cuda",
73
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
74
+ batch_size: Optional[Union[int, str]] = 1,
75
+ max_batch_size: Optional[int] = 64,
76
+ trust_remote_code: Optional[bool] = False,
77
+ use_fast_tokenizer: Optional[bool] = True,
78
+ add_bos_token: Optional[bool] = False,
79
+ prefix_token_id: Optional[int] = None,
80
+ # arguments used for splitting a model across GPUs naively.
81
+ # only used if `parallelize=True`.
82
+ parallelize: Optional[bool] = False,
83
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
84
+ max_cpu_memory: Optional[Union[int, str]] = None,
85
+ offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
86
+ # PEFT, delta weights and quantization options
87
+ peft: Optional[str] = None,
88
+ delta: Optional[str] = None,
89
+ autogptq: Optional[Union[bool, str]] = False,
90
+ **kwargs,
91
+ ) -> None:
92
+ super().__init__()
93
+
94
+ # optionally: take in an already-initialized transformers.PreTrainedModel
95
+ if not isinstance(pretrained, str):
96
+ eval_logger.warning(
97
+ "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
98
+ )
99
+ assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
100
+ self._model = pretrained
101
+ self._device = self._model.device
102
+ self._config = self._model.config
103
+ gpus = 0
104
+
105
+ else:
106
+ assert isinstance(device, str)
107
+ assert isinstance(pretrained, str)
108
+ assert isinstance(batch_size, (int, str))
109
+
110
+ gpus = torch.cuda.device_count()
111
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
112
+ accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
113
+ if accelerator.num_processes > 1:
114
+ self.accelerator = accelerator
115
+
116
+ if "npu" in accelerator.device.type:
117
+ gpus = torch.npu.device_count()
118
+
119
+ # using one process with no model parallelism
120
+ if not (parallelize or accelerator.num_processes > 1):
121
+ # use user-passed device
122
+ device_list = set(
123
+ ["cuda", "cpu"]
124
+ + [f"cuda:{i}" for i in range(gpus)]
125
+ + ["mps", "mps:0"]
126
+ + [f"npu:{i}" for i in range(gpus)]
127
+ )
128
+ if device and device in device_list:
129
+ self._device = torch.device(device)
130
+ eval_logger.info(f"Using device '{device}'")
131
+ if device in ("mps", "mps:0") and version.parse(
132
+ torch.__version__
133
+ ) < version.parse("2.1"):
134
+ raise RuntimeError(
135
+ f"mps requires torch >= 2.1. You have {torch.__version__}"
136
+ )
137
+ else:
138
+ eval_logger.info("Device not specified")
139
+ eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
140
+ self._device = (
141
+ torch.device("cuda")
142
+ if torch.cuda.is_available()
143
+ else torch.device("cpu")
144
+ )
145
+ else: # Parallelism managed by accelerate
146
+ if device != "cuda":
147
+ eval_logger.info(
148
+ f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
149
+ )
150
+ # TODO: include in warning that `load_in_8bit` etc. affect this too
151
+ self._device = (
152
+ self.accelerator.device
153
+ if hasattr(self, "accelerator")
154
+ else torch.device(device)
155
+ )
156
+
157
+ revision = str(revision) # cast to string if not already one
158
+ # TODO: update this to be less of a hack once subfolder is fixed in HF
159
+ revision = revision + ("/" + subfolder if subfolder is not None else "")
160
+
161
+ self._get_config(
162
+ pretrained,
163
+ revision=revision,
164
+ trust_remote_code=trust_remote_code,
165
+ )
166
+
167
+ # determine which of 'causal' and 'seq2seq' backends to use
168
+ self._get_backend(
169
+ config=self.config, backend=backend, trust_remote_code=trust_remote_code
170
+ )
171
+
172
+ # load tokenizer so we know tokenizer vocabulary size before loading model and PEFT
173
+ self._create_tokenizer(
174
+ pretrained,
175
+ tokenizer,
176
+ revision=revision,
177
+ trust_remote_code=trust_remote_code,
178
+ use_fast_tokenizer=use_fast_tokenizer,
179
+ )
180
+
181
+ # if we passed `pretrained` as a string, initialize our model now
182
+ if isinstance(pretrained, str):
183
+ self._create_model(
184
+ pretrained=pretrained,
185
+ revision=revision,
186
+ dtype=dtype,
187
+ trust_remote_code=trust_remote_code,
188
+ parallelize=parallelize,
189
+ gpus=gpus,
190
+ max_memory_per_gpu=max_memory_per_gpu,
191
+ max_cpu_memory=max_cpu_memory,
192
+ offload_folder=offload_folder,
193
+ peft=peft,
194
+ delta=delta,
195
+ autogptq=autogptq,
196
+ **kwargs,
197
+ )
198
+
199
+ # access self._model through self.model property outside this method
200
+ if isinstance(self.model, torch.nn.Module):
201
+ self.model.eval()
202
+ self.model.tie_weights()
203
+
204
+ self.truncation = truncation
205
+ self.logits_cache = logits_cache
206
+ self.vocab_size = self.tokenizer.vocab_size
207
+ # select (or create) a pad token to use
208
+ self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
209
+
210
+ self.add_bos_token = add_bos_token
211
+ if "gemma" in getattr(self.config, "model_type", ""):
212
+ self.add_bos_token = True
213
+ eval_logger.info(
214
+ f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
215
+ )
216
+
217
+ self._max_length = max_length
218
+ self.pretrained = pretrained
219
+ self.delta = delta
220
+ self.peft = peft
221
+ self.revision = revision
222
+ self.batch_schedule = 1
223
+ self.batch_sizes = {}
224
+ self.max_batch_size = max_batch_size
225
+
226
+ if str(batch_size).startswith("auto"):
227
+ batch_size = batch_size.split(":")
228
+ self.batch_size_per_gpu = batch_size[0]
229
+ self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
230
+ else:
231
+ self.batch_size_per_gpu = int(batch_size)
232
+
233
+ if isinstance(pretrained, str):
234
+ if gpus >= 1 or str(self.device) == "mps":
235
+ # TODO: can remove this whole snippet except in the mps case, perhaps?
236
+ if not (parallelize or autogptq or hasattr(self, "accelerator")):
237
+ # place model onto device requested manually,
238
+ # if not using HF Accelerate or device_map
239
+ # or any other option that preloads model onto device
240
+ try:
241
+ self.model.to(self.device)
242
+ except ValueError:
243
+ eval_logger.debug(
244
+ "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
245
+ )
246
+ # multigpu data-parallel support when launched with accelerate
247
+ if gpus > 1:
248
+ if accelerator.num_processes > 1:
249
+ if parallelize:
250
+ eval_logger.warning(
251
+ "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
252
+ )
253
+ elif gpus > accelerator.num_processes:
254
+ eval_logger.warning(
255
+ "WARNING: The number of total system GPUs does not match the number of spawned processes. "
256
+ "If you would like to use data parallelism, please launch the script "
257
+ "with 'accelerate launch *script*'. "
258
+ f"Current run will proceed with {accelerator.num_processes} devices."
259
+ )
260
+ if self.accelerator.is_local_main_process:
261
+ eval_logger.info(
262
+ f"Using {gpus} devices with data parallelism"
263
+ )
264
+
265
+ self._device = torch.device(f"{accelerator.device}")
266
+ self.accelerator = accelerator
267
+
268
+ self._rank = self.accelerator.local_process_index
269
+ self._world_size = self.accelerator.num_processes
270
+ else:
271
+ # if we aren't launching via accelerate, ditch
272
+ self._rank = 0
273
+ self._world_size = 1
274
+ else:
275
+ # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
276
+ eval_logger.warning(
277
+ "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
278
+ )
279
+ self._rank = 0
280
+ self._world_size = 1
281
+
282
+ self.custom_prefix_token_id = prefix_token_id
283
+ if prefix_token_id is not None:
284
+ eval_logger.info(
285
+ f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
286
+ )
287
+
288
+ def _get_accelerate_args(
289
+ self,
290
+ parallelize: bool = None,
291
+ device_map: Optional[str] = "auto",
292
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
293
+ max_cpu_memory: Optional[Union[int, str]] = None,
294
+ offload_folder: Optional[str] = "./offload",
295
+ gpus: Optional[int] = None,
296
+ ) -> dict:
297
+ """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
298
+ num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
299
+ num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes
300
+ if (
301
+ num_machines == 0
302
+ and hasattr(self, "accelerator")
303
+ and self.accelerator is not None
304
+ ):
305
+ eval_logger.info(
306
+ "We are not in a distributed setting for accelerate. Setting model_parallel to False."
307
+ )
308
+ parallelize = False
309
+
310
+ if parallelize is None:
311
+ # If parallelism is unset by the user, we automatically assign model parallelism
312
+ # if enough extra GPUs are available
313
+ max_memory_all_gpus = get_max_memory()
314
+ # We just want gpu, not cpu, max memory
315
+ if "cpu" in max_memory_all_gpus:
316
+ del max_memory_all_gpus["cpu"]
317
+ parallelize = bool(num_local_processes < len(max_memory_all_gpus))
318
+ eval_logger.info(
319
+ f"Setting model parallel to {parallelize} since "
320
+ f"the number of local processes is {num_local_processes} "
321
+ f"and the number of GPUs is {len(max_memory_all_gpus)}"
322
+ )
323
+
324
+ args = {}
325
+ if parallelize: # Model parallelism will be used
326
+ max_memory = {}
327
+ if max_memory_per_gpu is not None: # Using the provided memory requirements
328
+ max_memory_per_gpu_map = {
329
+ device_idx: max_memory_per_gpu for device_idx in range(gpus)
330
+ }
331
+ else: # Estimating the possible memory requirements
332
+ max_memory_all_gpus = get_max_memory()
333
+ if "cpu" in max_memory_all_gpus:
334
+ del max_memory_all_gpus["cpu"]
335
+ if not hasattr(self, "accelerator"):
336
+ max_memory_per_gpu_map = {
337
+ k: v for k, v in max_memory_all_gpus.items()
338
+ }
339
+ else:
340
+ # use only 1 / num_processes of the GPUs if we are running under accelerate launch
341
+ max_memory_per_gpu_map = {
342
+ k: v
343
+ for k, v in max_memory_all_gpus.items()
344
+ if k % num_local_processes
345
+ == (self.accelerator.process_index % num_local_processes)
346
+ }
347
+ args["max_memory"] = max_memory_per_gpu_map
348
+ args["device_map"] = "auto"
349
+ eval_logger.info(
350
+ f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to 'auto'"
351
+ )
352
+
353
+ if max_cpu_memory is not None:
354
+ max_memory["cpu"] = max_cpu_memory
355
+
356
+ args["offload_folder"] = offload_folder
357
+ elif (
358
+ device_map is None
359
+ ): # No model parallelism, we use the default provided device for our model
360
+ if hasattr(self, "accelerator"):
361
+ device_map = {"": f"{self.accelerator.device}"}
362
+ else:
363
+ device_map = {"": str(self.device)}
364
+ args["max_memory"] = None
365
+ args["device_map"] = device_map
366
+ eval_logger.info(
367
+ f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}"
368
+ )
369
+ else:
370
+ args["max_memory"] = None
371
+ args["device_map"] = None
372
+ eval_logger.info("Model parallel was set to False.")
373
+
374
+ return args
375
+
376
+ @property
377
+ def config(self):
378
+ # return the associated transformers.AutoConfig for the given pretrained model.
379
+ return self._config
380
+
381
+ @property
382
+ def model(self):
383
+ # returns the model, unwrapping it if using Accelerate
384
+ if hasattr(self, "accelerator"):
385
+ return self.accelerator.unwrap_model(self._model)
386
+ else:
387
+ return self._model
388
+
389
+ @property
390
+ def eot_token_id(self):
391
+ # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
392
+ return self.tokenizer.eos_token_id
393
+
394
+ @property
395
+ def prefix_token_id(self):
396
+ # it is used as prefix for loglikelihood
397
+ if self.custom_prefix_token_id is not None:
398
+ return self.custom_prefix_token_id
399
+ if self.tokenizer.bos_token_id is not None:
400
+ return self.tokenizer.bos_token_id
401
+ return self.tokenizer.eos_token_id
402
+
403
+ @property
404
+ def max_length(self):
405
+ if self._max_length: # if max length manually set, return it
406
+ return self._max_length
407
+ seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
408
+ for attr in seqlen_config_attrs:
409
+ if hasattr(self.model.config, attr):
410
+ return getattr(self.model.config, attr)
411
+ if hasattr(self.tokenizer, "model_max_length"):
412
+ if self.tokenizer.model_max_length == 1000000000000000019884624838656:
413
+ return self._DEFAULT_MAX_LENGTH
414
+ return self.tokenizer.model_max_length
415
+ return self._DEFAULT_MAX_LENGTH
416
+
417
+ @property
418
+ def max_gen_toks(self) -> int:
419
+ return 256
420
+
421
+ @property
422
+ def batch_size(self):
423
+ return self.batch_size_per_gpu
424
+
425
+ @property
426
+ def device(self):
427
+ return self._device
428
+
429
+ @property
430
+ def rank(self):
431
+ return self._rank
432
+
433
+ @property
434
+ def world_size(self):
435
+ return self._world_size
436
+
437
+ @property
438
+ def tokenizer_name(self) -> str:
439
+ return self.tokenizer.name_or_path.replace("/", "__")
440
+
441
+ @property
442
+ def chat_template(self) -> str:
443
+ if self.tokenizer.chat_template is not None:
444
+ return self.tokenizer.chat_template
445
+ return self.tokenizer.default_chat_template
446
+
447
+ def _get_backend(
448
+ self,
449
+ config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
450
+ backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
451
+ trust_remote_code: Optional[bool] = False,
452
+ ) -> None:
453
+ """
454
+ Helper method during initialization.
455
+ Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder))
456
+ model type to be used.
457
+ """
458
+ assert backend in ["default", "causal", "seq2seq"]
459
+
460
+ if backend != "default":
461
+ # if we've settled on non-default backend, use that manually
462
+ if backend == "causal":
463
+ self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
464
+ elif backend == "seq2seq":
465
+ self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
466
+ eval_logger.info(
467
+ f"Overrode HF model backend type, and using type '{backend}'"
468
+ )
469
+ else:
470
+ # determine and use the default HF backend for this model, based on its config + metadata.
471
+ if (
472
+ getattr(config, "model_type")
473
+ in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
474
+ ):
475
+ # first check if model type is listed under seq2seq models, since some
476
+ # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
477
+ # these special cases should be treated as seq2seq models.
478
+ self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
479
+ elif (
480
+ getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
481
+ ):
482
+ self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
483
+ else:
484
+ if not trust_remote_code:
485
+ eval_logger.warning(
486
+ "HF model type is neither marked as CausalLM or Seq2SeqLM. \
487
+ This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
488
+ )
489
+ # if model type is neither in HF transformers causal or seq2seq model registries
490
+ # then we default to AutoModelForCausalLM
491
+ self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
492
+
493
+ assert self.AUTO_MODEL_CLASS in [
494
+ transformers.AutoModelForCausalLM,
495
+ transformers.AutoModelForSeq2SeqLM,
496
+ ]
497
+ return None
498
+
499
+ def _get_config(
500
+ self,
501
+ pretrained: str,
502
+ revision: str = "main",
503
+ trust_remote_code: bool = False,
504
+ ) -> None:
505
+ self._config = transformers.AutoConfig.from_pretrained(
506
+ pretrained,
507
+ revision=revision,
508
+ trust_remote_code=trust_remote_code,
509
+ )
510
+
511
+ def _create_model(
512
+ self,
513
+ pretrained: str,
514
+ revision: Optional[str] = "main",
515
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
516
+ trust_remote_code: Optional[bool] = False,
517
+ # arguments used for splitting a model across GPUs naively.
518
+ # only used if `parallelize=True`.
519
+ # (accelerate naive PP (device_map) options)
520
+ parallelize: Optional[bool] = False,
521
+ gpus: Optional[int] = None,
522
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
523
+ max_cpu_memory: Optional[Union[int, str]] = None,
524
+ offload_folder: Optional[str] = "./offload",
525
+ # PEFT, delta weights and quantization options
526
+ peft: Optional[str] = None,
527
+ delta: Optional[str] = None,
528
+ autogptq: Optional[Union[bool, str]] = False,
529
+ **kwargs,
530
+ ) -> None:
531
+ """
532
+ Initializes an HF or HF-compatible PreTrainedModel from scratch
533
+ inside HFLM, using the kwargs passed into self.__init__().
534
+
535
+ Also handles functionality such as AutoGPTQ usage and PEFT wrapping.
536
+
537
+ For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
538
+ (such as PyTorch models that are nearly, but not quite, fully mirroring
539
+ HF's public interface relied on in this HFLM class)
540
+ please consider subclassing HFLM and overriding this and other methods as needed.
541
+ """
542
+
543
+ model_kwargs = kwargs if kwargs else {}
544
+
545
+ model_kwargs.update(
546
+ self._get_accelerate_args(
547
+ parallelize=parallelize,
548
+ device_map=kwargs.get("device_map", None),
549
+ max_memory_per_gpu=max_memory_per_gpu,
550
+ max_cpu_memory=max_cpu_memory,
551
+ offload_folder=offload_folder,
552
+ gpus=gpus,
553
+ )
554
+ )
555
+
556
+ if not autogptq:
557
+ if model_kwargs.get("load_in_4bit", None):
558
+ assert (
559
+ transformers.__version__ >= "4.30.0"
560
+ ), "load_in_4bit requires transformers >= 4.30.0"
561
+ if transformers.__version__ >= "4.30.0":
562
+ if model_kwargs.get("load_in_4bit", None):
563
+ if model_kwargs.get("bnb_4bit_compute_dtype", None):
564
+ model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
565
+ model_kwargs["bnb_4bit_compute_dtype"]
566
+ )
567
+
568
+ self._model = self.AUTO_MODEL_CLASS.from_pretrained(
569
+ pretrained,
570
+ revision=revision,
571
+ torch_dtype=get_dtype(dtype),
572
+ trust_remote_code=trust_remote_code,
573
+ **model_kwargs,
574
+ )
575
+ else:
576
+ try:
577
+ from auto_gptq import AutoGPTQForCausalLM
578
+ except ModuleNotFoundError:
579
+ raise Exception(
580
+ "Tried to load auto_gptq, but auto-gptq is not installed ",
581
+ "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
582
+ )
583
+
584
+ self._model = AutoGPTQForCausalLM.from_quantized(
585
+ pretrained,
586
+ trust_remote_code=trust_remote_code,
587
+ model_basename=None if autogptq is True else Path(autogptq).stem,
588
+ use_safetensors=True
589
+ if autogptq is True
590
+ else autogptq.endswith(".safetensors"),
591
+ **model_kwargs,
592
+ )
593
+
594
+ if peft and delta:
595
+ raise ValueError(
596
+ "Cannot use both 'peft' and 'delta' options at the same time."
597
+ )
598
+
599
+ if peft:
600
+ if model_kwargs.get("load_in_4bit", None):
601
+ if version.parse(PEFT_VERSION) < version.parse("0.4.0"):
602
+ raise AssertionError("load_in_4bit requires peft >= 0.4.0")
603
+ if self._model.config.vocab_size != len(self.tokenizer):
604
+ # resize model for LoRAs with added tokens
605
+ self._model.resize_token_embeddings(len(self.tokenizer))
606
+ eval_logger.info(
607
+ f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
608
+ )
609
+ self._model = PeftModel.from_pretrained(
610
+ self._model, peft, revision=revision
611
+ )
612
+ elif delta:
613
+ if autogptq:
614
+ eval_logger.warning(
615
+ "Delta weights might trigger unexpected behavior when used with AutoGPTQ."
616
+ )
617
+ _model_delta = self.AUTO_MODEL_CLASS.from_pretrained(
618
+ delta,
619
+ revision=revision,
620
+ torch_dtype=get_dtype(dtype),
621
+ trust_remote_code=trust_remote_code,
622
+ **model_kwargs,
623
+ )
624
+ for name, param in self._model.state_dict().items():
625
+ try:
626
+ param.data += _model_delta.state_dict()[name]
627
+ except KeyError:
628
+ raise KeyError(f"Delta model is missing weights for layer: {name}")
629
+ except Exception as e:
630
+ raise RuntimeError(
631
+ f"Failed to add delta weights to layer {name}. Error: {e}"
632
+ )
633
+
634
+ del _model_delta
635
+
636
+ return None
637
+
638
+ def _create_tokenizer(
639
+ self,
640
+ pretrained: Union[str, transformers.PreTrainedModel],
641
+ tokenizer: Optional[
642
+ Union[
643
+ str,
644
+ transformers.PreTrainedTokenizer,
645
+ transformers.PreTrainedTokenizerFast,
646
+ ]
647
+ ],
648
+ revision: Optional[str] = "main",
649
+ trust_remote_code: Optional[bool] = False,
650
+ use_fast_tokenizer: Optional[bool] = True,
651
+ ) -> None:
652
+ """
653
+ Helper method during initialization.
654
+
655
+ Create a tokenizer object corresponding to the correct
656
+ tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
657
+ """
658
+
659
+ if tokenizer:
660
+ if isinstance(tokenizer, str):
661
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
662
+ tokenizer,
663
+ revision=revision,
664
+ trust_remote_code=trust_remote_code,
665
+ use_fast=use_fast_tokenizer,
666
+ )
667
+ else:
668
+ assert isinstance(
669
+ tokenizer, transformers.PreTrainedTokenizer
670
+ ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
671
+ self.tokenizer = tokenizer
672
+ else:
673
+ # Get tokenizer based on 'pretrained'
674
+ if isinstance(pretrained, str):
675
+ model_name = pretrained
676
+ else:
677
+ # get the HF hub name via accessor on model
678
+ model_name = self.model.name_or_path
679
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
680
+ model_name,
681
+ revision=revision,
682
+ trust_remote_code=trust_remote_code,
683
+ use_fast=use_fast_tokenizer,
684
+ )
685
+ return None
686
+
687
+ def _detect_batch_size(self, requests=None, pos: int = 0):
688
+ if requests:
689
+ _, context_enc, continuation_enc = requests[pos]
690
+ max_length = len(
691
+ (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
692
+ )
693
+ max_context_enc = len(context_enc[-(self.max_length + 1) :])
694
+ max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
695
+ else:
696
+ max_length = self.max_length
697
+ max_context_enc = max_length
698
+ max_cont_enc = max_length
699
+
700
+ # if OOM, then halves batch_size and tries again
701
+ @find_executable_batch_size(starting_batch_size=self.max_batch_size)
702
+ def forward_batch(batch_size):
703
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
704
+ length = max(max_context_enc, max_cont_enc)
705
+ batched_conts = torch.ones(
706
+ (batch_size, length), device=self.device
707
+ ).long()
708
+ test_batch = torch.ones((batch_size, length), device=self.device).long()
709
+ call_kwargs = {
710
+ "attn_mask": test_batch,
711
+ "labels": batched_conts,
712
+ }
713
+ else:
714
+ call_kwargs = {}
715
+ test_batch = torch.ones(
716
+ (batch_size, max_length), device=self.device
717
+ ).long()
718
+ for _ in range(5):
719
+ out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) # noqa: F841
720
+
721
+ return batch_size
722
+
723
+ try:
724
+ batch_size = forward_batch()
725
+ except RuntimeError as e:
726
+ if "No executable batch size found" in str(e):
727
+ batch_size = 1
728
+ else:
729
+ raise
730
+
731
+ if self.world_size > 1:
732
+ # if multi-GPU, always take minimum over all selected batch sizes
733
+ max_rnk_bs = torch.tensor([batch_size], device=self.device)
734
+ gathered = (
735
+ self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
736
+ )
737
+ batch_size = min(gathered)
738
+ clear_torch_cache()
739
+ return batch_size
740
+
741
+ clear_torch_cache()
742
+ return batch_size
743
+
744
+ def tok_encode(
745
+ self, string: str, left_truncate_len=None, add_special_tokens=None
746
+ ) -> List[int]:
747
+ """ """
748
+ # default for None - empty dict, use predefined tokenizer param
749
+ # used for all models except for CausalLM or predefined value
750
+ special_tokens_kwargs = {}
751
+
752
+ # by default for CausalLM - false or self.add_bos_token is set
753
+ if add_special_tokens is None:
754
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
755
+ special_tokens_kwargs = {
756
+ "add_special_tokens": False or self.add_bos_token
757
+ }
758
+ # otherwise the method explicitly defines the value
759
+ else:
760
+ special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
761
+
762
+ encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
763
+
764
+ # left-truncate the encoded context to be at most `left_truncate_len` tokens long
765
+ if left_truncate_len:
766
+ encoding = encoding[-left_truncate_len:]
767
+
768
+ return encoding
769
+
770
+ def tok_batch_encode(
771
+ self,
772
+ strings: List[str],
773
+ padding_side: str = "left",
774
+ left_truncate_len: int = None,
775
+ truncation: bool = False,
776
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
777
+ # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
778
+ old_padding_side = self.tokenizer.padding_side
779
+ self.tokenizer.padding_side = padding_side
780
+
781
+ add_special_tokens = {}
782
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
783
+ add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
784
+
785
+ encoding = self.tokenizer(
786
+ strings,
787
+ truncation=truncation,
788
+ padding="longest",
789
+ return_tensors="pt",
790
+ **add_special_tokens,
791
+ )
792
+ if left_truncate_len:
793
+ encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
794
+ encoding["attention_mask"] = encoding["attention_mask"][
795
+ :, -left_truncate_len:
796
+ ]
797
+ self.tokenizer.padding_side = old_padding_side
798
+
799
+ return encoding["input_ids"], encoding["attention_mask"]
800
+
801
+ def tok_decode(self, tokens, skip_special_tokens=True):
802
+ return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
803
+
804
+ def _model_call(self, inps, attn_mask=None, labels=None):
805
+ """
806
+ :param inps: torch.Tensor
807
+ A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
808
+ [batch, sequence_ctx]. the size of sequence may vary from call to call
809
+ :param attn_mask: torch.Tensor, optional
810
+ A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
811
+ (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
812
+ :param labels: torch.Tensor, optional
813
+ A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
814
+ (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
815
+ :return
816
+ A torch tensor of shape [batch, sequence, vocab] with the
817
+ logits returned from the model's decoder
818
+ """
819
+ with torch.no_grad():
820
+ if attn_mask is not None or labels is not None:
821
+ assert attn_mask is not None and labels is not None
822
+ assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
823
+ return self.model(
824
+ input_ids=inps, attention_mask=attn_mask, labels=labels
825
+ ).logits
826
+ else:
827
+ assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
828
+ return self.model(inps).logits
829
+
830
+ def _model_generate(self, context, max_length, stop, **generation_kwargs):
831
+ # temperature = 0.0 if not set
832
+ # if do_sample is false and temp==0.0:
833
+ # remove temperature, as do_sample=False takes care of this
834
+ # and we don't want a warning from HF
835
+ generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
836
+ do_sample = generation_kwargs.get("do_sample", None)
837
+
838
+ # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
839
+ if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
840
+ generation_kwargs["do_sample"] = do_sample = False
841
+
842
+ if do_sample is False and generation_kwargs.get("temperature") == 0.0:
843
+ generation_kwargs.pop("temperature")
844
+ # build stopping criteria
845
+ stopping_criteria = stop_sequences_criteria(
846
+ self.tokenizer, stop, context.shape[1], context.shape[0]
847
+ )
848
+ return self.model.generate(
849
+ input_ids=context,
850
+ max_length=max_length,
851
+ stopping_criteria=stopping_criteria,
852
+ pad_token_id=self.tokenizer.pad_token_id,
853
+ use_cache=True,
854
+ **generation_kwargs,
855
+ )
856
+
857
+ def _select_cont_toks(
858
+ self, logits: torch.Tensor, contlen: int = None, inplen: int = None
859
+ ) -> torch.Tensor:
860
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
861
+ assert (
862
+ contlen and inplen
863
+ ), "Must pass input len and cont. len to select scored logits for causal LM"
864
+ # discard right-padding.
865
+ # also discard the input/context tokens. we'll only score continuations.
866
+ logits = logits[inplen - contlen : inplen]
867
+ elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
868
+ assert (
869
+ contlen and not inplen
870
+ ), "Selecting scored logits for Seq2SeqLM requires only cont. len"
871
+ # only discard right-padding.
872
+ # the logits input to this fn only contain decoder-side tokens.
873
+ logits = logits[:contlen]
874
+
875
+ return logits
876
+
877
+ def loglikelihood_rolling(
878
+ self, requests: List[Instance], disable_tqdm: bool = False
879
+ ) -> List[float]:
880
+ loglikelihoods = []
881
+
882
+ adaptive_batch_size = None
883
+ if self.batch_size == "auto":
884
+ # using rolling window with maximum context
885
+ print("Passed argument batch_size = auto. Detecting largest batch size")
886
+ batch_size = self._detect_batch_size()
887
+ print(f"Determined Largest batch size: {batch_size}")
888
+ adaptive_batch_size = batch_size
889
+
890
+ for (string,) in tqdm(
891
+ [req.args for req in requests], disable=(disable_tqdm or (self.rank != 0))
892
+ ):
893
+ rolling_token_windows = list(
894
+ map(
895
+ utils.make_disjoint_window,
896
+ utils.get_rolling_token_windows(
897
+ token_list=self.tok_encode(string),
898
+ prefix_token=self.prefix_token_id,
899
+ max_seq_len=self.max_length,
900
+ context_len=1,
901
+ ),
902
+ )
903
+ )
904
+
905
+ # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
906
+ rolling_token_windows = [(None,) + x for x in rolling_token_windows]
907
+
908
+ pad_amnt = 0
909
+ if self.world_size > 1:
910
+ # We pad out the external document-level iterator so the inner iterator doesn't hang
911
+ mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
912
+ gathered = (
913
+ self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
914
+ )
915
+
916
+ pad_amnt = max(gathered) - gathered[self.rank]
917
+ if pad_amnt > 0:
918
+ rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
919
+
920
+ string_nll = self._loglikelihood_tokens(
921
+ requests=rolling_token_windows,
922
+ disable_tqdm=True,
923
+ override_bs=adaptive_batch_size,
924
+ )
925
+
926
+ if (self.world_size > 1) and (pad_amnt > 0):
927
+ string_nll = [x[0] for x in string_nll[:-pad_amnt]]
928
+ else:
929
+ # discard is_greedy
930
+ string_nll = [x[0] for x in string_nll]
931
+
932
+ string_nll = sum(string_nll)
933
+ loglikelihoods.append(string_nll)
934
+
935
+ return loglikelihoods
936
+
937
+ def _batch_scheduler(self, pos, n_reordered_requests):
938
+ sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
939
+ if sched in self.batch_sizes:
940
+ return self.batch_sizes[sched]
941
+ if (len(self.batch_sizes) > 1) and (
942
+ self.batch_sizes[sched - 1] == self.max_batch_size
943
+ ):
944
+ # if previous batch size is already maximal, skip recomputation
945
+ self.batch_sizes[sched] = self.max_batch_size
946
+ return self.batch_sizes[sched]
947
+ print(
948
+ f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
949
+ )
950
+ self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
951
+ print(f"Determined largest batch size: {self.batch_sizes[sched]}")
952
+ return self.batch_sizes[sched]
953
+
954
+ def _loglikelihood_tokens(
955
+ self,
956
+ requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
957
+ disable_tqdm: bool = False,
958
+ override_bs: int = None,
959
+ ) -> List[Tuple[float, bool]]:
960
+ # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
961
+ res = []
962
+
963
+ def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
964
+ """Defines the key for the sorted method"""
965
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
966
+ # - time estimates will always be over not underestimates, which is more useful for planning
967
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
968
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
969
+ # automatic adaptive batches much much easier to implement
970
+ # - any OOMs will happen right away rather than near the end
971
+
972
+ toks = req[1] + req[2]
973
+ return -len(toks), tuple(toks)
974
+
975
+ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
976
+ """Defines the key to group and lookup one-token continuations"""
977
+ # Use with group_by="contexts" (optional)"
978
+ # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
979
+ # speeds up some multiple-choice tasks proportionally to the number of choices.
980
+ # groups requests by context+continuation[:-1] and infer on one request/group.
981
+ return req[-2] + req[-1][:-1]
982
+
983
+ re_ord = Collator(
984
+ requests,
985
+ sort_fn=_collate,
986
+ group_by="contexts"
987
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
988
+ and self.logits_cache
989
+ else None,
990
+ group_fn=_lookup_one_token_cont,
991
+ )
992
+
993
+ # automatic (variable) batch size detection for vectorization
994
+ # pull longest context sample from request
995
+ n_reordered_requests = len(re_ord)
996
+ batch_size = (
997
+ self.batch_size
998
+ if self.batch_size != "auto"
999
+ else override_bs
1000
+ if override_bs is not None
1001
+ else 0
1002
+ )
1003
+ batch_fn = (
1004
+ self._batch_scheduler
1005
+ if self.batch_size == "auto"
1006
+ and n_reordered_requests > 0
1007
+ and not override_bs
1008
+ else None
1009
+ )
1010
+
1011
+ chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
1012
+ pbar = tqdm(
1013
+ total=len(requests),
1014
+ disable=(disable_tqdm or (self.rank != 0)),
1015
+ desc="Running loglikelihood requests",
1016
+ )
1017
+ for chunk in chunks:
1018
+ inps = []
1019
+ cont_toks_list = []
1020
+ inplens = []
1021
+
1022
+ conts = []
1023
+ encoder_attns = []
1024
+
1025
+ padding_len_inp = None
1026
+ padding_len_cont = None
1027
+ # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
1028
+ # tensors, then we pack them together into a batch, call the model, and then pick it all apart
1029
+ # again because vectorizing is annoying
1030
+
1031
+ for _, context_enc, continuation_enc in chunk:
1032
+ # sanity check
1033
+ assert len(context_enc) > 0
1034
+ assert len(continuation_enc) > 0
1035
+ assert len(continuation_enc) <= self.max_length
1036
+
1037
+ # how this all works (illustrated on a causal decoder-only setup):
1038
+ # CTX CONT
1039
+ # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
1040
+ # model \ \
1041
+ # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
1042
+ # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
1043
+
1044
+ # when too long to fit in context, truncate from the left
1045
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
1046
+ inp = torch.tensor(
1047
+ (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
1048
+ dtype=torch.long,
1049
+ device=self.device,
1050
+ )
1051
+ (inplen,) = inp.shape
1052
+ elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
1053
+ inp = torch.tensor(
1054
+ (context_enc)[-self.max_length :],
1055
+ dtype=torch.long,
1056
+ device=self.device,
1057
+ )
1058
+ (inplen,) = inp.shape
1059
+
1060
+ # build encoder attn masks
1061
+ encoder_attns.append(torch.ones_like(inp))
1062
+
1063
+ cont = torch.tensor(
1064
+ (continuation_enc)[-self.max_length :],
1065
+ # TODO: left-shift these?
1066
+ # TODO: our code assumes we never end up truncating conts for either model type
1067
+ dtype=torch.long,
1068
+ device=self.device,
1069
+ )
1070
+ (contlen,) = cont.shape
1071
+
1072
+ conts.append(cont)
1073
+
1074
+ padding_len_cont = (
1075
+ max(padding_len_cont, contlen)
1076
+ if padding_len_cont is not None
1077
+ else contlen
1078
+ )
1079
+
1080
+ padding_len_inp = (
1081
+ max(padding_len_inp, inplen)
1082
+ if padding_len_inp is not None
1083
+ else inplen
1084
+ )
1085
+
1086
+ inps.append(inp) # [1, inp_length]
1087
+ cont_toks_list.append(continuation_enc)
1088
+ inplens.append(inplen)
1089
+
1090
+ # create encoder attn mask and batched conts, if seq2seq
1091
+ call_kwargs = {}
1092
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
1093
+ batched_inps = pad_and_concat(
1094
+ padding_len_inp, inps, padding_side="right"
1095
+ ) # [batch, padding_len_inp]
1096
+ elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
1097
+ # TODO: left-pad encoder inps and mask?
1098
+ batched_inps = pad_and_concat(
1099
+ padding_len_inp, inps
1100
+ ) # [batch, padding_len_inp]
1101
+ batched_conts = pad_and_concat(
1102
+ padding_len_cont, conts
1103
+ ) # [batch, padding_len_cont]
1104
+ batched_encoder_mask = pad_and_concat(
1105
+ padding_len_inp, encoder_attns
1106
+ ) # [batch, padding_len_inp]
1107
+ call_kwargs = {
1108
+ "attn_mask": batched_encoder_mask,
1109
+ "labels": batched_conts,
1110
+ }
1111
+
1112
+ multi_logits = F.log_softmax(
1113
+ self._model_call(batched_inps, **call_kwargs), dim=-1
1114
+ ) # [batch, padding_length (inp or cont), vocab]
1115
+
1116
+ for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
1117
+ chunk, multi_logits, inplens, cont_toks_list
1118
+ ):
1119
+ # Slice to original seq length
1120
+ contlen = len(cont_toks)
1121
+ # take only logits in the continuation
1122
+ # (discard context toks if decoder-only ; discard right-padding)
1123
+ # also discards + checks for "virtual tokens" in the causal LM's input window
1124
+ # from prompt/prefix tuning tokens, if applicable
1125
+ ctx_len = (
1126
+ inplen + (logits.shape[0] - padding_len_inp)
1127
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
1128
+ else None
1129
+ )
1130
+ logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
1131
+ logits = logits.unsqueeze(0) # [1, seq, vocab]
1132
+
1133
+ # Check if per-token argmax is exactly equal to continuation
1134
+ greedy_tokens = logits.argmax(dim=-1)
1135
+
1136
+ # check for one-token continuation cache hits.
1137
+ # noop in case group_by != "contexts" or no cache hit and returns the
1138
+ # original args. Otherwise, expands the logits batch dimension and yields each
1139
+ # batch along with matching continuation tokens and prompt strings.
1140
+ # logits -> [1, seq, vocab]
1141
+ for request_str, cont_toks, logits in re_ord.get_cache(
1142
+ req_str=request_str,
1143
+ cxt_toks=ctx_tokens,
1144
+ cont_toks=cont_toks,
1145
+ logits=logits,
1146
+ ):
1147
+ cont_toks = torch.tensor(
1148
+ cont_toks, dtype=torch.long, device=self.device
1149
+ ).unsqueeze(0) # [1, seq]
1150
+ max_equal = (greedy_tokens == cont_toks).all()
1151
+
1152
+ # Obtain log-probs at the corresponding continuation token indices
1153
+ # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
1154
+ logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
1155
+ -1
1156
+ ) # [1, seq]
1157
+
1158
+ # Answer: (log prob, is-exact-match)
1159
+ answer = (float(logits.sum()), bool(max_equal))
1160
+
1161
+ res.append(answer)
1162
+
1163
+ self.cache_hook.add_partial("loglikelihood", request_str, answer)
1164
+ pbar.update(1)
1165
+
1166
+ pbar.close()
1167
+
1168
+ return re_ord.get_original(res)
1169
+
1170
+ def generate_until(
1171
+ self, requests: List[Instance], disable_tqdm: bool = False
1172
+ ) -> List[str]:
1173
+ res = []
1174
+
1175
+ def _collate(req: Tuple[str, dict]):
1176
+ """Defines the key for the sorted method"""
1177
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
1178
+ # - time estimates will always be over not underestimates, which is more useful for planning
1179
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
1180
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
1181
+ # automatic adaptive batches much much easier to implement
1182
+ # - any OOMs will happen right away rather than near the end
1183
+ toks = self.tok_encode(req[0])
1184
+ return -len(toks), req[0]
1185
+
1186
+ pbar = tqdm(
1187
+ total=len(requests),
1188
+ disable=(disable_tqdm or (self.rank != 0)),
1189
+ desc="Running generate_until requests",
1190
+ )
1191
+ adaptive_batch_size = None
1192
+ if self.batch_size == "auto":
1193
+ # using rolling window with maximum context
1194
+ print("Passed argument batch_size = auto. Detecting largest batch size")
1195
+ batch_size = self._detect_batch_size()
1196
+ print(f"Determined Largest batch size: {batch_size}")
1197
+ adaptive_batch_size = batch_size
1198
+ # for each different set of kwargs, we execute all requests, by batch.
1199
+ batch_size = (
1200
+ self.batch_size
1201
+ if self.batch_size != "auto"
1202
+ else adaptive_batch_size
1203
+ if adaptive_batch_size is not None
1204
+ else 0
1205
+ )
1206
+ batch_fn = (
1207
+ self._batch_scheduler
1208
+ if self.batch_size == "auto" and not adaptive_batch_size
1209
+ else None
1210
+ )
1211
+
1212
+ # we group requests by their generation_kwargs,
1213
+ # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
1214
+ # in the same batch.
1215
+ # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
1216
+ re_ords = Collator(
1217
+ [reg.args for reg in requests],
1218
+ sort_fn=_collate,
1219
+ group_by="gen_kwargs",
1220
+ group_fn=lambda x: x[1],
1221
+ )
1222
+ chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
1223
+ for chunk in chunks:
1224
+ contexts, all_gen_kwargs = zip(*chunk)
1225
+ # we assume all gen kwargs in the batch are the same
1226
+ # this is safe to assume because the `grouper` object ensures it.
1227
+ gen_kwargs = all_gen_kwargs[0]
1228
+ # unpack our keyword arguments.
1229
+ until = None
1230
+ if isinstance(gen_kwargs, dict):
1231
+ kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
1232
+ if "until" in kwargs.keys():
1233
+ until = kwargs.pop("until")
1234
+ if isinstance(until, str):
1235
+ until = [until]
1236
+ elif not isinstance(until, list):
1237
+ raise ValueError(
1238
+ f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
1239
+ )
1240
+ else:
1241
+ raise ValueError(
1242
+ f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
1243
+ )
1244
+ # add EOS token to stop sequences
1245
+ eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
1246
+ if not until:
1247
+ until = [eos]
1248
+ else:
1249
+ until.append(eos)
1250
+ if "max_gen_toks" in kwargs.keys():
1251
+ max_gen_toks = kwargs.pop("max_gen_toks")
1252
+ else:
1253
+ max_gen_toks = self.max_gen_toks
1254
+
1255
+ # set the max length in tokens of inputs ("context_enc")
1256
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
1257
+ # max len for inputs = max length, minus room to generate the max new tokens
1258
+ max_ctx_len = self.max_length - max_gen_toks
1259
+ elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
1260
+ # max len for inputs = encoder's whole max_length
1261
+ max_ctx_len = self.max_length
1262
+
1263
+ # encode, pad, and truncate contexts for this batch
1264
+ context_enc, attn_masks = self.tok_batch_encode(
1265
+ contexts,
1266
+ left_truncate_len=max_ctx_len,
1267
+ truncation=self.truncation,
1268
+ )
1269
+ context_enc = context_enc.to(self.device)
1270
+ attn_masks = attn_masks.to(self.device)
1271
+
1272
+ if "max_length" not in kwargs:
1273
+ kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
1274
+
1275
+ # perform batched generation
1276
+ cont = self._model_generate(
1277
+ context=context_enc,
1278
+ attention_mask=attn_masks,
1279
+ stop=until,
1280
+ **kwargs,
1281
+ )
1282
+
1283
+ cont_toks_list = cont.tolist()
1284
+ for cont_toks, context in zip(cont_toks_list, contexts):
1285
+ # discard context + left-padding toks if using causal decoder-only LM
1286
+ if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
1287
+ cont_toks = cont_toks[context_enc.shape[1] :]
1288
+
1289
+ s = self.tok_decode(cont_toks)
1290
+
1291
+ # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
1292
+ for term in until:
1293
+ if len(term) > 0:
1294
+ # ignore '' separator,
1295
+ # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
1296
+ s = s.split(term)[0]
1297
+
1298
+ res.append(s)
1299
+
1300
+ self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
1301
+ pbar.update(1)
1302
+ # reorder this group of results back to original unsorted form
1303
+ res = re_ords.get_original(res)
1304
+
1305
+ pbar.close()
1306
+
1307
+ return res
1308
+
1309
+ def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
1310
+ """
1311
+ Method to apply a chat template to a list of chat history between user and model.
1312
+ """
1313
+ return self.tokenizer.apply_chat_template(
1314
+ chat_history, tokenize=False, add_generation_prompt=True
1315
+ )
1316
+
1317
+ def get_model_info(self) -> dict:
1318
+ """
1319
+ Method to get Hugging Face model information for experiment reproducibility.
1320
+ """
1321
+
1322
+ def get_model_num_params(model) -> int:
1323
+ if hasattr(model, "num_parameters"):
1324
+ return model.num_parameters()
1325
+ if hasattr(model, "parameters"):
1326
+ return sum(p.numel() for p in model.parameters())
1327
+ else:
1328
+ return -1
1329
+
1330
+ def get_model_dtype(model) -> str:
1331
+ if hasattr(model, "dtype"):
1332
+ return model.dtype
1333
+ else:
1334
+ return ""
1335
+
1336
+ def get_model_sha(pretrained: str, revision: str) -> str:
1337
+ try:
1338
+ model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
1339
+ return model_info.sha
1340
+ except Exception as e:
1341
+ eval_logger.warn(
1342
+ f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
1343
+ )
1344
+ return ""
1345
+
1346
+ model_info = {
1347
+ "model_num_parameters": get_model_num_params(self._model),
1348
+ "model_dtype": get_model_dtype(self._model),
1349
+ "model_revision": self.revision,
1350
+ "model_sha": get_model_sha(self.pretrained, self.revision),
1351
+ }
1352
+ if self.peft:
1353
+ model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
1354
+ if self.delta:
1355
+ model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
1356
+ return model_info
scripts/yans/lm-evaluation-harness/lm_eval/models/nemo_lm.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ import pathlib
17
+ from copy import deepcopy
18
+ from typing import List, Literal
19
+
20
+ import filelock
21
+ import numpy as np
22
+ import torch
23
+ from tqdm import tqdm
24
+
25
+ from lm_eval.api.instance import Instance
26
+ from lm_eval.api.model import LM
27
+ from lm_eval.api.registry import register_model
28
+ from lm_eval.models.utils import Collator
29
+ from lm_eval.utils import (
30
+ eval_logger,
31
+ get_rolling_token_windows,
32
+ make_disjoint_window,
33
+ simple_parse_args_string,
34
+ )
35
+
36
+
37
+ def _patch_pretrained_cfg(
38
+ pretrained_cfg, trainer, tensor_model_parallel_size, pipeline_model_parallel_size
39
+ ):
40
+ try:
41
+ import omegaconf
42
+ except ModuleNotFoundError:
43
+ raise Exception(
44
+ "Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
45
+ "Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
46
+ "or installing nemo following https://github.com/NVIDIA/NeMo.",
47
+ )
48
+
49
+ omegaconf.OmegaConf.set_struct(pretrained_cfg, True)
50
+ with omegaconf.open_dict(pretrained_cfg):
51
+ attributes_to_update = {
52
+ "sequence_parallel": False,
53
+ "activations_checkpoint_granularity": None,
54
+ "activations_checkpoint_method": None,
55
+ "precision": trainer.precision,
56
+ "global_batch_size": None,
57
+ "tensor_model_parallel_size": tensor_model_parallel_size,
58
+ "pipeline_model_parallel_size": pipeline_model_parallel_size,
59
+ "apply_rope_fusion": False,
60
+ }
61
+ for name, value in attributes_to_update.items():
62
+ if hasattr(pretrained_cfg, name):
63
+ pretrained_cfg[name] = value
64
+ return pretrained_cfg
65
+
66
+
67
+ def _get_target_from_class(target_class) -> str:
68
+ return f"{target_class.__module__}.{target_class.__name__}"
69
+
70
+
71
+ def load_model(
72
+ model_path: str,
73
+ trainer,
74
+ tensor_model_parallel_size: int,
75
+ pipeline_model_parallel_size: int,
76
+ ) -> torch.nn.Module:
77
+ try:
78
+ from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import (
79
+ MegatronGPTModel,
80
+ )
81
+ from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
82
+ except ModuleNotFoundError:
83
+ raise Exception(
84
+ "Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
85
+ "Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
86
+ "or installing nemo following https://github.com/NVIDIA/NeMo.",
87
+ )
88
+ model_path = pathlib.Path(model_path)
89
+
90
+ save_restore_connector = NLPSaveRestoreConnector()
91
+ if model_path.is_dir():
92
+ save_restore_connector.model_extracted_dir = model_path.as_posix()
93
+ pretrained_cfg = save_restore_connector.restore_from(
94
+ None, model_path.as_posix(), return_config=True, trainer=trainer
95
+ )
96
+ if not hasattr(pretrained_cfg, "target"):
97
+ pretrained_cfg["target"] = _get_target_from_class(MegatronGPTModel)
98
+
99
+ pretrained_cfg = _patch_pretrained_cfg(
100
+ pretrained_cfg,
101
+ trainer,
102
+ tensor_model_parallel_size=tensor_model_parallel_size,
103
+ pipeline_model_parallel_size=pipeline_model_parallel_size,
104
+ )
105
+
106
+ model_to_load_path = model_path
107
+ override_config = pretrained_cfg
108
+
109
+ module_name, class_name = override_config.target.rsplit(".", 1)
110
+ model_class = getattr(importlib.import_module(module_name), class_name)
111
+
112
+ # monkeypatch _build_tokenizer method to be process-safe
113
+ tokenizer_lock = filelock.FileLock(f"/tmp/{model_path.name}.tokenizer.lock")
114
+
115
+ def _synced_build_tokenizer(self):
116
+ with tokenizer_lock:
117
+ self._original_build_tokenizer()
118
+
119
+ model_class._original_build_tokenizer = model_class._build_tokenizer
120
+ model_class._build_tokenizer = _synced_build_tokenizer
121
+
122
+ model = model_class.restore_from(
123
+ restore_path=model_to_load_path.as_posix(),
124
+ trainer=trainer,
125
+ override_config_path=override_config,
126
+ save_restore_connector=save_restore_connector,
127
+ map_location=f"cuda:{trainer.local_rank}",
128
+ )
129
+
130
+ model.freeze()
131
+ model.training = False
132
+ try:
133
+ # Have to turn off activations_checkpoint_method for inference
134
+ model.model.language_model.encoder.activations_checkpoint_method = None
135
+ except AttributeError:
136
+ pass
137
+ return model
138
+
139
+
140
+ def setup_distributed_environment(trainer):
141
+ try:
142
+ from nemo.utils.app_state import AppState
143
+ except ModuleNotFoundError:
144
+ raise Exception(
145
+ "Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
146
+ "Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
147
+ "or installing nemo following https://github.com/NVIDIA/NeMo.",
148
+ )
149
+
150
+ def dummy():
151
+ return
152
+
153
+ if trainer.strategy.launcher is not None:
154
+ trainer.strategy.launcher.launch(dummy, trainer=trainer)
155
+ trainer.strategy.setup_environment()
156
+
157
+ app_state = AppState()
158
+
159
+ return app_state
160
+
161
+
162
+ @register_model("nemo_lm")
163
+ class NeMoLM(LM):
164
+ def __init__(
165
+ self,
166
+ path: str,
167
+ max_length: int = 4096,
168
+ batch_size: int = 1,
169
+ max_gen_toks: int = 256,
170
+ devices: int = 1,
171
+ num_nodes: int = 1,
172
+ tensor_model_parallel_size: int = 1,
173
+ pipeline_model_parallel_size: int = 1,
174
+ precision: Literal[
175
+ "16-mixed",
176
+ "bf16-mixed",
177
+ "32-true",
178
+ "64-true",
179
+ 64,
180
+ 32,
181
+ 16,
182
+ "64",
183
+ "32",
184
+ "16",
185
+ "bf16",
186
+ ] = "bf16",
187
+ **kwargs,
188
+ ):
189
+ try:
190
+ from nemo.collections.nlp.modules.common.text_generation_utils import (
191
+ generate,
192
+ )
193
+ from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
194
+ from pytorch_lightning.trainer.trainer import Trainer
195
+
196
+ self.generate = generate
197
+ except ModuleNotFoundError:
198
+ raise Exception(
199
+ "Attempted to use 'nemo_lm' model type, but package `nemo` is not installed"
200
+ "Please install nemo following the instructions in the README: either with a NVIDIA PyTorch or NeMo container, "
201
+ "or installing nemo following https://github.com/NVIDIA/NeMo.",
202
+ )
203
+
204
+ super().__init__()
205
+
206
+ if (
207
+ tensor_model_parallel_size == 1
208
+ and pipeline_model_parallel_size == 1
209
+ and devices > 1
210
+ ):
211
+ eval_logger.info(
212
+ f"The number of data replicas for evaluation is {devices}."
213
+ )
214
+ eval_logger.info(f"The total number of devices is {devices}.")
215
+ eval_logger.info(
216
+ "No tensor parallelism or pipeline parallelism is applied."
217
+ )
218
+
219
+ elif tensor_model_parallel_size * pipeline_model_parallel_size == devices:
220
+ eval_logger.info(
221
+ f"Setting tensor parallelism to {tensor_model_parallel_size} and pipeline parallelism to {pipeline_model_parallel_size}."
222
+ )
223
+ eval_logger.info(f"The total number of devices is {devices}.")
224
+ eval_logger.info("No data parallelism is applied.")
225
+
226
+ else:
227
+ raise ValueError(
228
+ "Please set the product of tensor_model_parallel_size and pipeline_model_parallel_size"
229
+ "equal to the specified number of devices."
230
+ )
231
+
232
+ if num_nodes > 1:
233
+ raise ValueError(
234
+ "A number of nodes greater than 1 is not supported yet. Please set num_nodes as 1."
235
+ )
236
+
237
+ trainer = Trainer(
238
+ strategy=NLPDDPStrategy(),
239
+ devices=devices,
240
+ accelerator="gpu",
241
+ num_nodes=num_nodes,
242
+ precision=precision,
243
+ logger=False,
244
+ enable_checkpointing=False,
245
+ use_distributed_sampler=False,
246
+ )
247
+ # Modify the following flags only for data replication
248
+ if (
249
+ tensor_model_parallel_size == 1
250
+ and pipeline_model_parallel_size == 1
251
+ and devices > 1
252
+ ):
253
+ self._device = torch.device(f"cuda:{trainer.global_rank}")
254
+ self._rank = trainer.global_rank
255
+ self._world_size = trainer.world_size
256
+ self.model = load_model(
257
+ path,
258
+ trainer,
259
+ tensor_model_parallel_size=tensor_model_parallel_size,
260
+ pipeline_model_parallel_size=pipeline_model_parallel_size,
261
+ ).cuda()
262
+ self.tokenizer = self.model.tokenizer
263
+ self.app_state = setup_distributed_environment(trainer)
264
+
265
+ self._max_length = max_length
266
+ self._batch_size = int(batch_size)
267
+ self._max_gen_toks = max_gen_toks
268
+
269
+ @classmethod
270
+ def create_from_arg_string(cls, arg_string, additional_config=None):
271
+ args = simple_parse_args_string(arg_string)
272
+ if additional_config:
273
+ args["batch_size"] = additional_config.get("batch_size", 1)
274
+
275
+ return cls(**args)
276
+
277
+ @property
278
+ def eot_token_id(self):
279
+ try:
280
+ return self.tokenizer.eos_id
281
+ except AttributeError:
282
+ return None
283
+
284
+ @property
285
+ def max_length(self):
286
+ return self._max_length
287
+
288
+ @property
289
+ def max_gen_toks(self):
290
+ return self._max_gen_toks
291
+
292
+ @property
293
+ def batch_size(self):
294
+ return self._batch_size
295
+
296
+ @property
297
+ def device(self):
298
+ return self._device
299
+
300
+ @property
301
+ def rank(self):
302
+ return self._rank
303
+
304
+ @property
305
+ def world_size(self):
306
+ return self._world_size
307
+
308
+ @property
309
+ def accelerator(self):
310
+ return self._Accelerator(self.world_size)
311
+
312
+ class _Accelerator:
313
+ def __init__(self, world_size):
314
+ self.world_size = world_size
315
+
316
+ def wait_for_everyone(self):
317
+ torch.distributed.barrier()
318
+
319
+ def gather(self, local_tensor):
320
+ gathered_tensors = [
321
+ torch.zeros(1, dtype=local_tensor.dtype).cuda()
322
+ for _ in range(self.world_size)
323
+ ]
324
+ torch.distributed.all_gather(gathered_tensors, local_tensor)
325
+ return torch.cat(gathered_tensors)
326
+
327
+ def tok_encode(self, string: str):
328
+ return self.tokenizer.text_to_ids(string)
329
+
330
+ def tok_decode(self, tokens):
331
+ return self.tokenizer.ids_to_text(tokens)
332
+
333
+ def _encode_pair(self, context, continuation):
334
+ n_spaces = len(context) - len(context.rstrip())
335
+ if n_spaces > 0:
336
+ continuation = context[-n_spaces:] + continuation
337
+ context = context[:-n_spaces]
338
+ whole_enc = self.tok_encode(context + continuation)
339
+ context_enc = self.tok_encode(context)
340
+ context_enc_len = len(context_enc)
341
+ continuation_enc = whole_enc[context_enc_len:]
342
+ return context_enc, continuation_enc
343
+
344
+ def loglikelihood(self, requests):
345
+ new_reqs = []
346
+ for context, continuation in [req.args for req in requests]:
347
+ if context == "":
348
+ # end of text as context
349
+ context_enc, continuation_enc = (
350
+ [self.eot_token_id],
351
+ self.tok_encode(continuation),
352
+ )
353
+ else:
354
+ context_enc, continuation_enc = self._encode_pair(context, continuation)
355
+
356
+ new_reqs.append(((context, continuation), context_enc, continuation_enc))
357
+
358
+ return self._loglikelihood_tokens(new_reqs)
359
+
360
+ def loglikelihood_rolling(
361
+ self, requests: List[Instance], disable_tqdm: bool = False
362
+ ) -> List[float]:
363
+ loglikelihoods = []
364
+
365
+ for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
366
+ rolling_token_windows = list(
367
+ map(
368
+ make_disjoint_window,
369
+ get_rolling_token_windows(
370
+ token_list=self.tok_encode(string),
371
+ prefix_token=self.eot_token_id,
372
+ max_seq_len=self.max_length - 1,
373
+ context_len=1,
374
+ ),
375
+ )
376
+ )
377
+
378
+ rolling_token_windows = [(None,) + x for x in rolling_token_windows]
379
+
380
+ string_nll = self._loglikelihood_tokens(
381
+ rolling_token_windows,
382
+ )
383
+
384
+ # discard is_greedy
385
+ string_nll = [x[0] for x in string_nll]
386
+
387
+ string_nll = sum(string_nll)
388
+ loglikelihoods.append(string_nll)
389
+ return loglikelihoods
390
+
391
+ def _loglikelihood_tokens(self, requests, disable_tqdm=False):
392
+ res = []
393
+
394
+ def _collate(x):
395
+ toks = x[1] + x[2]
396
+ return -len(toks), tuple(toks)
397
+
398
+ re_ord = Collator(requests, sort_fn=_collate)
399
+ chunks = re_ord.get_batched(n=self.batch_size, batch_fn=None)
400
+ pbar = tqdm(
401
+ total=len(requests),
402
+ disable=(disable_tqdm or (self.rank != 0)),
403
+ desc="Running loglikelihood requests",
404
+ )
405
+ for chunk in chunks:
406
+ inps = []
407
+ ctxlens = []
408
+ contlens = []
409
+
410
+ for _, context_enc, continuation_enc in chunk:
411
+ # Leave one token for generation. Tokens_to_generate = 0 breaks NeMo.
412
+ inp = (context_enc + continuation_enc)[-(self.max_length - 1) :]
413
+
414
+ ctxlen = len(context_enc) - max(
415
+ 0, len(context_enc) + len(continuation_enc) - (self.max_length - 1)
416
+ )
417
+ ctxlens.append(ctxlen)
418
+ contlens.append(len(continuation_enc))
419
+
420
+ inps.append(self.tok_decode(inp))
421
+
422
+ output = self.generate(
423
+ self.model,
424
+ inputs=inps,
425
+ tokens_to_generate=1,
426
+ min_tokens_to_generate=1,
427
+ compute_logprob=True,
428
+ all_probs=True,
429
+ )
430
+
431
+ batch_token_ids = np.asarray(output["token_ids"])[:, :-1]
432
+ batch_logprobs = output["logprob"][:, :-1]
433
+ batch_full_logprob = output["full_logprob"][:, :-1, :]
434
+
435
+ # Compute greedy tokens for entire batch rather than calling it with proper ctxlen for each sample.
436
+ # Additional tokens for each sample will be trimmed later.
437
+ min_ctxlen = min(ctxlens)
438
+
439
+ # Use min_ctxlen-1 instead of min_ctxlen since full_logprobs are not returns for the first token.
440
+ batch_greedy_tokens = (
441
+ torch.argmax(batch_full_logprob[:, min_ctxlen - 1 :, :], -1)
442
+ .cpu()
443
+ .numpy()
444
+ )
445
+
446
+ for token_ids, greedy_tokens, logprobs, ctxlen, contlen, (
447
+ cache_key,
448
+ _,
449
+ _,
450
+ ) in zip(
451
+ batch_token_ids,
452
+ batch_greedy_tokens,
453
+ batch_logprobs,
454
+ ctxlens,
455
+ contlens,
456
+ chunk,
457
+ ):
458
+ # Trim at contlen since shorter contexts in a batch will have more than one token generated.
459
+ # Use ctxlen-1 instead of ctxlen same as for full_logprob in batch_greedy_tokens calculation
460
+ logprobs = (logprobs[ctxlen - 1 :])[:contlen]
461
+ logprob = sum(logprobs).tolist()
462
+
463
+ continuation_tokens = (token_ids[ctxlen:])[:contlen]
464
+ len_diff = ctxlen - min_ctxlen
465
+ is_greedy = continuation_tokens == (greedy_tokens[len_diff:])[:contlen]
466
+ if not isinstance(is_greedy, bool):
467
+ is_greedy = is_greedy.all()
468
+ answer = (logprob, is_greedy)
469
+
470
+ if cache_key is not None:
471
+ self.cache_hook.add_partial("loglikelihood", cache_key, answer)
472
+
473
+ res.append(answer)
474
+ pbar.update(1)
475
+
476
+ pbar.close()
477
+
478
+ return re_ord.get_original(res)
479
+
480
+ def generate_until(self, requests):
481
+ if not requests:
482
+ return []
483
+ res = []
484
+
485
+ def get_until(req_args):
486
+ until = req_args.get("until", [])
487
+ until = deepcopy(until) # prevent from modifying req_args for cache_key
488
+ if self.tokenizer.ids_to_tokens([self.eot_token_id])[0] not in until:
489
+ until.append(self.tokenizer.ids_to_tokens([self.eot_token_id])[0])
490
+ return until
491
+
492
+ def _collate(x):
493
+ toks = self.tok_encode(x[0])
494
+ return len(toks), x[0]
495
+
496
+ re_ords = Collator(
497
+ [reg.args for reg in requests], sort_fn=_collate, group_by="gen_kwargs"
498
+ )
499
+ chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
500
+ for chunk in chunks:
501
+ contexts, all_gen_kwargs = zip(*chunk)
502
+ # we assume all gen kwargs in the batch are the same
503
+ # this is safe to assume because the `grouper` object ensures it.
504
+ req_args = all_gen_kwargs[0]
505
+ # unpack our keyword arguments.
506
+ until = get_until(req_args)
507
+ max_gen_toks = req_args.get("max_gen_toks", self.max_gen_toks)
508
+
509
+ remaining_length = self.max_length - max_gen_toks
510
+ contexts = []
511
+ for context, _ in chunk:
512
+ encoded_context = self.tok_encode(context)
513
+ encoded_context = encoded_context[-remaining_length:]
514
+ contexts.append(self.tok_decode(encoded_context))
515
+
516
+ output = self.generate(
517
+ self.model,
518
+ inputs=contexts,
519
+ tokens_to_generate=max_gen_toks,
520
+ end_strings=until,
521
+ greedy=True,
522
+ )
523
+
524
+ answers = output["sentences"]
525
+
526
+ continuations = []
527
+ for context, answer in zip(contexts, answers):
528
+ continuations.append(answer[len(context) :])
529
+
530
+ for term in until:
531
+ continuations = [answer.split(term)[0] for answer in continuations]
532
+
533
+ for request, answer in zip(chunk, continuations):
534
+ self.cache_hook.add_partial("greedy_until", request, answer)
535
+ res.append(answer)
536
+
537
+ return re_ords.get_original(res)