Add files using upload-large-folder tool
Browse files- Koala-36M-v1/.gitattributes +68 -0
- Prism/LICENSE +201 -0
- Prism/LLaDA/LLaDA_Baseline/.gitignore +210 -0
- Prism/LLaDA/LLaDA_Baseline/LICENSE +21 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/__init__.py +7 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/__main__.py +527 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/evaluator.py +765 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/evaluator_utils.py +554 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/__init__.py +25 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/custom.py +17 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/decontamination.py +25 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/extraction.py +233 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/selection.py +61 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/transformation.py +122 -0
- Prism/LLaDA/LLaDA_Baseline/dllm_eval/utils.py +552 -0
- Prism/LLaDA/LLaDA_Baseline/evaluation_script.py +21 -0
- Prism/LLaDA/LLaDA_Baseline/metrics/gsm8k_all.py +286 -0
- Prism/LLaDA/LLaDA_Baseline/metrics/humaneval_all.py +183 -0
- Prism/LLaDA/LLaDA_Baseline/metrics/math500_all.py +213 -0
- Prism/LLaDA/LLaDA_Baseline/metrics/mbpp_all.py +194 -0
- Prism/LLaDA/LLaDA_Baseline/requirements.txt +9 -0
- Prism/LLaDA/LLaDA_Baseline/scripts/run_gsm8k.sh +32 -0
- Prism/LLaDA/LLaDA_Baseline/scripts/run_humaneval.sh +29 -0
- Prism/LLaDA/LLaDA_Baseline/scripts/run_math500.sh +29 -0
- Prism/LLaDA/LLaDA_Baseline/scripts/run_mbpp.sh +29 -0
- Prism/LLaDA/LLaDA_Prism/.gitignore +210 -0
- Prism/LLaDA/LLaDA_Prism/LICENSE +21 -0
- Prism/LLaDA/LLaDA_Prism/evaluation_script.py +21 -0
- Prism/LLaDA/LLaDA_Prism/requirements.txt +9 -0
- Prism/LLaDA/LLaDA_Truthfulqa/.gitignore +3 -0
- Prism/LLaDA/LLaDA_Truthfulqa/LICENSE +201 -0
- Prism/LLaDA/LLaDA_Truthfulqa/eval_llada.py +413 -0
- Prism/LLaDA/LLaDA_Truthfulqa/eval_llada_prism.py +333 -0
- Prism/README.md +107 -0
- URSA-1.7B/.gitattributes +37 -0
- URSA-1.7B/.gitignore +55 -0
- URSA-1.7B/LICENSE +176 -0
- URSA-1.7B/README.md +117 -0
- URSA-1.7B/model_index.json +19 -0
- URSA/.flake8 +21 -0
- URSA/.gitignore +55 -0
- URSA/=4.57.1 +70 -0
- URSA/LICENSE +176 -0
- URSA/README.md +191 -0
- URSA/inference.py +71 -0
- URSA/pyproject.toml +3 -0
- URSA/requirements.txt +10 -0
- URSA/setup.py +133 -0
- URSA/ursa.jpg +0 -0
- URSA/version.txt +1 -0
Koala-36M-v1/.gitattributes
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
# Audio files - uncompressed
|
| 38 |
+
*.pcm filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.sam filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.raw filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
# Audio files - compressed
|
| 42 |
+
*.aac filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
*.flac filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.ogg filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
# Image files - uncompressed
|
| 48 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
*.tiff filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
# Image files - compressed
|
| 53 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
# Video files - compressed
|
| 57 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
*.webm filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
Koala_36M_1.csv filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
Koala_36M_2.csv filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
Koala_36M_3.csv filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
Koala_36M_4.csv filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
Koala_36M_5.csv filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
Koala_36M_6.csv filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
Koala_36M_7.csv filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
Koala_36M_8.csv filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
Koala_36M_9.csv filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
Koala_36M_10.csv filter=lfs diff=lfs merge=lfs -text
|
Prism/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
Prism/LLaDA/LLaDA_Baseline/.gitignore
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.jsonl
|
| 2 |
+
*.json
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[codz]
|
| 7 |
+
*$py.class
|
| 8 |
+
|
| 9 |
+
# C extensions
|
| 10 |
+
*.so
|
| 11 |
+
|
| 12 |
+
# Distribution / packaging
|
| 13 |
+
.Python
|
| 14 |
+
build/
|
| 15 |
+
develop-eggs/
|
| 16 |
+
dist/
|
| 17 |
+
downloads/
|
| 18 |
+
eggs/
|
| 19 |
+
.eggs/
|
| 20 |
+
lib/
|
| 21 |
+
lib64/
|
| 22 |
+
parts/
|
| 23 |
+
sdist/
|
| 24 |
+
var/
|
| 25 |
+
wheels/
|
| 26 |
+
share/python-wheels/
|
| 27 |
+
*.egg-info/
|
| 28 |
+
.installed.cfg
|
| 29 |
+
*.egg
|
| 30 |
+
MANIFEST
|
| 31 |
+
|
| 32 |
+
# PyInstaller
|
| 33 |
+
# Usually these files are written by a python script from a template
|
| 34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 35 |
+
*.manifest
|
| 36 |
+
*.spec
|
| 37 |
+
|
| 38 |
+
# Installer logs
|
| 39 |
+
pip-log.txt
|
| 40 |
+
pip-delete-this-directory.txt
|
| 41 |
+
|
| 42 |
+
# Unit test / coverage reports
|
| 43 |
+
htmlcov/
|
| 44 |
+
.tox/
|
| 45 |
+
.nox/
|
| 46 |
+
.coverage
|
| 47 |
+
.coverage.*
|
| 48 |
+
.cache
|
| 49 |
+
nosetests.xml
|
| 50 |
+
coverage.xml
|
| 51 |
+
*.cover
|
| 52 |
+
*.py.cover
|
| 53 |
+
.hypothesis/
|
| 54 |
+
.pytest_cache/
|
| 55 |
+
cover/
|
| 56 |
+
|
| 57 |
+
# Translations
|
| 58 |
+
*.mo
|
| 59 |
+
*.pot
|
| 60 |
+
|
| 61 |
+
# Django stuff:
|
| 62 |
+
*.log
|
| 63 |
+
local_settings.py
|
| 64 |
+
db.sqlite3
|
| 65 |
+
db.sqlite3-journal
|
| 66 |
+
|
| 67 |
+
# Flask stuff:
|
| 68 |
+
instance/
|
| 69 |
+
.webassets-cache
|
| 70 |
+
|
| 71 |
+
# Scrapy stuff:
|
| 72 |
+
.scrapy
|
| 73 |
+
|
| 74 |
+
# Sphinx documentation
|
| 75 |
+
docs/_build/
|
| 76 |
+
|
| 77 |
+
# PyBuilder
|
| 78 |
+
.pybuilder/
|
| 79 |
+
target/
|
| 80 |
+
|
| 81 |
+
# Jupyter Notebook
|
| 82 |
+
.ipynb_checkpoints
|
| 83 |
+
|
| 84 |
+
# IPython
|
| 85 |
+
profile_default/
|
| 86 |
+
ipython_config.py
|
| 87 |
+
|
| 88 |
+
# pyenv
|
| 89 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 91 |
+
# .python-version
|
| 92 |
+
|
| 93 |
+
# pipenv
|
| 94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 97 |
+
# install all needed dependencies.
|
| 98 |
+
#Pipfile.lock
|
| 99 |
+
|
| 100 |
+
# UV
|
| 101 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 103 |
+
# commonly ignored for libraries.
|
| 104 |
+
#uv.lock
|
| 105 |
+
|
| 106 |
+
# poetry
|
| 107 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 108 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 109 |
+
# commonly ignored for libraries.
|
| 110 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 111 |
+
#poetry.lock
|
| 112 |
+
#poetry.toml
|
| 113 |
+
|
| 114 |
+
# pdm
|
| 115 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 116 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 117 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 118 |
+
#pdm.lock
|
| 119 |
+
#pdm.toml
|
| 120 |
+
.pdm-python
|
| 121 |
+
.pdm-build/
|
| 122 |
+
|
| 123 |
+
# pixi
|
| 124 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 125 |
+
#pixi.lock
|
| 126 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 127 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 128 |
+
.pixi
|
| 129 |
+
|
| 130 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 131 |
+
__pypackages__/
|
| 132 |
+
|
| 133 |
+
# Celery stuff
|
| 134 |
+
celerybeat-schedule
|
| 135 |
+
celerybeat.pid
|
| 136 |
+
|
| 137 |
+
# SageMath parsed files
|
| 138 |
+
*.sage.py
|
| 139 |
+
|
| 140 |
+
# Environments
|
| 141 |
+
.env
|
| 142 |
+
.envrc
|
| 143 |
+
.venv
|
| 144 |
+
env/
|
| 145 |
+
venv/
|
| 146 |
+
ENV/
|
| 147 |
+
env.bak/
|
| 148 |
+
venv.bak/
|
| 149 |
+
|
| 150 |
+
# Spyder project settings
|
| 151 |
+
.spyderproject
|
| 152 |
+
.spyproject
|
| 153 |
+
|
| 154 |
+
# Rope project settings
|
| 155 |
+
.ropeproject
|
| 156 |
+
|
| 157 |
+
# mkdocs documentation
|
| 158 |
+
/site
|
| 159 |
+
|
| 160 |
+
# mypy
|
| 161 |
+
.mypy_cache/
|
| 162 |
+
.dmypy.json
|
| 163 |
+
dmypy.json
|
| 164 |
+
|
| 165 |
+
# Pyre type checker
|
| 166 |
+
.pyre/
|
| 167 |
+
|
| 168 |
+
# pytype static type analyzer
|
| 169 |
+
.pytype/
|
| 170 |
+
|
| 171 |
+
# Cython debug symbols
|
| 172 |
+
cython_debug/
|
| 173 |
+
|
| 174 |
+
# PyCharm
|
| 175 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 176 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 177 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 178 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 179 |
+
#.idea/
|
| 180 |
+
|
| 181 |
+
# Abstra
|
| 182 |
+
# Abstra is an AI-powered process automation framework.
|
| 183 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 184 |
+
# Learn more at https://abstra.io/docs
|
| 185 |
+
.abstra/
|
| 186 |
+
|
| 187 |
+
# Visual Studio Code
|
| 188 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 189 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 190 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 191 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 192 |
+
# .vscode/
|
| 193 |
+
|
| 194 |
+
# Ruff stuff:
|
| 195 |
+
.ruff_cache/
|
| 196 |
+
|
| 197 |
+
# PyPI configuration file
|
| 198 |
+
.pypirc
|
| 199 |
+
|
| 200 |
+
# Cursor
|
| 201 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 202 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 203 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 204 |
+
.cursorignore
|
| 205 |
+
.cursorindexingignore
|
| 206 |
+
|
| 207 |
+
# Marimo
|
| 208 |
+
marimo/_static/
|
| 209 |
+
marimo/_lsp/
|
| 210 |
+
__marimo__/
|
Prism/LLaDA/LLaDA_Baseline/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 preordinary
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from .evaluator import evaluate, simple_evaluate
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
__version__ = "0.4.9"
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/__main__.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from functools import partial
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
from dllm_eval import evaluator, utils
|
| 11 |
+
from dllm_eval.evaluator import request_caching_arg_to_dict
|
| 12 |
+
from dllm_eval.loggers import EvaluationTracker, WandbLogger
|
| 13 |
+
from dllm_eval.tasks import TaskManager
|
| 14 |
+
from dllm_eval.utils import (
|
| 15 |
+
handle_non_serializable,
|
| 16 |
+
make_table,
|
| 17 |
+
simple_parse_args_string,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def try_parse_json(value: str) -> Union[str, dict, None]:
|
| 22 |
+
if value is None:
|
| 23 |
+
return None
|
| 24 |
+
try:
|
| 25 |
+
return json.loads(value)
|
| 26 |
+
except json.JSONDecodeError:
|
| 27 |
+
if "{" in value:
|
| 28 |
+
raise argparse.ArgumentTypeError(
|
| 29 |
+
f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
|
| 30 |
+
)
|
| 31 |
+
return value
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _int_or_none_list_arg_type(
|
| 35 |
+
min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
|
| 36 |
+
):
|
| 37 |
+
def parse_value(item):
|
| 38 |
+
item = item.strip().lower()
|
| 39 |
+
if item == "none":
|
| 40 |
+
return None
|
| 41 |
+
try:
|
| 42 |
+
return int(item)
|
| 43 |
+
except ValueError:
|
| 44 |
+
raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
|
| 45 |
+
|
| 46 |
+
items = [parse_value(v) for v in value.split(split_char)]
|
| 47 |
+
num_items = len(items)
|
| 48 |
+
|
| 49 |
+
if num_items == 1:
|
| 50 |
+
# Makes downstream handling the same for single and multiple values
|
| 51 |
+
items = items * max_len
|
| 52 |
+
elif num_items < min_len or num_items > max_len:
|
| 53 |
+
raise argparse.ArgumentTypeError(
|
| 54 |
+
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
|
| 55 |
+
)
|
| 56 |
+
elif num_items != max_len:
|
| 57 |
+
logging.warning(
|
| 58 |
+
f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
|
| 59 |
+
"Missing values will be filled with defaults."
|
| 60 |
+
)
|
| 61 |
+
default_items = [parse_value(v) for v in defaults.split(split_char)]
|
| 62 |
+
items.extend(
|
| 63 |
+
default_items[num_items:]
|
| 64 |
+
) # extend items list with missing defaults
|
| 65 |
+
|
| 66 |
+
return items
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def check_argument_types(parser: argparse.ArgumentParser):
|
| 70 |
+
"""
|
| 71 |
+
Check to make sure all CLI args are typed, raises error if not
|
| 72 |
+
"""
|
| 73 |
+
for action in parser._actions:
|
| 74 |
+
if action.dest != "help" and not action.const:
|
| 75 |
+
if action.type is None:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
f"Argument '{action.dest}' doesn't have a type specified."
|
| 78 |
+
)
|
| 79 |
+
else:
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 84 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--tasks",
|
| 90 |
+
"-t",
|
| 91 |
+
default=None,
|
| 92 |
+
type=str,
|
| 93 |
+
metavar="task1,task2",
|
| 94 |
+
help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--model_args",
|
| 98 |
+
"-a",
|
| 99 |
+
default="",
|
| 100 |
+
type=try_parse_json,
|
| 101 |
+
help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""",
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--num_fewshot",
|
| 105 |
+
"-f",
|
| 106 |
+
type=int,
|
| 107 |
+
default=None,
|
| 108 |
+
metavar="N",
|
| 109 |
+
help="Number of examples in few-shot context",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--batch_size",
|
| 113 |
+
"-b",
|
| 114 |
+
type=str,
|
| 115 |
+
default=1,
|
| 116 |
+
metavar="auto|auto:N|N",
|
| 117 |
+
help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--max_batch_size",
|
| 121 |
+
type=int,
|
| 122 |
+
default=None,
|
| 123 |
+
metavar="N",
|
| 124 |
+
help="Maximal batch size to try with --batch_size auto.",
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--device",
|
| 128 |
+
type=str,
|
| 129 |
+
default=None,
|
| 130 |
+
help="Device to use (e.g. cuda, cuda:0, cpu).",
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--output_path",
|
| 134 |
+
"-o",
|
| 135 |
+
default=None,
|
| 136 |
+
type=str,
|
| 137 |
+
metavar="DIR|DIR/file.json",
|
| 138 |
+
help="Path where result metrics will be saved. Can be either a directory or a .json file. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--limit",
|
| 142 |
+
"-L",
|
| 143 |
+
type=float,
|
| 144 |
+
default=None,
|
| 145 |
+
metavar="N|0<N<1",
|
| 146 |
+
help="Limit the number of examples per task. "
|
| 147 |
+
"If <1, limit is a percentage of the total number of examples.",
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--samples",
|
| 151 |
+
"-E",
|
| 152 |
+
default=None,
|
| 153 |
+
type=str,
|
| 154 |
+
metavar="/path/to/json",
|
| 155 |
+
help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--use_cache",
|
| 159 |
+
"-c",
|
| 160 |
+
type=str,
|
| 161 |
+
default=None,
|
| 162 |
+
metavar="DIR",
|
| 163 |
+
help="A path to a sqlite db file for caching model responses. `None` if not caching.",
|
| 164 |
+
)
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--cache_requests",
|
| 167 |
+
type=str,
|
| 168 |
+
default=None,
|
| 169 |
+
choices=["true", "refresh", "delete"],
|
| 170 |
+
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--check_integrity",
|
| 174 |
+
action="store_true",
|
| 175 |
+
help="Whether to run the relevant part of the test suite for the tasks.",
|
| 176 |
+
)
|
| 177 |
+
parser.add_argument(
|
| 178 |
+
"--write_out",
|
| 179 |
+
"-w",
|
| 180 |
+
action="store_true",
|
| 181 |
+
default=False,
|
| 182 |
+
help="Prints the prompt for the first few documents.",
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--log_samples",
|
| 186 |
+
"-s",
|
| 187 |
+
action="store_true",
|
| 188 |
+
default=False,
|
| 189 |
+
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--system_instruction",
|
| 193 |
+
type=str,
|
| 194 |
+
default=None,
|
| 195 |
+
help="System instruction to be used in the prompt",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--apply_chat_template",
|
| 199 |
+
type=str,
|
| 200 |
+
nargs="?",
|
| 201 |
+
const=True,
|
| 202 |
+
default=False,
|
| 203 |
+
help=(
|
| 204 |
+
"If True, apply chat template to the prompt. "
|
| 205 |
+
"Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
|
| 206 |
+
"To apply a specific template from the available list of templates, provide the template name as an argument. "
|
| 207 |
+
"E.g. `--apply_chat_template template_name`"
|
| 208 |
+
),
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--fewshot_as_multiturn",
|
| 212 |
+
action="store_true",
|
| 213 |
+
default=False,
|
| 214 |
+
help="If True, uses the fewshot as a multi-turn conversation",
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--show_config",
|
| 218 |
+
action="store_true",
|
| 219 |
+
default=False,
|
| 220 |
+
help="If True, shows the the full config of all tasks at the end of the evaluation.",
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--include_path",
|
| 224 |
+
type=str,
|
| 225 |
+
default=None,
|
| 226 |
+
metavar="DIR",
|
| 227 |
+
help="Additional path to include if there are external tasks to include.",
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--gen_kwargs",
|
| 231 |
+
type=try_parse_json,
|
| 232 |
+
default=None,
|
| 233 |
+
help=(
|
| 234 |
+
"Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
|
| 235 |
+
""" e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
|
| 236 |
+
),
|
| 237 |
+
)
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--verbosity",
|
| 240 |
+
"-v",
|
| 241 |
+
type=str.upper,
|
| 242 |
+
default=None,
|
| 243 |
+
metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
|
| 244 |
+
help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.",
|
| 245 |
+
)
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
"--wandb_args",
|
| 248 |
+
type=str,
|
| 249 |
+
default="",
|
| 250 |
+
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
|
| 251 |
+
)
|
| 252 |
+
parser.add_argument(
|
| 253 |
+
"--wandb_config_args",
|
| 254 |
+
type=str,
|
| 255 |
+
default="",
|
| 256 |
+
help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3",
|
| 257 |
+
)
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
"--hf_hub_log_args",
|
| 260 |
+
type=str,
|
| 261 |
+
default="",
|
| 262 |
+
help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
|
| 263 |
+
)
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
"--predict_only",
|
| 266 |
+
"-x",
|
| 267 |
+
action="store_true",
|
| 268 |
+
default=False,
|
| 269 |
+
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
|
| 270 |
+
)
|
| 271 |
+
default_seed_string = "0,1234,1234,1234"
|
| 272 |
+
parser.add_argument(
|
| 273 |
+
"--seed",
|
| 274 |
+
type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
|
| 275 |
+
default=default_seed_string, # for backward compatibility
|
| 276 |
+
help=(
|
| 277 |
+
"Set seed for python's random, numpy, torch, and fewshot sampling.\n"
|
| 278 |
+
"Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
|
| 279 |
+
"respectively, or a single integer to set the same seed for all four.\n"
|
| 280 |
+
f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
|
| 281 |
+
"(for backward compatibility).\n"
|
| 282 |
+
"E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
|
| 283 |
+
"Here numpy's seed is not set since the second value is `None`.\n"
|
| 284 |
+
"E.g, `--seed 42` sets all four seeds to 42."
|
| 285 |
+
),
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--trust_remote_code",
|
| 289 |
+
action="store_true",
|
| 290 |
+
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
|
| 291 |
+
)
|
| 292 |
+
parser.add_argument(
|
| 293 |
+
"--confirm_run_unsafe_code",
|
| 294 |
+
action="store_true",
|
| 295 |
+
help="Confirm that you understand the risks of running unsafe code for tasks that require it",
|
| 296 |
+
)
|
| 297 |
+
parser.add_argument(
|
| 298 |
+
"--metadata",
|
| 299 |
+
type=json.loads,
|
| 300 |
+
default=None,
|
| 301 |
+
help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
|
| 302 |
+
)
|
| 303 |
+
return parser
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
| 307 |
+
check_argument_types(parser)
|
| 308 |
+
return parser.parse_args()
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
|
| 312 |
+
if not args:
|
| 313 |
+
# we allow for args to be passed externally, else we parse them ourselves
|
| 314 |
+
parser = setup_parser()
|
| 315 |
+
args = parse_eval_args(parser)
|
| 316 |
+
|
| 317 |
+
if args.wandb_args:
|
| 318 |
+
wandb_args_dict = simple_parse_args_string(args.wandb_args)
|
| 319 |
+
wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args)
|
| 320 |
+
wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict)
|
| 321 |
+
|
| 322 |
+
utils.setup_logging(args.verbosity)
|
| 323 |
+
eval_logger = logging.getLogger(__name__)
|
| 324 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 325 |
+
|
| 326 |
+
# update the evaluation tracker args with the output path and the HF token
|
| 327 |
+
if args.output_path:
|
| 328 |
+
args.hf_hub_log_args += f",output_path={args.output_path}"
|
| 329 |
+
if os.environ.get("HF_TOKEN", None):
|
| 330 |
+
args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}"
|
| 331 |
+
evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
|
| 332 |
+
evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
|
| 333 |
+
|
| 334 |
+
if args.predict_only:
|
| 335 |
+
args.log_samples = True
|
| 336 |
+
if (args.log_samples or args.predict_only) and not args.output_path:
|
| 337 |
+
raise ValueError(
|
| 338 |
+
"Specify --output_path if providing --log_samples or --predict_only"
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
if args.fewshot_as_multiturn and args.apply_chat_template is False:
|
| 342 |
+
raise ValueError(
|
| 343 |
+
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
if args.include_path is not None:
|
| 347 |
+
eval_logger.info(f"Including path: {args.include_path}")
|
| 348 |
+
metadata = (
|
| 349 |
+
simple_parse_args_string(args.model_args)
|
| 350 |
+
if isinstance(args.model_args, str)
|
| 351 |
+
else args.model_args
|
| 352 |
+
if isinstance(args.model_args, dict)
|
| 353 |
+
else {}
|
| 354 |
+
) | (
|
| 355 |
+
args.metadata
|
| 356 |
+
if isinstance(args.metadata, dict)
|
| 357 |
+
else simple_parse_args_string(args.metadata)
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
task_manager = TaskManager(include_path=args.include_path, metadata=metadata)
|
| 361 |
+
|
| 362 |
+
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
|
| 363 |
+
eval_logger.warning(
|
| 364 |
+
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
if args.limit:
|
| 368 |
+
eval_logger.warning(
|
| 369 |
+
" --limit SHOULD ONLY BE USED FOR TESTING."
|
| 370 |
+
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
|
| 371 |
+
)
|
| 372 |
+
if args.samples:
|
| 373 |
+
assert args.limit is None, (
|
| 374 |
+
"If --samples is not None, then --limit must be None."
|
| 375 |
+
)
|
| 376 |
+
if (samples := Path(args.samples)).is_file():
|
| 377 |
+
args.samples = json.loads(samples.read_text())
|
| 378 |
+
else:
|
| 379 |
+
args.samples = json.loads(args.samples)
|
| 380 |
+
|
| 381 |
+
if args.tasks is None:
|
| 382 |
+
eval_logger.error("Need to specify task to evaluate.")
|
| 383 |
+
sys.exit()
|
| 384 |
+
elif args.tasks == "list":
|
| 385 |
+
print(task_manager.list_all_tasks())
|
| 386 |
+
sys.exit()
|
| 387 |
+
elif args.tasks == "list_groups":
|
| 388 |
+
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
|
| 389 |
+
sys.exit()
|
| 390 |
+
elif args.tasks == "list_tags":
|
| 391 |
+
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
|
| 392 |
+
sys.exit()
|
| 393 |
+
elif args.tasks == "list_subtasks":
|
| 394 |
+
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
|
| 395 |
+
sys.exit()
|
| 396 |
+
else:
|
| 397 |
+
if os.path.isdir(args.tasks):
|
| 398 |
+
import glob
|
| 399 |
+
|
| 400 |
+
task_names = []
|
| 401 |
+
yaml_path = os.path.join(args.tasks, "*.yaml")
|
| 402 |
+
for yaml_file in glob.glob(yaml_path):
|
| 403 |
+
config = utils.load_yaml_config(yaml_file)
|
| 404 |
+
task_names.append(config)
|
| 405 |
+
else:
|
| 406 |
+
task_list = args.tasks.split(",")
|
| 407 |
+
task_names = task_manager.match_tasks(task_list)
|
| 408 |
+
for task in [task for task in task_list if task not in task_names]:
|
| 409 |
+
if os.path.isfile(task):
|
| 410 |
+
config = utils.load_yaml_config(task)
|
| 411 |
+
task_names.append(config)
|
| 412 |
+
task_missing = [
|
| 413 |
+
task for task in task_list if task not in task_names and "*" not in task
|
| 414 |
+
] # we don't want errors if a wildcard ("*") task name was used
|
| 415 |
+
|
| 416 |
+
if task_missing:
|
| 417 |
+
missing = ", ".join(task_missing)
|
| 418 |
+
eval_logger.error(
|
| 419 |
+
f"Tasks were not found: {missing}\n"
|
| 420 |
+
f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
|
| 421 |
+
)
|
| 422 |
+
raise ValueError(
|
| 423 |
+
f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
|
| 427 |
+
if args.trust_remote_code:
|
| 428 |
+
eval_logger.info(
|
| 429 |
+
"Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
|
| 430 |
+
)
|
| 431 |
+
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
|
| 432 |
+
# because it's already been determined based on the prior env var before launching our
|
| 433 |
+
# script--`datasets` gets imported by dllm_eval internally before these lines can update the env.
|
| 434 |
+
import datasets
|
| 435 |
+
|
| 436 |
+
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
|
| 437 |
+
|
| 438 |
+
args.model_args = args.model_args + ",trust_remote_code=True"
|
| 439 |
+
(
|
| 440 |
+
eval_logger.info(f"Selected Tasks: {task_names}")
|
| 441 |
+
if eval_logger.getEffectiveLevel() >= logging.INFO
|
| 442 |
+
else print(f"Selected Tasks: {task_names}")
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
request_caching_args = request_caching_arg_to_dict(
|
| 446 |
+
cache_requests=args.cache_requests
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
results = evaluator.simple_evaluate(
|
| 450 |
+
model=args.model,
|
| 451 |
+
model_args=args.model_args,
|
| 452 |
+
tasks=task_names,
|
| 453 |
+
num_fewshot=args.num_fewshot,
|
| 454 |
+
batch_size=args.batch_size,
|
| 455 |
+
max_batch_size=args.max_batch_size,
|
| 456 |
+
device=args.device,
|
| 457 |
+
use_cache=args.use_cache,
|
| 458 |
+
limit=args.limit,
|
| 459 |
+
samples=args.samples,
|
| 460 |
+
check_integrity=args.check_integrity,
|
| 461 |
+
write_out=args.write_out,
|
| 462 |
+
log_samples=args.log_samples,
|
| 463 |
+
evaluation_tracker=evaluation_tracker,
|
| 464 |
+
system_instruction=args.system_instruction,
|
| 465 |
+
apply_chat_template=args.apply_chat_template,
|
| 466 |
+
fewshot_as_multiturn=args.fewshot_as_multiturn,
|
| 467 |
+
gen_kwargs=args.gen_kwargs,
|
| 468 |
+
task_manager=task_manager,
|
| 469 |
+
predict_only=args.predict_only,
|
| 470 |
+
random_seed=args.seed[0],
|
| 471 |
+
numpy_random_seed=args.seed[1],
|
| 472 |
+
torch_random_seed=args.seed[2],
|
| 473 |
+
fewshot_random_seed=args.seed[3],
|
| 474 |
+
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
|
| 475 |
+
metadata=metadata,
|
| 476 |
+
**request_caching_args,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
if results is not None:
|
| 480 |
+
if args.log_samples:
|
| 481 |
+
samples = results.pop("samples")
|
| 482 |
+
dumped = json.dumps(
|
| 483 |
+
results, indent=2, default=handle_non_serializable, ensure_ascii=False
|
| 484 |
+
)
|
| 485 |
+
if args.show_config:
|
| 486 |
+
print(dumped)
|
| 487 |
+
|
| 488 |
+
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
|
| 489 |
+
|
| 490 |
+
# Add W&B logging
|
| 491 |
+
if args.wandb_args:
|
| 492 |
+
try:
|
| 493 |
+
wandb_logger.post_init(results)
|
| 494 |
+
wandb_logger.log_eval_result()
|
| 495 |
+
if args.log_samples:
|
| 496 |
+
wandb_logger.log_eval_samples(samples)
|
| 497 |
+
except Exception as e:
|
| 498 |
+
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
|
| 499 |
+
|
| 500 |
+
evaluation_tracker.save_results_aggregated(
|
| 501 |
+
results=results, samples=samples if args.log_samples else None
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if args.log_samples:
|
| 505 |
+
for task_name, config in results["configs"].items():
|
| 506 |
+
evaluation_tracker.save_results_samples(
|
| 507 |
+
task_name=task_name, samples=samples[task_name]
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if (
|
| 511 |
+
evaluation_tracker.push_results_to_hub
|
| 512 |
+
or evaluation_tracker.push_samples_to_hub
|
| 513 |
+
):
|
| 514 |
+
evaluation_tracker.recreate_metadata_card()
|
| 515 |
+
|
| 516 |
+
print(
|
| 517 |
+
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
|
| 518 |
+
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
if args.wandb_args:
|
| 522 |
+
# Tear down wandb run once all the logging is done.
|
| 523 |
+
wandb_logger.run.finish()
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
if __name__ == "__main__":
|
| 527 |
+
cli_evaluate()
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/evaluator.py
ADDED
|
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import random
|
| 5 |
+
import time
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from typing import TYPE_CHECKING, List, Optional, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import dllm_eval.api.metrics
|
| 13 |
+
import dllm_eval.api.registry
|
| 14 |
+
import dllm_eval.api.task
|
| 15 |
+
import dllm_eval.models
|
| 16 |
+
from dllm_eval.caching.cache import delete_cache
|
| 17 |
+
from dllm_eval.evaluator_utils import (
|
| 18 |
+
consolidate_group_results,
|
| 19 |
+
consolidate_results,
|
| 20 |
+
get_sample_size,
|
| 21 |
+
get_subtask_list,
|
| 22 |
+
get_task_list,
|
| 23 |
+
prepare_print_tasks,
|
| 24 |
+
print_writeout,
|
| 25 |
+
run_task_tests,
|
| 26 |
+
)
|
| 27 |
+
from dllm_eval.loggers import EvaluationTracker
|
| 28 |
+
from dllm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
|
| 29 |
+
from dllm_eval.tasks import TaskManager, get_task_dict
|
| 30 |
+
from dllm_eval.utils import (
|
| 31 |
+
handle_non_serializable,
|
| 32 |
+
hash_string,
|
| 33 |
+
positional_deprecated,
|
| 34 |
+
setup_logging,
|
| 35 |
+
simple_parse_args_string,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if TYPE_CHECKING:
|
| 40 |
+
from dllm_eval.api.model import LM
|
| 41 |
+
from dllm_eval.api.task import Task
|
| 42 |
+
|
| 43 |
+
eval_logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@positional_deprecated
|
| 47 |
+
def simple_evaluate(
|
| 48 |
+
model,
|
| 49 |
+
model_args: Optional[Union[str, dict]] = None,
|
| 50 |
+
tasks: Optional[List[Union[str, dict, object]]] = None,
|
| 51 |
+
num_fewshot: Optional[int] = None,
|
| 52 |
+
batch_size: Optional[Union[int, str]] = None,
|
| 53 |
+
max_batch_size: Optional[int] = None,
|
| 54 |
+
device: Optional[str] = None,
|
| 55 |
+
use_cache: Optional[str] = None,
|
| 56 |
+
cache_requests: bool = False,
|
| 57 |
+
rewrite_requests_cache: bool = False,
|
| 58 |
+
delete_requests_cache: bool = False,
|
| 59 |
+
limit: Optional[Union[int, float]] = None,
|
| 60 |
+
samples: Optional[dict] = None,
|
| 61 |
+
bootstrap_iters: int = 100000,
|
| 62 |
+
check_integrity: bool = False,
|
| 63 |
+
write_out: bool = False,
|
| 64 |
+
log_samples: bool = True,
|
| 65 |
+
evaluation_tracker: Optional[EvaluationTracker] = None,
|
| 66 |
+
system_instruction: Optional[str] = None,
|
| 67 |
+
apply_chat_template: Union[bool, str] = False,
|
| 68 |
+
fewshot_as_multiturn: bool = False,
|
| 69 |
+
gen_kwargs: Union[str, dict, None] = None,
|
| 70 |
+
task_manager: Optional[TaskManager] = None,
|
| 71 |
+
verbosity=None,
|
| 72 |
+
predict_only: bool = False,
|
| 73 |
+
random_seed: int = 0,
|
| 74 |
+
numpy_random_seed: int = 1234,
|
| 75 |
+
torch_random_seed: int = 1234,
|
| 76 |
+
fewshot_random_seed: int = 1234,
|
| 77 |
+
confirm_run_unsafe_code: bool = False,
|
| 78 |
+
metadata: Optional[dict] = None,
|
| 79 |
+
):
|
| 80 |
+
"""Instantiate and evaluate a model on a list of tasks.
|
| 81 |
+
|
| 82 |
+
:param model: Union[str, LM]
|
| 83 |
+
Name of model or LM object, see dllm_eval.models.get_model
|
| 84 |
+
:param model_args: Optional[str, dict]
|
| 85 |
+
String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
|
| 86 |
+
Ignored if `model` argument is a LM object.
|
| 87 |
+
:param tasks: list[Union[str, dict, Task]]
|
| 88 |
+
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
|
| 89 |
+
:param num_fewshot: int
|
| 90 |
+
Number of examples in few-shot context
|
| 91 |
+
:param batch_size: int or str, optional
|
| 92 |
+
Batch size for model
|
| 93 |
+
:param max_batch_size: int, optional
|
| 94 |
+
Maximal batch size to try with automatic batch size detection
|
| 95 |
+
:param device: str, optional
|
| 96 |
+
PyTorch device (e.g. "cpu" or "cuda:0") for running models
|
| 97 |
+
:param use_cache: str, optional
|
| 98 |
+
A path to a sqlite db file for caching model responses. `None` if not caching.
|
| 99 |
+
:param cache_requests: bool, optional
|
| 100 |
+
Speed up evaluation by caching the building of dataset requests. `None` if not caching.
|
| 101 |
+
:param rewrite_requests_cache: bool, optional
|
| 102 |
+
Rewrites all the request cache if set to `True`. `None` if not desired.
|
| 103 |
+
:param delete_requests_cache: bool, optional
|
| 104 |
+
Deletes all the request cache if set to `True`. `None` if not desired.
|
| 105 |
+
:param limit: int or float, optional
|
| 106 |
+
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
|
| 107 |
+
:param samples: dictionary, optional
|
| 108 |
+
Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
|
| 109 |
+
:param bootstrap_iters:
|
| 110 |
+
Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
|
| 111 |
+
:param check_integrity: bool
|
| 112 |
+
Whether to run the relevant part of the test suite for the tasks
|
| 113 |
+
:param write_out: bool
|
| 114 |
+
If True, write out an example document and model input for checking task integrity
|
| 115 |
+
:param log_samples: bool
|
| 116 |
+
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
|
| 117 |
+
:param system_instruction: str
|
| 118 |
+
System instruction to be applied to the prompt
|
| 119 |
+
:param apply_chat_template: Union[bool, str]
|
| 120 |
+
Specifies whether to apply a chat template to the prompt.
|
| 121 |
+
- If set to True, the default chat template is applied.
|
| 122 |
+
- If set to a string, applies the specified chat template by name.
|
| 123 |
+
Defaults to False (no chat template applied).
|
| 124 |
+
:param fewshot_as_multiturn: bool
|
| 125 |
+
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
| 126 |
+
:param gen_kwargs: dict or comma-separated string
|
| 127 |
+
Arguments for model generation
|
| 128 |
+
Ignored for all tasks with loglikelihood output_type
|
| 129 |
+
:param verbosity: str
|
| 130 |
+
Verbosity level for logging
|
| 131 |
+
:param predict_only: bool
|
| 132 |
+
If true only model outputs will be generated and returned. Metrics will not be evaluated
|
| 133 |
+
:param random_seed: int
|
| 134 |
+
Random seed for python's random module. If set to None, the seed will not be set.
|
| 135 |
+
:param numpy_random_seed: int
|
| 136 |
+
Random seed for numpy. If set to None, the seed will not be set.
|
| 137 |
+
:param torch_random_seed: int
|
| 138 |
+
Random seed for torch. If set to None, the seed will not be set.
|
| 139 |
+
:param fewshot_random_seed: int
|
| 140 |
+
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
|
| 141 |
+
:param metadata: dict
|
| 142 |
+
Additional metadata to be added to the task manager. Will get passed to the download function of the task.
|
| 143 |
+
|
| 144 |
+
return
|
| 145 |
+
Dictionary of results
|
| 146 |
+
"""
|
| 147 |
+
if verbosity is not None:
|
| 148 |
+
setup_logging(verbosity=verbosity)
|
| 149 |
+
start_date = time.time()
|
| 150 |
+
|
| 151 |
+
if limit is not None and samples is not None:
|
| 152 |
+
raise ValueError(
|
| 153 |
+
"Either 'limit' or 'samples' must be None, but both are not None."
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if (
|
| 157 |
+
(isinstance(model_args, str) and "inst" in model_args.lower())
|
| 158 |
+
or (
|
| 159 |
+
isinstance(model_args, dict)
|
| 160 |
+
and any("inst" in str(v).lower() for v in model_args.values())
|
| 161 |
+
)
|
| 162 |
+
) and not apply_chat_template:
|
| 163 |
+
eval_logger.warning(
|
| 164 |
+
"Model appears to be an instruct variant but chat template is not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if delete_requests_cache:
|
| 168 |
+
eval_logger.info("Deleting requests cache...")
|
| 169 |
+
delete_cache()
|
| 170 |
+
|
| 171 |
+
seed_message = []
|
| 172 |
+
if random_seed is not None:
|
| 173 |
+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
|
| 174 |
+
seed_message.append(f"Setting random seed to {random_seed}")
|
| 175 |
+
random.seed(random_seed)
|
| 176 |
+
|
| 177 |
+
if numpy_random_seed is not None:
|
| 178 |
+
seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
|
| 179 |
+
np.random.seed(numpy_random_seed)
|
| 180 |
+
|
| 181 |
+
if torch_random_seed is not None:
|
| 182 |
+
seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
|
| 183 |
+
torch.manual_seed(torch_random_seed)
|
| 184 |
+
|
| 185 |
+
if fewshot_random_seed is not None:
|
| 186 |
+
seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
|
| 187 |
+
|
| 188 |
+
if seed_message:
|
| 189 |
+
eval_logger.info(" | ".join(seed_message))
|
| 190 |
+
|
| 191 |
+
if tasks is None:
|
| 192 |
+
tasks = []
|
| 193 |
+
if len(tasks) == 0:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
"No tasks specified, or no tasks found. Please verify the task names."
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if gen_kwargs is not None:
|
| 199 |
+
if isinstance(gen_kwargs, str):
|
| 200 |
+
gen_kwargs = simple_parse_args_string(gen_kwargs)
|
| 201 |
+
eval_logger.warning(
|
| 202 |
+
f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
|
| 203 |
+
"Ensure 'do_sample=True' for non-greedy decoding!"
|
| 204 |
+
)
|
| 205 |
+
if not gen_kwargs:
|
| 206 |
+
gen_kwargs = None
|
| 207 |
+
|
| 208 |
+
if isinstance(model, str):
|
| 209 |
+
if model_args is None:
|
| 210 |
+
eval_logger.warning("model_args not specified. Using defaults.")
|
| 211 |
+
model_args = ""
|
| 212 |
+
|
| 213 |
+
if isinstance(model_args, dict):
|
| 214 |
+
eval_logger.info(
|
| 215 |
+
f"Initializing {model} model, with arguments: {model_args}"
|
| 216 |
+
)
|
| 217 |
+
lm = dllm_eval.api.registry.get_model(model).create_from_arg_obj(
|
| 218 |
+
model_args,
|
| 219 |
+
{
|
| 220 |
+
"batch_size": batch_size,
|
| 221 |
+
"max_batch_size": max_batch_size,
|
| 222 |
+
"device": device,
|
| 223 |
+
},
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
else:
|
| 227 |
+
eval_logger.info(
|
| 228 |
+
f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
|
| 229 |
+
)
|
| 230 |
+
lm = dllm_eval.api.registry.get_model(model).create_from_arg_string(
|
| 231 |
+
model_args,
|
| 232 |
+
{
|
| 233 |
+
"batch_size": batch_size,
|
| 234 |
+
"max_batch_size": max_batch_size,
|
| 235 |
+
"device": device,
|
| 236 |
+
},
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
if not isinstance(model, dllm_eval.api.model.LM):
|
| 240 |
+
raise TypeError(
|
| 241 |
+
f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of dllm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `dllm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
|
| 242 |
+
)
|
| 243 |
+
eval_logger.info("Using pre-initialized model")
|
| 244 |
+
lm = model
|
| 245 |
+
|
| 246 |
+
if use_cache is not None:
|
| 247 |
+
eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
|
| 248 |
+
lm = dllm_eval.api.model.CachingLM(
|
| 249 |
+
lm,
|
| 250 |
+
use_cache
|
| 251 |
+
# each rank receives a different cache db.
|
| 252 |
+
# necessary to avoid multiple writes to cache at once
|
| 253 |
+
+ "_rank"
|
| 254 |
+
+ str(lm.rank)
|
| 255 |
+
+ ".db",
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
if task_manager is None:
|
| 259 |
+
metadata = (
|
| 260 |
+
simple_parse_args_string(model_args)
|
| 261 |
+
if isinstance(model_args, str)
|
| 262 |
+
else model_args
|
| 263 |
+
if isinstance(model_args, dict)
|
| 264 |
+
else {}
|
| 265 |
+
) | (metadata or {})
|
| 266 |
+
task_manager = TaskManager(metadata=metadata)
|
| 267 |
+
|
| 268 |
+
task_dict = get_task_dict(
|
| 269 |
+
tasks,
|
| 270 |
+
task_manager,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
|
| 274 |
+
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
|
| 275 |
+
def _adjust_config(task_dict):
|
| 276 |
+
adjusted_task_dict = {}
|
| 277 |
+
for task_name, task_obj in task_dict.items():
|
| 278 |
+
if isinstance(task_obj, dict):
|
| 279 |
+
adjusted_task_dict = {
|
| 280 |
+
**adjusted_task_dict,
|
| 281 |
+
**{task_name: _adjust_config(task_obj)},
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
else:
|
| 285 |
+
if task_obj.get_config("output_type") == "generate_until":
|
| 286 |
+
if gen_kwargs is not None:
|
| 287 |
+
task_obj.set_config(
|
| 288 |
+
key="generation_kwargs", value=gen_kwargs, update=True
|
| 289 |
+
)
|
| 290 |
+
eval_logger.info(
|
| 291 |
+
f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if predict_only:
|
| 295 |
+
eval_logger.info(
|
| 296 |
+
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
|
| 297 |
+
)
|
| 298 |
+
# we have to change the class properties post-hoc. This is pretty hacky.
|
| 299 |
+
task_obj.override_metric(metric_name="bypass")
|
| 300 |
+
|
| 301 |
+
# override tasks' fewshot values to the provided num_fewshot arg value
|
| 302 |
+
# except if tasks have it set to 0 manually in their configs--then we should never overwrite that
|
| 303 |
+
if num_fewshot is not None:
|
| 304 |
+
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
|
| 305 |
+
eval_logger.info(
|
| 306 |
+
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
|
| 307 |
+
)
|
| 308 |
+
else:
|
| 309 |
+
eval_logger.warning(
|
| 310 |
+
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
|
| 311 |
+
)
|
| 312 |
+
task_obj.set_config(key="num_fewshot", value=num_fewshot)
|
| 313 |
+
else:
|
| 314 |
+
# if num_fewshot not provided, and the task does not define a default one, default to 0
|
| 315 |
+
if (
|
| 316 |
+
default_num_fewshot := task_obj.get_config("num_fewshot")
|
| 317 |
+
) is None:
|
| 318 |
+
task_obj.set_config(key="num_fewshot", value=0)
|
| 319 |
+
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
|
| 320 |
+
task_obj.set_fewshot_seed(seed=fewshot_random_seed)
|
| 321 |
+
|
| 322 |
+
adjusted_task_dict[task_name] = task_obj
|
| 323 |
+
|
| 324 |
+
return adjusted_task_dict
|
| 325 |
+
|
| 326 |
+
task_dict = _adjust_config(task_dict)
|
| 327 |
+
|
| 328 |
+
if check_integrity:
|
| 329 |
+
run_task_tests(task_list=tasks)
|
| 330 |
+
|
| 331 |
+
if evaluation_tracker is not None:
|
| 332 |
+
evaluation_tracker.general_config_tracker.log_experiment_args(
|
| 333 |
+
model_source=model,
|
| 334 |
+
model_args=model_args,
|
| 335 |
+
system_instruction=system_instruction,
|
| 336 |
+
chat_template=lm.chat_template(apply_chat_template)
|
| 337 |
+
if apply_chat_template
|
| 338 |
+
else None,
|
| 339 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
results = evaluate(
|
| 343 |
+
lm=lm,
|
| 344 |
+
task_dict=task_dict,
|
| 345 |
+
limit=limit,
|
| 346 |
+
samples=samples,
|
| 347 |
+
cache_requests=cache_requests,
|
| 348 |
+
rewrite_requests_cache=rewrite_requests_cache,
|
| 349 |
+
bootstrap_iters=bootstrap_iters,
|
| 350 |
+
write_out=write_out,
|
| 351 |
+
log_samples=True if predict_only else log_samples,
|
| 352 |
+
system_instruction=system_instruction,
|
| 353 |
+
apply_chat_template=apply_chat_template,
|
| 354 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 355 |
+
verbosity=verbosity,
|
| 356 |
+
confirm_run_unsafe_code=confirm_run_unsafe_code,
|
| 357 |
+
)
|
| 358 |
+
if verbosity is not None:
|
| 359 |
+
setup_logging(verbosity=verbosity)
|
| 360 |
+
|
| 361 |
+
if lm.rank == 0:
|
| 362 |
+
if isinstance(model, str):
|
| 363 |
+
model_name = model
|
| 364 |
+
elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
|
| 365 |
+
model_name = model.config._name_or_path
|
| 366 |
+
else:
|
| 367 |
+
model_name = type(model).__name__
|
| 368 |
+
|
| 369 |
+
# add info about the model and few shot config
|
| 370 |
+
results["config"] = {
|
| 371 |
+
"model": model_name,
|
| 372 |
+
"model_args": model_args,
|
| 373 |
+
}
|
| 374 |
+
# add more detailed model info if available
|
| 375 |
+
if isinstance(lm, dllm_eval.models.huggingface.HFLM):
|
| 376 |
+
results["config"].update(lm.get_model_info())
|
| 377 |
+
# add info about execution
|
| 378 |
+
results["config"].update(
|
| 379 |
+
{
|
| 380 |
+
"batch_size": batch_size,
|
| 381 |
+
"batch_sizes": (
|
| 382 |
+
list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
|
| 383 |
+
),
|
| 384 |
+
"device": device,
|
| 385 |
+
"use_cache": use_cache,
|
| 386 |
+
"limit": limit,
|
| 387 |
+
"bootstrap_iters": bootstrap_iters,
|
| 388 |
+
"gen_kwargs": gen_kwargs,
|
| 389 |
+
"random_seed": random_seed,
|
| 390 |
+
"numpy_seed": numpy_random_seed,
|
| 391 |
+
"torch_seed": torch_random_seed,
|
| 392 |
+
"fewshot_seed": fewshot_random_seed,
|
| 393 |
+
}
|
| 394 |
+
)
|
| 395 |
+
results["git_hash"] = get_git_commit_hash()
|
| 396 |
+
results["date"] = start_date
|
| 397 |
+
add_env_info(results) # additional environment info to results
|
| 398 |
+
add_tokenizer_info(results, lm) # additional info about tokenizer
|
| 399 |
+
return results
|
| 400 |
+
else:
|
| 401 |
+
return None
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
@positional_deprecated
|
| 405 |
+
def evaluate(
|
| 406 |
+
lm: "LM",
|
| 407 |
+
task_dict,
|
| 408 |
+
limit: Optional[int] = None,
|
| 409 |
+
samples: Optional[dict] = None,
|
| 410 |
+
cache_requests: bool = False,
|
| 411 |
+
rewrite_requests_cache: bool = False,
|
| 412 |
+
bootstrap_iters: Optional[int] = 100000,
|
| 413 |
+
write_out: bool = False,
|
| 414 |
+
log_samples: bool = True,
|
| 415 |
+
system_instruction: Optional[str] = None,
|
| 416 |
+
apply_chat_template: Union[bool, str] = False,
|
| 417 |
+
fewshot_as_multiturn: bool = False,
|
| 418 |
+
verbosity: str = "INFO",
|
| 419 |
+
confirm_run_unsafe_code: bool = False,
|
| 420 |
+
):
|
| 421 |
+
"""Instantiate and evaluate a model on a list of tasks.
|
| 422 |
+
|
| 423 |
+
:param lm: obj
|
| 424 |
+
Language Model
|
| 425 |
+
:param task_dict: dict[str, Task]
|
| 426 |
+
Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
|
| 427 |
+
:param limit: int, optional
|
| 428 |
+
Limit the number of examples per task (only use this for testing)
|
| 429 |
+
:param samples: dictionary, optional
|
| 430 |
+
Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
|
| 431 |
+
:param cache_requests: bool, optional
|
| 432 |
+
Speed up evaluation by caching the building of dataset requests.
|
| 433 |
+
:param rewrite_requests_cache: bool, optional
|
| 434 |
+
Rewrites all the request cache if set to `True`.
|
| 435 |
+
:param bootstrap_iters:
|
| 436 |
+
Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
|
| 437 |
+
:param write_out: bool
|
| 438 |
+
If True, write out an example document and model input for checking task integrity
|
| 439 |
+
:param log_samples: bool
|
| 440 |
+
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
|
| 441 |
+
:param system_instruction: str
|
| 442 |
+
System instruction to be applied to the prompt
|
| 443 |
+
:param apply_chat_template: Union[bool, str]
|
| 444 |
+
Specifies whether to apply a chat template to the prompt.
|
| 445 |
+
- If set to True, the default chat template is applied.
|
| 446 |
+
- If set to a string, applies the specified chat template by name.
|
| 447 |
+
Defaults to False (no chat template applied).
|
| 448 |
+
:param fewshot_as_multiturn: bool
|
| 449 |
+
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
| 450 |
+
:param verbosity: str
|
| 451 |
+
Verbosity level for logging
|
| 452 |
+
:param confirm_run_unsafe_code: bool
|
| 453 |
+
Whether to confirm running tasks marked as unsafe.
|
| 454 |
+
:return
|
| 455 |
+
Dictionary of results
|
| 456 |
+
"""
|
| 457 |
+
|
| 458 |
+
if limit is not None and samples is not None:
|
| 459 |
+
raise ValueError(
|
| 460 |
+
"Either 'limit' or 'samples' must be None, but both are not None."
|
| 461 |
+
)
|
| 462 |
+
if samples is not None:
|
| 463 |
+
eval_logger.info(f"Evaluating examples for tasks {list(samples.keys())}")
|
| 464 |
+
if apply_chat_template:
|
| 465 |
+
eval_logger.warning(
|
| 466 |
+
"Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
|
| 467 |
+
)
|
| 468 |
+
# tracks all Instances/requests a model must generate output on.
|
| 469 |
+
requests = defaultdict(list)
|
| 470 |
+
# stores the amount to pad out reqs per req. type so that
|
| 471 |
+
# number of fwd passes per distributed rank is equal
|
| 472 |
+
padding_requests = defaultdict(int)
|
| 473 |
+
|
| 474 |
+
# get lists of group hierarchy and each type of request
|
| 475 |
+
eval_tasks = get_task_list(task_dict)
|
| 476 |
+
if not log_samples:
|
| 477 |
+
if not all(
|
| 478 |
+
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
|
| 479 |
+
for task_output in eval_tasks
|
| 480 |
+
):
|
| 481 |
+
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
|
| 482 |
+
|
| 483 |
+
# validation checks:
|
| 484 |
+
# 1.are we running multimodal task <-> non-multimodal model class, or vice-versa.
|
| 485 |
+
# 2.are we running code that is marked as unsafe.
|
| 486 |
+
incompatible_tasks = []
|
| 487 |
+
for task_output in eval_tasks:
|
| 488 |
+
task: Task = task_output.task
|
| 489 |
+
|
| 490 |
+
if getattr(task, "MULTIMODAL", False) and not getattr(lm, "MULTIMODAL", False):
|
| 491 |
+
incompatible_tasks.append(task_output.task_name)
|
| 492 |
+
elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
|
| 493 |
+
raise ValueError(
|
| 494 |
+
f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task."
|
| 495 |
+
)
|
| 496 |
+
if len(incompatible_tasks) > 0:
|
| 497 |
+
if not getattr(lm, "MULTIMODAL", False):
|
| 498 |
+
raise ValueError(
|
| 499 |
+
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
|
| 500 |
+
)
|
| 501 |
+
# end validation check
|
| 502 |
+
|
| 503 |
+
# Cache the limit arg.
|
| 504 |
+
limit_arg = limit
|
| 505 |
+
limits = []
|
| 506 |
+
for task_output in eval_tasks:
|
| 507 |
+
task: Task = task_output.task
|
| 508 |
+
|
| 509 |
+
limit = get_sample_size(task, limit_arg)
|
| 510 |
+
limits.append(limit)
|
| 511 |
+
task.build_all_requests(
|
| 512 |
+
limit=limit,
|
| 513 |
+
samples=samples.get(task_output.task_name, None)
|
| 514 |
+
if samples is not None
|
| 515 |
+
else samples,
|
| 516 |
+
rank=lm.rank,
|
| 517 |
+
world_size=lm.world_size,
|
| 518 |
+
cache_requests=cache_requests,
|
| 519 |
+
rewrite_requests_cache=rewrite_requests_cache,
|
| 520 |
+
system_instruction=system_instruction,
|
| 521 |
+
apply_chat_template=bool(apply_chat_template),
|
| 522 |
+
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 523 |
+
chat_template=getattr(lm, "apply_chat_template")
|
| 524 |
+
if apply_chat_template
|
| 525 |
+
else None,
|
| 526 |
+
tokenizer_name=getattr(lm, "tokenizer_name", "")
|
| 527 |
+
if apply_chat_template
|
| 528 |
+
else "",
|
| 529 |
+
)
|
| 530 |
+
eval_logger.debug(
|
| 531 |
+
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
|
| 532 |
+
)
|
| 533 |
+
if write_out:
|
| 534 |
+
print_writeout(task)
|
| 535 |
+
# aggregate Instances by LM method requested to get output.
|
| 536 |
+
for instance in task.instances:
|
| 537 |
+
reqtype = instance.request_type
|
| 538 |
+
requests[reqtype].append(instance)
|
| 539 |
+
|
| 540 |
+
if lm.world_size > 1:
|
| 541 |
+
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
|
| 542 |
+
gathered_item = (
|
| 543 |
+
lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
|
| 544 |
+
)
|
| 545 |
+
# "multiple_choice" task types dispatch (several) "loglikelihood" request types
|
| 546 |
+
reqtype = (
|
| 547 |
+
"loglikelihood"
|
| 548 |
+
if task.OUTPUT_TYPE == "multiple_choice"
|
| 549 |
+
else task.OUTPUT_TYPE
|
| 550 |
+
)
|
| 551 |
+
# compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
|
| 552 |
+
numpad = max(gathered_item) - gathered_item[lm.rank]
|
| 553 |
+
# todo: may not account for padding in cases like SquadV2 which has multiple req types
|
| 554 |
+
padding_requests[reqtype] += numpad
|
| 555 |
+
|
| 556 |
+
### Run LM on inputs, get all outputs ###
|
| 557 |
+
# execute each type of request
|
| 558 |
+
for reqtype, reqs in requests.items():
|
| 559 |
+
eval_logger.info(f"Running {reqtype} requests")
|
| 560 |
+
# create `K` copies of each request `req` based off `K = req.repeats`
|
| 561 |
+
cloned_reqs = []
|
| 562 |
+
for req in reqs:
|
| 563 |
+
cloned_reqs.extend([req] * req.repeats)
|
| 564 |
+
|
| 565 |
+
if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
|
| 566 |
+
for _ in range(padding_requests[reqtype]):
|
| 567 |
+
cloned_reqs.extend([req] * req.repeats)
|
| 568 |
+
|
| 569 |
+
# run requests through model
|
| 570 |
+
resps = getattr(lm, reqtype)(cloned_reqs)
|
| 571 |
+
|
| 572 |
+
# put responses from model into a list of length K for each request.
|
| 573 |
+
for x, req in zip(resps, cloned_reqs):
|
| 574 |
+
req.resps.append(x)
|
| 575 |
+
|
| 576 |
+
if lm.world_size > 1:
|
| 577 |
+
lm.accelerator.wait_for_everyone()
|
| 578 |
+
|
| 579 |
+
RANK = lm.rank
|
| 580 |
+
WORLD_SIZE = lm.world_size
|
| 581 |
+
### Postprocess outputs ###
|
| 582 |
+
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
|
| 583 |
+
for task_output, limit in zip(eval_tasks, limits):
|
| 584 |
+
task = task_output.task
|
| 585 |
+
task.apply_filters()
|
| 586 |
+
|
| 587 |
+
### Collect values of metrics on all datapoints ###
|
| 588 |
+
# # unpack results and sort back in order and return control to Task
|
| 589 |
+
# TODO: make it possible to use a different metric per filter
|
| 590 |
+
# Pre-process task.instances to group by doc_id
|
| 591 |
+
instances_by_doc_id = defaultdict(list)
|
| 592 |
+
for instance in task.instances:
|
| 593 |
+
instances_by_doc_id[instance.doc_id].append(instance)
|
| 594 |
+
# Sort instances within each group
|
| 595 |
+
for instances in instances_by_doc_id.values():
|
| 596 |
+
instances.sort(key=lambda x: x.idx)
|
| 597 |
+
# iterate over different filters used
|
| 598 |
+
for filter_key in task.instances[0].filtered_resps.keys():
|
| 599 |
+
indices = (
|
| 600 |
+
samples.get(task_output.task_name, None)
|
| 601 |
+
if samples is not None
|
| 602 |
+
else None
|
| 603 |
+
)
|
| 604 |
+
doc_iterator = task.doc_iterator(
|
| 605 |
+
rank=RANK,
|
| 606 |
+
limit=limit,
|
| 607 |
+
world_size=WORLD_SIZE,
|
| 608 |
+
samples=indices,
|
| 609 |
+
)
|
| 610 |
+
for doc_id, doc in doc_iterator:
|
| 611 |
+
if indices:
|
| 612 |
+
doc_id_true = indices[doc_id]
|
| 613 |
+
else:
|
| 614 |
+
doc_id_true = doc_id
|
| 615 |
+
requests = instances_by_doc_id[doc_id]
|
| 616 |
+
metrics = task.process_results(
|
| 617 |
+
doc, [req.filtered_resps[filter_key] for req in requests]
|
| 618 |
+
)
|
| 619 |
+
if log_samples:
|
| 620 |
+
target = task.doc_to_target(doc)
|
| 621 |
+
example = {
|
| 622 |
+
"doc_id": doc_id_true,
|
| 623 |
+
"doc": doc,
|
| 624 |
+
"target": target,
|
| 625 |
+
"arguments": [req.args for req in requests],
|
| 626 |
+
"resps": [req.resps for req in requests],
|
| 627 |
+
"filtered_resps": [
|
| 628 |
+
req.filtered_resps[filter_key] for req in requests
|
| 629 |
+
],
|
| 630 |
+
"filter": filter_key,
|
| 631 |
+
"metrics": list(metrics.keys()),
|
| 632 |
+
"doc_hash": hash_string(
|
| 633 |
+
json.dumps(
|
| 634 |
+
requests[0].doc,
|
| 635 |
+
indent=2,
|
| 636 |
+
default=handle_non_serializable,
|
| 637 |
+
ensure_ascii=False,
|
| 638 |
+
)
|
| 639 |
+
),
|
| 640 |
+
"prompt_hash": hash_string(requests[0].arguments[0]),
|
| 641 |
+
"target_hash": hash_string(str(target)),
|
| 642 |
+
}
|
| 643 |
+
example.update(metrics)
|
| 644 |
+
task_output.logged_samples.append(example)
|
| 645 |
+
for metric, value in metrics.items():
|
| 646 |
+
task_output.sample_metrics[(metric, filter_key)].append(value)
|
| 647 |
+
|
| 648 |
+
if WORLD_SIZE > 1:
|
| 649 |
+
# if multigpu, then gather data across all ranks to rank 0
|
| 650 |
+
# first gather logged samples across all ranks
|
| 651 |
+
for task_output in eval_tasks:
|
| 652 |
+
if log_samples:
|
| 653 |
+
# for task_name, task_samples in list(samples.items()):
|
| 654 |
+
full_samples = [None] * WORLD_SIZE if RANK == 0 else None
|
| 655 |
+
torch.distributed.gather_object(
|
| 656 |
+
obj=task_output.logged_samples,
|
| 657 |
+
object_gather_list=full_samples,
|
| 658 |
+
dst=0,
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
if RANK == 0:
|
| 662 |
+
task_output.logged_samples = list(
|
| 663 |
+
itertools.chain.from_iterable(full_samples)
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
# then collect metrics across all ranks
|
| 667 |
+
for metrics in task_output.sample_metrics:
|
| 668 |
+
metric_list = [None] * WORLD_SIZE if RANK == 0 else None
|
| 669 |
+
torch.distributed.gather_object(
|
| 670 |
+
obj=task_output.sample_metrics[metrics],
|
| 671 |
+
object_gather_list=metric_list,
|
| 672 |
+
dst=0,
|
| 673 |
+
)
|
| 674 |
+
if RANK == 0:
|
| 675 |
+
task_output.sample_metrics[metrics] = list(
|
| 676 |
+
itertools.chain.from_iterable(metric_list)
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
if RANK == 0:
|
| 680 |
+
### Aggregate results over all datapoints ###
|
| 681 |
+
# aggregate results ; run bootstrap CIs
|
| 682 |
+
for task_output in eval_tasks:
|
| 683 |
+
task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
|
| 684 |
+
(
|
| 685 |
+
results,
|
| 686 |
+
samples,
|
| 687 |
+
configs,
|
| 688 |
+
versions,
|
| 689 |
+
num_fewshot,
|
| 690 |
+
higher_is_better,
|
| 691 |
+
) = consolidate_results(eval_tasks)
|
| 692 |
+
|
| 693 |
+
### Calculate group metrics ###
|
| 694 |
+
if bool(results):
|
| 695 |
+
results, versions, show_group_table, *_ = consolidate_group_results(
|
| 696 |
+
results, versions, task_dict
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
results_agg, group_agg = prepare_print_tasks(task_dict, results)
|
| 700 |
+
subtask_list = get_subtask_list(task_dict)
|
| 701 |
+
|
| 702 |
+
# collect all higher_is_better values for metrics
|
| 703 |
+
# in the group's subtasks.
|
| 704 |
+
# TODO: clean this up ; unify with the below metric_list loop?
|
| 705 |
+
_higher_is_better = {}
|
| 706 |
+
for group, task_list in subtask_list.items():
|
| 707 |
+
if (
|
| 708 |
+
len(task_list) != 0
|
| 709 |
+
): # subtask list will list "task_name": [] for solo tasks
|
| 710 |
+
for task in task_list:
|
| 711 |
+
for m, h in higher_is_better[task].items():
|
| 712 |
+
if m not in _higher_is_better.keys():
|
| 713 |
+
_higher_is_better[m] = h
|
| 714 |
+
|
| 715 |
+
if (
|
| 716 |
+
m in _higher_is_better
|
| 717 |
+
and _higher_is_better[m] is not None
|
| 718 |
+
and _higher_is_better[m] != h
|
| 719 |
+
):
|
| 720 |
+
eval_logger.warning(
|
| 721 |
+
f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
|
| 722 |
+
)
|
| 723 |
+
_higher_is_better[m] = None
|
| 724 |
+
higher_is_better[group] = _higher_is_better
|
| 725 |
+
|
| 726 |
+
results_dict = {
|
| 727 |
+
"results": dict(results_agg.items()),
|
| 728 |
+
**(
|
| 729 |
+
{"groups": dict(group_agg.items())}
|
| 730 |
+
if (bool(group_agg) & show_group_table)
|
| 731 |
+
else {}
|
| 732 |
+
),
|
| 733 |
+
"group_subtasks": dict(reversed(subtask_list.items())),
|
| 734 |
+
"configs": dict(sorted(configs.items())),
|
| 735 |
+
"versions": dict(sorted(versions.items())),
|
| 736 |
+
"n-shot": dict(sorted(num_fewshot.items())),
|
| 737 |
+
"higher_is_better": dict(sorted(higher_is_better.items())),
|
| 738 |
+
"n-samples": {
|
| 739 |
+
task_output.task_name: {
|
| 740 |
+
"original": len(task_output.task.eval_docs),
|
| 741 |
+
"effective": min(
|
| 742 |
+
limit if limit else len(task_output.task.eval_docs),
|
| 743 |
+
len(task_output.task.eval_docs),
|
| 744 |
+
),
|
| 745 |
+
}
|
| 746 |
+
for task_output, limit in zip(eval_tasks, limits)
|
| 747 |
+
},
|
| 748 |
+
}
|
| 749 |
+
if log_samples:
|
| 750 |
+
results_dict["samples"] = dict(samples)
|
| 751 |
+
|
| 752 |
+
return results_dict
|
| 753 |
+
|
| 754 |
+
else:
|
| 755 |
+
return None
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
def request_caching_arg_to_dict(cache_requests: str) -> dict:
|
| 759 |
+
request_caching_args = {
|
| 760 |
+
"cache_requests": cache_requests in {"true", "refresh"},
|
| 761 |
+
"rewrite_requests_cache": cache_requests == "refresh",
|
| 762 |
+
"delete_requests_cache": cache_requests == "delete",
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
+
return request_caching_args
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/evaluator_utils.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
import pathlib
|
| 5 |
+
import sys
|
| 6 |
+
from typing import List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
from dllm_eval.api.group import ConfigurableGroup
|
| 9 |
+
from dllm_eval.api.metrics import (
|
| 10 |
+
aggregate_subtask_metrics,
|
| 11 |
+
mean,
|
| 12 |
+
pooled_sample_stderr,
|
| 13 |
+
stderr_for_metric,
|
| 14 |
+
)
|
| 15 |
+
from dllm_eval.api.task import Task
|
| 16 |
+
from dllm_eval.utils import positional_deprecated
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
eval_logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TaskOutput:
|
| 23 |
+
"""
|
| 24 |
+
Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task.
|
| 25 |
+
|
| 26 |
+
Attributes:
|
| 27 |
+
task (object): The task object.
|
| 28 |
+
task_name (str): The name of the task.
|
| 29 |
+
task_config (dict): The configuration of the task.
|
| 30 |
+
version (str): The version of the task.
|
| 31 |
+
group_name (str): The name of the task group.
|
| 32 |
+
n_shot (int): The number of shots for the task.
|
| 33 |
+
task_alias (str): The alias of the task.
|
| 34 |
+
group_alias (str): The alias of the task group.
|
| 35 |
+
is_group (bool): Indicates if the task is a group.
|
| 36 |
+
logged_samples (list): The list of logged samples.
|
| 37 |
+
sample_len (int): The length of the samples.
|
| 38 |
+
sample_metrics (defaultdict): The dictionary of samples' metrics.
|
| 39 |
+
agg_metrics (defaultdict): The dictionary of aggregate metrics.
|
| 40 |
+
|
| 41 |
+
Methods:
|
| 42 |
+
from_taskdict(cls, task_name: str, task):
|
| 43 |
+
Creates a TaskOutput instance from a task dictionary.
|
| 44 |
+
|
| 45 |
+
calculate_aggregate_metric(bootstrap_iters=100000) -> None:
|
| 46 |
+
Calculates the aggregate metrics for the task.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
task=None,
|
| 52 |
+
task_name=None,
|
| 53 |
+
task_config=None,
|
| 54 |
+
version=None,
|
| 55 |
+
group_name=None,
|
| 56 |
+
n_shot=None,
|
| 57 |
+
task_alias=None,
|
| 58 |
+
group_alias=None,
|
| 59 |
+
is_group=None,
|
| 60 |
+
):
|
| 61 |
+
self.task = task
|
| 62 |
+
self.task_config = task_config
|
| 63 |
+
self.task_name = task_name
|
| 64 |
+
self.group_name = group_name
|
| 65 |
+
self.version = version
|
| 66 |
+
self.n_shot = n_shot
|
| 67 |
+
self.task_alias = task_alias
|
| 68 |
+
self.group_alias = group_alias
|
| 69 |
+
self.is_group = is_group
|
| 70 |
+
self.logged_samples = []
|
| 71 |
+
self.sample_len = None
|
| 72 |
+
self.sample_metrics = collections.defaultdict(list)
|
| 73 |
+
self.agg_metrics = collections.defaultdict(list)
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def from_taskdict(cls, task_name: str, task):
|
| 77 |
+
if isinstance(task, tuple):
|
| 78 |
+
group_name, task = task
|
| 79 |
+
else:
|
| 80 |
+
group_name = None
|
| 81 |
+
if not task:
|
| 82 |
+
# these gets filtered out in get_task_list
|
| 83 |
+
# once they are added to group hierarchy
|
| 84 |
+
is_group = True
|
| 85 |
+
return cls(
|
| 86 |
+
task=task, task_name=task_name, is_group=is_group, group_name=group_name
|
| 87 |
+
)
|
| 88 |
+
version = task.VERSION
|
| 89 |
+
task_config = dict(task.dump_config())
|
| 90 |
+
if (n_shot := task_config.get("num_fewshot")) == 0:
|
| 91 |
+
n_shot = task_config.get("metadata", {}).get("num_fewshot", 0)
|
| 92 |
+
task_alias = task_config.get("alias")
|
| 93 |
+
group_alias = task_config.get("group_alias")
|
| 94 |
+
return cls(
|
| 95 |
+
task=task,
|
| 96 |
+
task_name=task_name,
|
| 97 |
+
task_config=task_config,
|
| 98 |
+
group_name=group_name,
|
| 99 |
+
version=version,
|
| 100 |
+
n_shot=n_shot,
|
| 101 |
+
task_alias=task_alias,
|
| 102 |
+
group_alias=group_alias,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
|
| 106 |
+
for (metric, filter_key), items in self.sample_metrics.items():
|
| 107 |
+
try:
|
| 108 |
+
agg_fn = self.task.aggregation()[metric]
|
| 109 |
+
except KeyError:
|
| 110 |
+
# This is when process results output an arbitrary metric
|
| 111 |
+
# TODO: Handle this better and allow other aggregate functions other than mean.
|
| 112 |
+
agg_fn = mean
|
| 113 |
+
metric_key = f"{metric},{filter_key}"
|
| 114 |
+
self.agg_metrics[metric_key] = agg_fn(items)
|
| 115 |
+
self.sample_len = len(items) # TODO: same sample size for each metric?
|
| 116 |
+
if isinstance(bootstrap_iters, int):
|
| 117 |
+
stderr_fn = stderr_for_metric(
|
| 118 |
+
metric=agg_fn,
|
| 119 |
+
bootstrap_iters=min(bootstrap_iters, 100)
|
| 120 |
+
if metric in ["bleu", "chrf", "ter"]
|
| 121 |
+
else bootstrap_iters,
|
| 122 |
+
)
|
| 123 |
+
self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
|
| 124 |
+
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations."
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def __repr__(self):
|
| 132 |
+
return (
|
| 133 |
+
f"TaskOutput(task_name={self.task_name}, "
|
| 134 |
+
f"group_name={self.group_name}, "
|
| 135 |
+
f"version={self.version}, "
|
| 136 |
+
f"n_shot={self.n_shot}, "
|
| 137 |
+
f"task_alias={self.task_alias}, "
|
| 138 |
+
f"group_alias={self.group_alias})"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_task_list(task_dict: dict) -> List[TaskOutput]:
|
| 143 |
+
outputs = []
|
| 144 |
+
for task_name, task_obj in task_dict.items():
|
| 145 |
+
if isinstance(task_obj, dict):
|
| 146 |
+
_outputs = get_task_list(task_obj)
|
| 147 |
+
outputs.extend(_outputs)
|
| 148 |
+
else:
|
| 149 |
+
task_output = TaskOutput.from_taskdict(task_name, task_obj)
|
| 150 |
+
outputs.append(task_output)
|
| 151 |
+
|
| 152 |
+
return outputs
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_subtask_list(task_dict, task_root=None, depth=0):
|
| 156 |
+
subtask_list = {}
|
| 157 |
+
for group_obj, task_obj in task_dict.items():
|
| 158 |
+
if isinstance(group_obj, ConfigurableGroup):
|
| 159 |
+
# group_name = group_obj.group_name
|
| 160 |
+
group_name = group_obj.group_name
|
| 161 |
+
else:
|
| 162 |
+
group_name = group_obj
|
| 163 |
+
if isinstance(task_obj, dict):
|
| 164 |
+
_subtask_list = get_subtask_list(
|
| 165 |
+
task_obj, task_root=group_name, depth=depth + 1
|
| 166 |
+
)
|
| 167 |
+
if task_root:
|
| 168 |
+
subtask_list.setdefault((task_root, depth), []).extend(
|
| 169 |
+
[
|
| 170 |
+
_task
|
| 171 |
+
for (_task, _depth) in _subtask_list.keys()
|
| 172 |
+
if (_depth - 1) == depth
|
| 173 |
+
]
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
subtask_list = {**subtask_list, **_subtask_list}
|
| 177 |
+
else:
|
| 178 |
+
if isinstance(task_obj, ConfigurableGroup):
|
| 179 |
+
# group_or_task_name = task_obj.group_name
|
| 180 |
+
group_or_task_name = task_obj.group_name
|
| 181 |
+
elif isinstance(task_obj, Task):
|
| 182 |
+
# group_or_task_name = task_obj.task_name
|
| 183 |
+
group_or_task_name = task_obj.task_name
|
| 184 |
+
|
| 185 |
+
if task_root is None:
|
| 186 |
+
subtask_list.setdefault((group_or_task_name, depth), [])
|
| 187 |
+
else:
|
| 188 |
+
subtask_list.setdefault((task_root, depth), []).append(
|
| 189 |
+
group_or_task_name
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if depth == 0:
|
| 193 |
+
_subtask_list = {}
|
| 194 |
+
for group_key, task_list in subtask_list.items():
|
| 195 |
+
group_name, depth = group_key
|
| 196 |
+
_subtask_list[group_name] = task_list
|
| 197 |
+
subtask_list = _subtask_list
|
| 198 |
+
|
| 199 |
+
return subtask_list
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def print_writeout(task) -> None:
|
| 203 |
+
for inst in task.instances:
|
| 204 |
+
# print the prompt for the first few documents
|
| 205 |
+
if inst.doc_id < 1:
|
| 206 |
+
eval_logger.info(
|
| 207 |
+
f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\
|
| 208 |
+
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
|
| 209 |
+
)
|
| 210 |
+
eval_logger.info(f"Request: {str(inst)}")
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
|
| 214 |
+
if limit is not None:
|
| 215 |
+
limit = (
|
| 216 |
+
int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
|
| 217 |
+
)
|
| 218 |
+
return limit
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def prepare_print_tasks(
|
| 222 |
+
task_dict: dict,
|
| 223 |
+
results: dict,
|
| 224 |
+
task_depth=0,
|
| 225 |
+
group_depth=0,
|
| 226 |
+
) -> Tuple[dict, dict]:
|
| 227 |
+
"""
|
| 228 |
+
@param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
|
| 229 |
+
value is a list of task names.
|
| 230 |
+
@param results: Dictionary containing the results of each task. Each key is a
|
| 231 |
+
group name and its value is a dictionary of task results.
|
| 232 |
+
@param task_depth: The indentation level for printing the task
|
| 233 |
+
hierarchy. Default is 0.
|
| 234 |
+
@param group_depth: The indentation level for printing the group
|
| 235 |
+
hierarchy. Default is 0.
|
| 236 |
+
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
|
| 237 |
+
aggregated results for each task, and groups_agg contains aggregated results for each group.
|
| 238 |
+
|
| 239 |
+
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def _sort_task_dict(task_dict):
|
| 243 |
+
"""
|
| 244 |
+
Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name.
|
| 245 |
+
Required so that we end up sorting within each sub-header correctly.
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
return dict(
|
| 249 |
+
sorted(
|
| 250 |
+
task_dict.items(),
|
| 251 |
+
key=lambda item: item[0].group_name
|
| 252 |
+
if isinstance(item[0], ConfigurableGroup)
|
| 253 |
+
else item[0],
|
| 254 |
+
)
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
task_agg = collections.defaultdict(dict)
|
| 258 |
+
group_agg = collections.defaultdict(dict)
|
| 259 |
+
task_dict = _sort_task_dict(task_dict)
|
| 260 |
+
for task_or_group_name, task_or_group_obj in task_dict.items():
|
| 261 |
+
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
|
| 262 |
+
if isinstance(task_or_group_name, ConfigurableGroup):
|
| 263 |
+
# string_name = task_or_group_name.group_name
|
| 264 |
+
name = task_or_group_name.group_name
|
| 265 |
+
from_configurable_group = True
|
| 266 |
+
task_or_group_obj = _sort_task_dict(task_or_group_obj)
|
| 267 |
+
elif isinstance(task_or_group_name, str):
|
| 268 |
+
name = task_or_group_name
|
| 269 |
+
if isinstance(task_or_group_obj, Task):
|
| 270 |
+
# string_name = task_or_group_obj.task_name
|
| 271 |
+
name = task_or_group_obj.task_name
|
| 272 |
+
from_configurable_group = False
|
| 273 |
+
|
| 274 |
+
task_agg[name] = results[name].copy()
|
| 275 |
+
if from_configurable_group:
|
| 276 |
+
if task_or_group_name.group_alias is not None:
|
| 277 |
+
alias = task_or_group_name.group_alias
|
| 278 |
+
else:
|
| 279 |
+
alias = task_or_group_name.group
|
| 280 |
+
else:
|
| 281 |
+
if "alias" in task_agg[name]:
|
| 282 |
+
alias = task_agg[name]["alias"]
|
| 283 |
+
else:
|
| 284 |
+
alias = name
|
| 285 |
+
|
| 286 |
+
task_agg[name]["alias"] = tab_string + alias
|
| 287 |
+
if "samples" in task_agg[name]:
|
| 288 |
+
task_agg[name].pop("samples")
|
| 289 |
+
|
| 290 |
+
if from_configurable_group and (" " not in results[name]):
|
| 291 |
+
group_tab_string = " " * group_depth + "- " if group_depth > 0 else ""
|
| 292 |
+
group_agg[name] = results[name].copy()
|
| 293 |
+
group_agg[name]["alias"] = group_tab_string + alias
|
| 294 |
+
if "samples" in group_agg[name]:
|
| 295 |
+
group_agg[name].pop("samples")
|
| 296 |
+
|
| 297 |
+
if isinstance(task_or_group_obj, dict):
|
| 298 |
+
task_depth += 1
|
| 299 |
+
group_depth += 1
|
| 300 |
+
_task_agg, _group_agg = prepare_print_tasks(
|
| 301 |
+
task_or_group_obj, results, task_depth, group_depth
|
| 302 |
+
)
|
| 303 |
+
task_agg = {
|
| 304 |
+
**task_agg,
|
| 305 |
+
**_task_agg,
|
| 306 |
+
}
|
| 307 |
+
group_agg = {**group_agg, **_group_agg}
|
| 308 |
+
task_depth -= 1
|
| 309 |
+
group_depth -= 1
|
| 310 |
+
return task_agg, group_agg
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def consolidate_results(
|
| 314 |
+
eval_tasks: List[TaskOutput],
|
| 315 |
+
) -> Tuple[dict, dict, dict, dict, dict, dict]:
|
| 316 |
+
"""
|
| 317 |
+
@param eval_tasks: list(TaskOutput).
|
| 318 |
+
@return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
|
| 319 |
+
|
| 320 |
+
Consolidates the results of multiple evaluation tasks into a single structure.
|
| 321 |
+
|
| 322 |
+
The method iterates over each evaluation instance and extracts relevant information to create the consolidated
|
| 323 |
+
results structure. The consolidated results structure has the following properties:
|
| 324 |
+
|
| 325 |
+
- results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains
|
| 326 |
+
metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task
|
| 327 |
+
aliases specified in the task configuration.
|
| 328 |
+
- samples: A defaultdict with task names as keys and lists of log samples as values.
|
| 329 |
+
- configs: A defaultdict with task names as keys and task configurations as values.
|
| 330 |
+
- versions: A defaultdict with task names as keys and task versions as values.
|
| 331 |
+
- num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values.
|
| 332 |
+
- higher_is_better: A defaultdict with task names as keys and indicators of whether higher values are better
|
| 333 |
+
for each metric as values.
|
| 334 |
+
|
| 335 |
+
The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple.
|
| 336 |
+
"""
|
| 337 |
+
# stores the final result for each task, for each metric/filter pair.
|
| 338 |
+
results = collections.defaultdict(dict)
|
| 339 |
+
# logs info about each document evaluated.
|
| 340 |
+
samples = collections.defaultdict(list)
|
| 341 |
+
# store num-fewshot value per task
|
| 342 |
+
num_fewshot = collections.defaultdict(int)
|
| 343 |
+
# Tracks the YAML configs of all chosen task
|
| 344 |
+
configs = collections.defaultdict(dict)
|
| 345 |
+
# Tracks each task's version.
|
| 346 |
+
versions = collections.defaultdict(dict)
|
| 347 |
+
# Track `higher_is_better` for each metric
|
| 348 |
+
higher_is_better = collections.defaultdict(dict)
|
| 349 |
+
|
| 350 |
+
for task_output in eval_tasks:
|
| 351 |
+
if "task_alias" in (task_config := task_output.task_config):
|
| 352 |
+
results[task_output.task_name]["alias"] = task_config["task_alias"]
|
| 353 |
+
else:
|
| 354 |
+
results[task_output.task_name]["alias"] = task_output.task_name
|
| 355 |
+
if group_alias := task_output.group_alias:
|
| 356 |
+
if group_alias not in results and (group_name := task_output.group_name):
|
| 357 |
+
results[group_name]["alias"] = group_alias
|
| 358 |
+
num_fewshot[task_output.task_name] = task_output.n_shot
|
| 359 |
+
configs[task_output.task_name] = task_output.task_config
|
| 360 |
+
versions[task_output.task_name] = task_output.version
|
| 361 |
+
samples[task_output.task_name] = task_output.logged_samples
|
| 362 |
+
higher_is_better[task_output.task_name] = task_output.task.higher_is_better()
|
| 363 |
+
for (metric, filter_key), items in task_output.sample_metrics.items():
|
| 364 |
+
metric_key = f"{metric},{filter_key}"
|
| 365 |
+
results[task_output.task_name][metric_key] = task_output.agg_metrics[
|
| 366 |
+
metric_key
|
| 367 |
+
]
|
| 368 |
+
results[task_output.task_name]["samples"] = task_output.sample_len
|
| 369 |
+
results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
|
| 370 |
+
task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
|
| 371 |
+
)
|
| 372 |
+
return results, samples, configs, versions, num_fewshot, higher_is_better
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def consolidate_group_results(
|
| 376 |
+
results,
|
| 377 |
+
versions,
|
| 378 |
+
task_dict,
|
| 379 |
+
task_root=None,
|
| 380 |
+
show_group_table=False,
|
| 381 |
+
task_aggregation_list=None,
|
| 382 |
+
) -> Tuple[dict, dict, bool, Union[None,]]:
|
| 383 |
+
"""
|
| 384 |
+
(Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
|
| 385 |
+
|
| 386 |
+
@return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below:
|
| 387 |
+
|
| 388 |
+
- results: A defaultdict with task names (and, after this function is called, group names of
|
| 389 |
+
groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys.
|
| 390 |
+
- versions: A defaultdict with task names (and, after this function is called, group names of
|
| 391 |
+
groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None).
|
| 392 |
+
- show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table.
|
| 393 |
+
- task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric.
|
| 394 |
+
|
| 395 |
+
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
|
| 396 |
+
In the top-level invocation of this function, task_aggregation_list is ignored.
|
| 397 |
+
"""
|
| 398 |
+
if task_root is None:
|
| 399 |
+
task_root = {}
|
| 400 |
+
|
| 401 |
+
if task_aggregation_list is None:
|
| 402 |
+
task_aggregation_list = {}
|
| 403 |
+
|
| 404 |
+
for group_or_task, group_or_task_info in task_dict.items():
|
| 405 |
+
# Convert to string
|
| 406 |
+
if isinstance(group_or_task, ConfigurableGroup):
|
| 407 |
+
group_config = group_or_task.config
|
| 408 |
+
group_or_task = group_or_task.group_name
|
| 409 |
+
else:
|
| 410 |
+
group_config = None
|
| 411 |
+
|
| 412 |
+
if isinstance(group_or_task_info, Task):
|
| 413 |
+
if task_root:
|
| 414 |
+
task_aggregation_list.setdefault(task_root, []).append(
|
| 415 |
+
group_or_task_info.task_name
|
| 416 |
+
)
|
| 417 |
+
else:
|
| 418 |
+
(
|
| 419 |
+
results,
|
| 420 |
+
versions,
|
| 421 |
+
show_group_table,
|
| 422 |
+
_task_aggregation_list,
|
| 423 |
+
) = consolidate_group_results(
|
| 424 |
+
results,
|
| 425 |
+
versions,
|
| 426 |
+
group_or_task_info,
|
| 427 |
+
group_or_task,
|
| 428 |
+
show_group_table,
|
| 429 |
+
task_aggregation_list,
|
| 430 |
+
)
|
| 431 |
+
if task_root:
|
| 432 |
+
task_aggregation_list.setdefault(task_root, []).extend(
|
| 433 |
+
task_aggregation_list.get(group_or_task, [])
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
if (group_config is None) or (
|
| 437 |
+
group_config["aggregate_metric_list"] is None
|
| 438 |
+
):
|
| 439 |
+
results[group_or_task][" "] = " "
|
| 440 |
+
continue
|
| 441 |
+
|
| 442 |
+
if "aggregate_metric_list" in group_config:
|
| 443 |
+
agg_metric_list = group_config["aggregate_metric_list"]
|
| 444 |
+
|
| 445 |
+
show_group_table = show_group_table | bool(
|
| 446 |
+
group_config["aggregate_metric_list"]
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
task_list = _task_aggregation_list[group_or_task]
|
| 450 |
+
|
| 451 |
+
metric_list = list(
|
| 452 |
+
{
|
| 453 |
+
key
|
| 454 |
+
for task in task_list
|
| 455 |
+
for key in results[task].keys()
|
| 456 |
+
if "_stderr" not in key and key not in ["task", "alias", "samples"]
|
| 457 |
+
}
|
| 458 |
+
)
|
| 459 |
+
for metric in metric_list:
|
| 460 |
+
stderr = "_stderr,".join(metric.split(","))
|
| 461 |
+
|
| 462 |
+
# gather metrics, sizes, and stderrs from subtasks
|
| 463 |
+
metrics = [
|
| 464 |
+
results[task][metric]
|
| 465 |
+
for task in task_list
|
| 466 |
+
if metric in results[task]
|
| 467 |
+
] # TODO: copy?
|
| 468 |
+
stderrs = [
|
| 469 |
+
results[task][stderr]
|
| 470 |
+
for task in task_list
|
| 471 |
+
if stderr in results[task]
|
| 472 |
+
]
|
| 473 |
+
sizes = [
|
| 474 |
+
results[task]["samples"]
|
| 475 |
+
for task in task_list
|
| 476 |
+
if metric in results[task]
|
| 477 |
+
]
|
| 478 |
+
|
| 479 |
+
for metric_config in agg_metric_list:
|
| 480 |
+
for filter_name in metric_config["filter_list"]:
|
| 481 |
+
if metric != ",".join([metric_config["metric"], filter_name]):
|
| 482 |
+
continue
|
| 483 |
+
|
| 484 |
+
# compute group's pooled metric and stderr
|
| 485 |
+
if metric_config["aggregation"] == "mean":
|
| 486 |
+
aggregate_fn = aggregate_subtask_metrics
|
| 487 |
+
elif callable(metric_config["aggregation"]):
|
| 488 |
+
aggregate_fn = metric_config["aggregation"]
|
| 489 |
+
else:
|
| 490 |
+
raise ValueError(
|
| 491 |
+
f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'"
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
results[group_or_task][metric] = aggregate_fn(
|
| 495 |
+
metrics,
|
| 496 |
+
sizes,
|
| 497 |
+
metric_config["weight_by_size"],
|
| 498 |
+
)
|
| 499 |
+
# TODO: calculate groups' metrics using arbitrary agg fns
|
| 500 |
+
if "N/A" in stderrs:
|
| 501 |
+
results[group_or_task][stderr] = "N/A"
|
| 502 |
+
else:
|
| 503 |
+
# NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere
|
| 504 |
+
results[group_or_task][stderr] = pooled_sample_stderr(
|
| 505 |
+
stderrs, sizes
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
results[group_or_task]["samples"] = sum(sizes)
|
| 509 |
+
group_metadata = group_config.get("metadata", None)
|
| 510 |
+
if group_metadata is not None:
|
| 511 |
+
versions[group_or_task] = group_metadata.get("version", None)
|
| 512 |
+
# print(results)
|
| 513 |
+
return results, versions, show_group_table, task_aggregation_list
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
@positional_deprecated
|
| 517 |
+
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
|
| 518 |
+
"""
|
| 519 |
+
Search upward in the directory tree to a maximum of three layers
|
| 520 |
+
to find and return the package root (containing the 'tests' folder)
|
| 521 |
+
"""
|
| 522 |
+
cur_path = start_path.resolve()
|
| 523 |
+
max_layers = 3
|
| 524 |
+
for _ in range(max_layers):
|
| 525 |
+
if (cur_path / "tests" / "test_version_stable.py").exists():
|
| 526 |
+
return cur_path
|
| 527 |
+
else:
|
| 528 |
+
cur_path = cur_path.parent.resolve()
|
| 529 |
+
raise FileNotFoundError(
|
| 530 |
+
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
@positional_deprecated
|
| 535 |
+
def run_task_tests(task_list: List[str]):
|
| 536 |
+
"""
|
| 537 |
+
Find the package root and run the tests for the given tasks
|
| 538 |
+
"""
|
| 539 |
+
import pytest
|
| 540 |
+
|
| 541 |
+
package_root = find_test_root(start_path=pathlib.Path(__file__))
|
| 542 |
+
task_string = " or ".join(task_list)
|
| 543 |
+
args = [
|
| 544 |
+
f"{package_root}/tests/test_version_stable.py",
|
| 545 |
+
f"--rootdir={package_root}",
|
| 546 |
+
"-k",
|
| 547 |
+
f"{task_string}",
|
| 548 |
+
]
|
| 549 |
+
sys.path.append(str(package_root))
|
| 550 |
+
pytest_return_val = pytest.main(args)
|
| 551 |
+
if pytest_return_val:
|
| 552 |
+
raise ValueError(
|
| 553 |
+
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
|
| 554 |
+
)
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from dllm_eval.api.filter import FilterEnsemble
|
| 5 |
+
from dllm_eval.api.registry import get_filter
|
| 6 |
+
|
| 7 |
+
from . import custom, extraction, selection, transformation
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_filter_ensemble(
|
| 11 |
+
filter_name: str, components: List[List[str]]
|
| 12 |
+
) -> FilterEnsemble:
|
| 13 |
+
"""
|
| 14 |
+
Create a filtering pipeline.
|
| 15 |
+
"""
|
| 16 |
+
filters = []
|
| 17 |
+
for function, kwargs in components:
|
| 18 |
+
if kwargs is None:
|
| 19 |
+
kwargs = {}
|
| 20 |
+
# create a filter given its name in the registry
|
| 21 |
+
f = partial(get_filter(function), **kwargs)
|
| 22 |
+
# add the filter as a pipeline step
|
| 23 |
+
filters.append(f)
|
| 24 |
+
|
| 25 |
+
return FilterEnsemble(name=filter_name, filters=filters)
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/custom.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dllm_eval.api.filter import Filter
|
| 2 |
+
from dllm_eval.api.registry import register_filter
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@register_filter("custom")
|
| 6 |
+
class CustomFilter(Filter):
|
| 7 |
+
"""
|
| 8 |
+
Custom filter that applies a custom, user-defined function to the model responses.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, **kwargs) -> None:
|
| 12 |
+
self.filter_fn = kwargs.pop("filter_fn")
|
| 13 |
+
|
| 14 |
+
super().__init__(**kwargs)
|
| 15 |
+
|
| 16 |
+
def apply(self, resps, docs):
|
| 17 |
+
return self.filter_fn(resps, docs)
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/decontamination.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dllm_eval.api.filter import Filter
|
| 2 |
+
from dllm_eval.api.registry import register_filter
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@register_filter("decontaminate")
|
| 6 |
+
class DecontaminationFilter(Filter):
|
| 7 |
+
"""
|
| 8 |
+
A filter which evaluates
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
name = "track_decontamination"
|
| 12 |
+
|
| 13 |
+
def __init__(self, path) -> None:
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
|
| 17 |
+
should further cache result on a given (task_name, doc_id)
|
| 18 |
+
"""
|
| 19 |
+
self._decontam_results = None
|
| 20 |
+
|
| 21 |
+
def apply(self, resps, docs) -> None:
|
| 22 |
+
"""
|
| 23 |
+
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
|
| 24 |
+
"""
|
| 25 |
+
pass
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/extraction.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import sys
|
| 3 |
+
import unicodedata
|
| 4 |
+
|
| 5 |
+
from dllm_eval.api.filter import Filter
|
| 6 |
+
from dllm_eval.api.registry import register_filter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@register_filter("regex")
|
| 10 |
+
class RegexFilter(Filter):
|
| 11 |
+
"""A filter that extracts values from text using regex pattern matching.
|
| 12 |
+
|
| 13 |
+
This filter applies a regex pattern to each model response and extracts matched values.
|
| 14 |
+
If no match is found, returns a fallback value. Useful for extracting structured data
|
| 15 |
+
(like numbers) from unstructured model outputs.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
|
| 21 |
+
group_select: int = 0,
|
| 22 |
+
fallback: str = "[invalid]",
|
| 23 |
+
) -> None:
|
| 24 |
+
"""
|
| 25 |
+
pass a string `regex` to run `re.compile(r"regex")` on.
|
| 26 |
+
`fallback` defines the output returned if no matches for the regex are located.
|
| 27 |
+
"""
|
| 28 |
+
self.regex_pattern = regex_pattern
|
| 29 |
+
self.regex = re.compile(regex_pattern)
|
| 30 |
+
self.group_select = group_select
|
| 31 |
+
self.fallback = fallback
|
| 32 |
+
|
| 33 |
+
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 34 |
+
# here, we assume we have a list, in which each element is
|
| 35 |
+
# a list of model responses for some particular input/target pair.
|
| 36 |
+
# so we process each of these (same input/target response sets)
|
| 37 |
+
# independently (and keep them a list.)
|
| 38 |
+
def filter_set(inst):
|
| 39 |
+
filtered = []
|
| 40 |
+
for resp in inst:
|
| 41 |
+
match = self.regex.findall(resp)
|
| 42 |
+
if match:
|
| 43 |
+
match = match[self.group_select]
|
| 44 |
+
if isinstance(match, tuple):
|
| 45 |
+
match = [m for m in match if m]
|
| 46 |
+
if match:
|
| 47 |
+
match = match[0]
|
| 48 |
+
else:
|
| 49 |
+
match = self.fallback
|
| 50 |
+
match = match.strip()
|
| 51 |
+
else:
|
| 52 |
+
match = self.fallback
|
| 53 |
+
filtered.append(match)
|
| 54 |
+
return filtered
|
| 55 |
+
|
| 56 |
+
filtered_resps = list(map(lambda x: filter_set(x), resps))
|
| 57 |
+
return filtered_resps
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@register_filter("regex_pos")
|
| 61 |
+
class POSFilter(Filter):
|
| 62 |
+
""" """
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
regex_pattern: str = r"\['(.*?)'\]",
|
| 67 |
+
group_select=0,
|
| 68 |
+
fallback=None,
|
| 69 |
+
) -> None:
|
| 70 |
+
"""
|
| 71 |
+
pass a string `regex` to run `re.compile(r"regex")` on.
|
| 72 |
+
`fallback` defines the output returned if no matches for the regex are located.
|
| 73 |
+
"""
|
| 74 |
+
if fallback is None:
|
| 75 |
+
fallback = ["invalid"]
|
| 76 |
+
self.regex_pattern = regex_pattern
|
| 77 |
+
self.regex = re.compile(regex_pattern)
|
| 78 |
+
self.group_select = group_select
|
| 79 |
+
self.fallback = fallback
|
| 80 |
+
|
| 81 |
+
def apply(self, resps, docs):
|
| 82 |
+
def extract_tagged_tokens(text):
|
| 83 |
+
# Extract tagged tokens list from text input using regex
|
| 84 |
+
tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
|
| 85 |
+
return [(token, pos) for token, pos in tokens]
|
| 86 |
+
|
| 87 |
+
def extract_pos_tags(result):
|
| 88 |
+
pos_tags = []
|
| 89 |
+
if isinstance(result, str):
|
| 90 |
+
result = extract_tagged_tokens(result)
|
| 91 |
+
pos_tags.extend(pos for _, pos in result)
|
| 92 |
+
return pos_tags if pos_tags else self.fallback
|
| 93 |
+
|
| 94 |
+
def filter_set(inst):
|
| 95 |
+
filtered = []
|
| 96 |
+
for resp in inst:
|
| 97 |
+
match = extract_pos_tags(resp)
|
| 98 |
+
filtered.append(match)
|
| 99 |
+
return filtered
|
| 100 |
+
|
| 101 |
+
filtered_resps = map(lambda x: filter_set(x), resps)
|
| 102 |
+
|
| 103 |
+
return filtered_resps
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@register_filter("remove_whitespace")
|
| 107 |
+
class WhitespaceFilter(Filter):
|
| 108 |
+
"""Filters out leading whitespace from responses."""
|
| 109 |
+
|
| 110 |
+
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 111 |
+
def filter_set(inst):
|
| 112 |
+
filtered_resp = []
|
| 113 |
+
for resp in inst:
|
| 114 |
+
resp = resp.lstrip()
|
| 115 |
+
filtered_resp.append(resp)
|
| 116 |
+
return filtered_resp
|
| 117 |
+
|
| 118 |
+
filtered_resps = [filter_set(resp) for resp in resps]
|
| 119 |
+
|
| 120 |
+
return filtered_resps
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@register_filter("multi_choice_regex")
|
| 124 |
+
class MultiChoiceRegexFilter(RegexFilter):
|
| 125 |
+
"""
|
| 126 |
+
A filter used to extract a model's answer on multiple choice questions with
|
| 127 |
+
letter answers. assumes each document has a "choices" field
|
| 128 |
+
containing the list of answer choices and that the answer label symbols
|
| 129 |
+
are of the form (A), (B), (C), ... or A, B, C.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
|
| 135 |
+
group_select=0,
|
| 136 |
+
fallback: str = "[invalid]",
|
| 137 |
+
ignore_case=False,
|
| 138 |
+
ignore_punctuation=False,
|
| 139 |
+
regexes_to_ignore=None,
|
| 140 |
+
) -> None:
|
| 141 |
+
"""
|
| 142 |
+
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
|
| 143 |
+
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
|
| 144 |
+
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
|
| 145 |
+
group_select: Selects the (group_select)th match from the findall result.
|
| 146 |
+
ignore_case: Ignores the case during step 1 matching
|
| 147 |
+
ignore_punctuation: Remove the punctuation during step 1 matching
|
| 148 |
+
regexes_to_ignore: Remove these regexes during step 1 matching
|
| 149 |
+
"""
|
| 150 |
+
super().__init__(regex_pattern, group_select, fallback)
|
| 151 |
+
self.ignore_case = ignore_case
|
| 152 |
+
self.ignore_punctuation = ignore_punctuation
|
| 153 |
+
self.regexes_to_ignore = regexes_to_ignore
|
| 154 |
+
|
| 155 |
+
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 156 |
+
# here, we assume we have a list, in which each element is
|
| 157 |
+
# a list of model responses for some particular input/target pair.
|
| 158 |
+
# so we process each of these (same input/target response sets)
|
| 159 |
+
# independently (and keep them a list.)
|
| 160 |
+
|
| 161 |
+
def find_match(regex, resp, convert_dict={}):
|
| 162 |
+
match = regex.findall(resp)
|
| 163 |
+
if match:
|
| 164 |
+
match = match[self.group_select]
|
| 165 |
+
if isinstance(match, tuple):
|
| 166 |
+
match = [m for m in match if m][0]
|
| 167 |
+
match = match.strip()
|
| 168 |
+
if match and match in convert_dict:
|
| 169 |
+
match = convert_dict[match]
|
| 170 |
+
return match
|
| 171 |
+
|
| 172 |
+
punct_tbl = dict.fromkeys(
|
| 173 |
+
i
|
| 174 |
+
for i in range(sys.maxunicode)
|
| 175 |
+
if unicodedata.category(chr(i)).startswith("P")
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def filter_ignores(st):
|
| 179 |
+
if self.regexes_to_ignore is not None:
|
| 180 |
+
for s in self.regexes_to_ignore:
|
| 181 |
+
st = re.sub(s, "", st)
|
| 182 |
+
|
| 183 |
+
if self.ignore_case:
|
| 184 |
+
st = st.lower()
|
| 185 |
+
|
| 186 |
+
if self.ignore_punctuation:
|
| 187 |
+
# https://stackoverflow.com/a/266162
|
| 188 |
+
st = st.translate(punct_tbl)
|
| 189 |
+
return st
|
| 190 |
+
|
| 191 |
+
filtered_resps = []
|
| 192 |
+
|
| 193 |
+
for r, doc in zip(resps, docs):
|
| 194 |
+
fallback_regexes = []
|
| 195 |
+
choice_to_alpha = {}
|
| 196 |
+
next_alpha = "A"
|
| 197 |
+
|
| 198 |
+
without_paren_fallback_regexes = []
|
| 199 |
+
without_paren_to_target = {}
|
| 200 |
+
|
| 201 |
+
choices = doc["choices"]
|
| 202 |
+
for c in choices:
|
| 203 |
+
m = filter_ignores(c.strip())
|
| 204 |
+
fallback_regexes.append(f"{re.escape(m)}")
|
| 205 |
+
choice_to_alpha[m] = f"({next_alpha})"
|
| 206 |
+
|
| 207 |
+
without_paren_fallback_regexes.append(next_alpha)
|
| 208 |
+
without_paren_to_target[next_alpha] = f"({next_alpha})"
|
| 209 |
+
|
| 210 |
+
next_alpha = chr(ord(next_alpha) + 1)
|
| 211 |
+
fallback_regex = re.compile("|".join(fallback_regexes))
|
| 212 |
+
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
|
| 213 |
+
without_paren_fallback_regex = re.compile(
|
| 214 |
+
rf":[\s]*({without_paren_fallback_regex})"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
filtered = []
|
| 218 |
+
for resp in r:
|
| 219 |
+
match = find_match(self.regex, resp)
|
| 220 |
+
if not match:
|
| 221 |
+
match = find_match(
|
| 222 |
+
fallback_regex, filter_ignores(resp), choice_to_alpha
|
| 223 |
+
)
|
| 224 |
+
if not match:
|
| 225 |
+
match = find_match(
|
| 226 |
+
without_paren_fallback_regex, resp, without_paren_to_target
|
| 227 |
+
)
|
| 228 |
+
if not match:
|
| 229 |
+
match = self.fallback
|
| 230 |
+
filtered.append(match)
|
| 231 |
+
filtered_resps.append(filtered)
|
| 232 |
+
|
| 233 |
+
return filtered_resps
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/selection.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
|
| 3 |
+
from dllm_eval.api.filter import Filter
|
| 4 |
+
from dllm_eval.api.registry import register_filter
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function
|
| 8 |
+
# that takes an input and returns a scalar and then should select the max reward,
|
| 9 |
+
# or should implement different filters for different ways of handling a reward model's inference.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@register_filter("take_first")
|
| 13 |
+
class TakeFirstFilter(Filter):
|
| 14 |
+
def __init__(self) -> None:
|
| 15 |
+
"""
|
| 16 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def apply(self, resps, docs):
|
| 20 |
+
"""
|
| 21 |
+
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
|
| 22 |
+
"""
|
| 23 |
+
return map(lambda r: r[0], resps)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@register_filter("take_first_k")
|
| 27 |
+
class TakeKFilter(Filter):
|
| 28 |
+
def __init__(self, **kwargs) -> None:
|
| 29 |
+
self.k = kwargs.pop("k")
|
| 30 |
+
|
| 31 |
+
super().__init__(**kwargs)
|
| 32 |
+
|
| 33 |
+
def apply(self, resps, docs):
|
| 34 |
+
# need resp to be subscriptable to check below
|
| 35 |
+
resps = list(resps)
|
| 36 |
+
# check we have at least k responses per doc, else we can't take the first k
|
| 37 |
+
assert len(resps[0]) >= self.k, (
|
| 38 |
+
f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
|
| 39 |
+
)
|
| 40 |
+
return map(lambda r: r[: self.k], resps)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@register_filter("majority_vote")
|
| 44 |
+
class MajorityVoteFilter(Filter):
|
| 45 |
+
def __init__(self) -> None:
|
| 46 |
+
"""
|
| 47 |
+
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def apply(self, resps, docs):
|
| 51 |
+
"""
|
| 52 |
+
Each entry of `resps` is a list of model responses.
|
| 53 |
+
We select the response that occurs most frequently in each entry of `resps`.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def select_majority(resp):
|
| 57 |
+
counts = Counter(resp)
|
| 58 |
+
vote = counts.most_common(1)[0][0]
|
| 59 |
+
return vote
|
| 60 |
+
|
| 61 |
+
return map(lambda r: [select_majority(r)], resps)
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/transformation.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
from dllm_eval.api.filter import Filter
|
| 4 |
+
from dllm_eval.api.registry import register_filter
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@register_filter("lowercase")
|
| 8 |
+
class LowercaseFilter(Filter):
|
| 9 |
+
def __init__(self) -> None:
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
def apply(self, resps, docs):
|
| 13 |
+
def filter_set(inst):
|
| 14 |
+
return [resp.lower() for resp in inst]
|
| 15 |
+
|
| 16 |
+
return [filter_set(resp) for resp in resps]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@register_filter("uppercase")
|
| 20 |
+
class UppercaseFilter(Filter):
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
def apply(self, resps, docs):
|
| 25 |
+
def filter_set(inst):
|
| 26 |
+
return [resp.upper() for resp in inst]
|
| 27 |
+
|
| 28 |
+
return [filter_set(resp) for resp in resps]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@register_filter("map")
|
| 32 |
+
class MapFilter(Filter):
|
| 33 |
+
def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
|
| 34 |
+
"""
|
| 35 |
+
Initializes the MapFilter with a given mapping dictionary and default value.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
- mapping_dict (dict): A dictionary containing the key-value mappings.
|
| 39 |
+
Default is an empty dictionary.
|
| 40 |
+
- default_value (Any): The value to be returned when a key is not found in the mapping_dict.
|
| 41 |
+
Default is None.
|
| 42 |
+
|
| 43 |
+
Example:
|
| 44 |
+
mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
|
| 45 |
+
"""
|
| 46 |
+
if mapping_dict is None:
|
| 47 |
+
mapping_dict = {}
|
| 48 |
+
assert isinstance(mapping_dict, dict), (
|
| 49 |
+
"Provided mapping_dict is not a dictionary"
|
| 50 |
+
)
|
| 51 |
+
self.mapping_dict = mapping_dict
|
| 52 |
+
self.default_value = default_value
|
| 53 |
+
|
| 54 |
+
def apply(self, resps, docs):
|
| 55 |
+
def filter_set(inst):
|
| 56 |
+
return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
|
| 57 |
+
|
| 58 |
+
return [filter_set(resp) for resp in resps]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@register_filter("format_span")
|
| 62 |
+
class SPANFilter(Filter):
|
| 63 |
+
def __init__(self) -> None:
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
def apply(self, resps, docs):
|
| 67 |
+
def format_ner_text(text):
|
| 68 |
+
label_dict = {
|
| 69 |
+
"person": "PER",
|
| 70 |
+
"location": "LOC",
|
| 71 |
+
"organization": "ORG",
|
| 72 |
+
"counties": "LOC",
|
| 73 |
+
"places": "LOC",
|
| 74 |
+
"people": "PER",
|
| 75 |
+
"persons": "PER",
|
| 76 |
+
"company": "ORG",
|
| 77 |
+
"country": "LOC",
|
| 78 |
+
"continent": "LOC",
|
| 79 |
+
"time": "DATE",
|
| 80 |
+
"date": "DATE",
|
| 81 |
+
"per": "PER",
|
| 82 |
+
"loc": "LOC",
|
| 83 |
+
"org": "ORG",
|
| 84 |
+
}
|
| 85 |
+
text = text.lower()
|
| 86 |
+
for key, value in label_dict.items():
|
| 87 |
+
text = text.replace(key, value)
|
| 88 |
+
|
| 89 |
+
text = "$".join(i for i in text.split("$$"))
|
| 90 |
+
return text.rstrip("$$")
|
| 91 |
+
|
| 92 |
+
def format_named_entities(text):
|
| 93 |
+
"""
|
| 94 |
+
Extract named entities from text and format them as 'label: value $$ label: value'.
|
| 95 |
+
Handles grouped entities (e.g., LOC: kenya, uganda) and excludes 'none' values.
|
| 96 |
+
"""
|
| 97 |
+
# Regular expression to match label: entities pattern
|
| 98 |
+
pattern = r"\b(PER|LOC|ORG|DATE):\s*([^$]+)"
|
| 99 |
+
# Normalize newline characters
|
| 100 |
+
text = text.replace("\n", "$").strip()
|
| 101 |
+
matches = re.findall(pattern, text)
|
| 102 |
+
|
| 103 |
+
formatted_entities = []
|
| 104 |
+
|
| 105 |
+
for label, values in matches:
|
| 106 |
+
# Split multiple entities separated by commas and strip whitespace
|
| 107 |
+
entities = [value.strip() for value in values.split(",")]
|
| 108 |
+
|
| 109 |
+
# Exclude 'none' entities
|
| 110 |
+
for entity in entities:
|
| 111 |
+
if entity.lower() != "none":
|
| 112 |
+
formatted_entities.append(f"{label.lower()}: {entity}")
|
| 113 |
+
|
| 114 |
+
# Join entities with the desired separator
|
| 115 |
+
return " $ ".join(formatted_entities)
|
| 116 |
+
|
| 117 |
+
def filter_set(inst):
|
| 118 |
+
return [
|
| 119 |
+
format_named_entities(format_ner_text(resp.lower())) for resp in inst
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
return [filter_set(resp) for resp in resps]
|
Prism/LLaDA/LLaDA_Baseline/dllm_eval/utils.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import fnmatch
|
| 3 |
+
import functools
|
| 4 |
+
import hashlib
|
| 5 |
+
import importlib.util
|
| 6 |
+
import inspect
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
from dataclasses import asdict, is_dataclass
|
| 12 |
+
from itertools import islice
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any, Callable, Generator, List, Optional, Tuple
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import yaml
|
| 18 |
+
from jinja2 import BaseLoader, Environment, StrictUndefined
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
SPACING = " " * 47
|
| 22 |
+
|
| 23 |
+
HIGHER_IS_BETTER_SYMBOLS = {
|
| 24 |
+
True: "↑",
|
| 25 |
+
False: "↓",
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def setup_logging(verbosity=logging.INFO):
|
| 30 |
+
# Configure the root logger
|
| 31 |
+
class CustomFormatter(logging.Formatter):
|
| 32 |
+
def format(self, record):
|
| 33 |
+
if record.name.startswith("dllm_eval."):
|
| 34 |
+
record.name = record.name[len("dllm_eval.") :]
|
| 35 |
+
return super().format(record)
|
| 36 |
+
|
| 37 |
+
formatter = CustomFormatter(
|
| 38 |
+
"%(asctime)s %(levelname)-8s [%(name)s:%(lineno)d] %(message)s",
|
| 39 |
+
datefmt="%Y-%m-%d:%H:%M:%S",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity
|
| 43 |
+
|
| 44 |
+
level_map = {
|
| 45 |
+
"DEBUG": logging.DEBUG,
|
| 46 |
+
"INFO": logging.INFO,
|
| 47 |
+
"WARNING": logging.WARNING,
|
| 48 |
+
"ERROR": logging.ERROR,
|
| 49 |
+
"CRITICAL": logging.CRITICAL,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
log_level = level_map.get(str(log_level).upper(), logging.INFO)
|
| 53 |
+
|
| 54 |
+
if not logging.root.handlers:
|
| 55 |
+
handler = logging.StreamHandler()
|
| 56 |
+
handler.setFormatter(formatter)
|
| 57 |
+
|
| 58 |
+
root_logger = logging.getLogger()
|
| 59 |
+
root_logger.addHandler(handler)
|
| 60 |
+
root_logger.setLevel(log_level)
|
| 61 |
+
|
| 62 |
+
if log_level == logging.DEBUG:
|
| 63 |
+
third_party_loggers = ["urllib3", "filelock", "fsspec"]
|
| 64 |
+
for logger_name in third_party_loggers:
|
| 65 |
+
logging.getLogger(logger_name).setLevel(logging.INFO)
|
| 66 |
+
else:
|
| 67 |
+
logging.getLogger().setLevel(log_level)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def hash_string(string: str) -> str:
|
| 71 |
+
return hashlib.sha256(string.encode("utf-8")).hexdigest()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def escaped_split(text, sep_char, maxsplit=-1):
|
| 75 |
+
"""Split text into a list on occurrences of the given separation
|
| 76 |
+
character `sep_char`. The separation character may be escaped by a
|
| 77 |
+
backslash to avoid splitting at that location.
|
| 78 |
+
|
| 79 |
+
The separation character must be a string of size 1.
|
| 80 |
+
|
| 81 |
+
If `maxsplit` is given, at most `maxsplit` splits are done (thus,
|
| 82 |
+
the list will have at most `maxsplit + 1` elements). If `maxsplit`
|
| 83 |
+
is not specified or less than 0, then there is no limit on the
|
| 84 |
+
number of splits (all possible splits are made).
|
| 85 |
+
"""
|
| 86 |
+
assert len(sep_char) == 1, (
|
| 87 |
+
"separation string must be a single character for escaped splitting"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if maxsplit == 0:
|
| 91 |
+
return text
|
| 92 |
+
maxsplit = max(0, maxsplit)
|
| 93 |
+
|
| 94 |
+
return re.split(r"(?<!\\)" + sep_char, text, maxsplit)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def handle_arg_string(arg):
|
| 98 |
+
if arg.lower() == "true":
|
| 99 |
+
return True
|
| 100 |
+
elif arg.lower() == "false":
|
| 101 |
+
return False
|
| 102 |
+
elif arg.isnumeric():
|
| 103 |
+
return int(arg)
|
| 104 |
+
try:
|
| 105 |
+
return float(arg)
|
| 106 |
+
except ValueError:
|
| 107 |
+
return arg
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def handle_non_serializable(o):
|
| 111 |
+
if isinstance(o, np.int64) or isinstance(o, np.int32):
|
| 112 |
+
return int(o)
|
| 113 |
+
elif isinstance(o, set):
|
| 114 |
+
return list(o)
|
| 115 |
+
else:
|
| 116 |
+
return str(o)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def sanitize_list(sub):
|
| 120 |
+
"""
|
| 121 |
+
Takes possible nested list and recursively converts all inner component to strings
|
| 122 |
+
"""
|
| 123 |
+
if isinstance(sub, list):
|
| 124 |
+
return [sanitize_list(item) for item in sub]
|
| 125 |
+
if isinstance(sub, tuple):
|
| 126 |
+
return tuple(sanitize_list(item) for item in sub)
|
| 127 |
+
else:
|
| 128 |
+
return str(sub)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def simple_parse_args_string(args_string: Optional[str]) -> dict:
|
| 132 |
+
"""
|
| 133 |
+
Parses something like
|
| 134 |
+
args1=val1,arg2=val2
|
| 135 |
+
Into a dictionary
|
| 136 |
+
"""
|
| 137 |
+
if args_string is None:
|
| 138 |
+
return {}
|
| 139 |
+
args_string = args_string.strip()
|
| 140 |
+
if not args_string:
|
| 141 |
+
return {}
|
| 142 |
+
arg_list = [arg for arg in args_string.split(",") if arg]
|
| 143 |
+
args_dict = {
|
| 144 |
+
kv[0]: handle_arg_string("=".join(kv[1:]))
|
| 145 |
+
for kv in [arg.split("=") for arg in arg_list]
|
| 146 |
+
}
|
| 147 |
+
return args_dict
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def join_iters(iters):
|
| 151 |
+
for iter in iters:
|
| 152 |
+
yield from iter
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def group(arr, fn):
|
| 156 |
+
res = collections.defaultdict(list)
|
| 157 |
+
|
| 158 |
+
for ob in arr:
|
| 159 |
+
res[fn(ob)].append(ob)
|
| 160 |
+
|
| 161 |
+
return list(res.values())
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# Returns a list containing all values of the source_list that
|
| 165 |
+
# match at least one of the patterns
|
| 166 |
+
def pattern_match(patterns, source_list):
|
| 167 |
+
if isinstance(patterns, str):
|
| 168 |
+
patterns = [patterns]
|
| 169 |
+
|
| 170 |
+
task_names = set()
|
| 171 |
+
for pattern in patterns:
|
| 172 |
+
for matching in fnmatch.filter(source_list, pattern):
|
| 173 |
+
task_names.add(matching)
|
| 174 |
+
return sorted(list(task_names))
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def softmax(x) -> np.ndarray:
|
| 178 |
+
"""Compute softmax values for each sets of scores in x."""
|
| 179 |
+
e_x = np.exp(x - np.max(x))
|
| 180 |
+
return e_x / e_x.sum()
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def general_detokenize(string) -> str:
|
| 184 |
+
string = string.replace(" n't", "n't")
|
| 185 |
+
string = string.replace(" )", ")")
|
| 186 |
+
string = string.replace("( ", "(")
|
| 187 |
+
string = string.replace('" ', '"')
|
| 188 |
+
string = string.replace(' "', '"')
|
| 189 |
+
string = re.sub(r" (['.,])", r"\1", string)
|
| 190 |
+
return string
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def get_file_task_name(filename: str) -> str:
|
| 194 |
+
"""
|
| 195 |
+
Given the sample results filenames, extracts and returns the task name.
|
| 196 |
+
"""
|
| 197 |
+
return filename[filename.find("_") + 1 : filename.rfind("_")]
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_file_datetime(filename: str) -> str:
|
| 201 |
+
"""
|
| 202 |
+
Given the results and sample results filenames, extracts and returns the datetime.
|
| 203 |
+
"""
|
| 204 |
+
return filename[filename.rfind("_") + 1 :].replace(".jsonl", "")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def sanitize_model_name(model_name: str) -> str:
|
| 208 |
+
"""
|
| 209 |
+
Given the model name, returns a sanitized version of it.
|
| 210 |
+
"""
|
| 211 |
+
return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def sanitize_task_name(task_name: str) -> str:
|
| 215 |
+
"""
|
| 216 |
+
Given the task name, returns a sanitized version of it.
|
| 217 |
+
"""
|
| 218 |
+
return re.sub(r"\W", "_", task_name)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def get_latest_filename(filenames: List[str]) -> str:
|
| 222 |
+
"""
|
| 223 |
+
Given a list of filenames, returns the filename with the latest datetime.
|
| 224 |
+
"""
|
| 225 |
+
return max(filenames, key=lambda f: get_file_datetime(f))
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def get_results_filenames(filenames: List[str]) -> List[str]:
|
| 229 |
+
"""
|
| 230 |
+
Extracts filenames that correspond to aggregated results.
|
| 231 |
+
"""
|
| 232 |
+
return [f for f in filenames if "/results_" in f and ".json" in f]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def get_sample_results_filenames(filenames: List[str]) -> List[str]:
|
| 236 |
+
"""
|
| 237 |
+
Extracts filenames that correspond to sample results.
|
| 238 |
+
"""
|
| 239 |
+
return [f for f in filenames if "/samples_" in f and ".json" in f]
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def get_rolling_token_windows(
|
| 243 |
+
token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int
|
| 244 |
+
) -> Generator[Tuple[List[int], List[int]], None, None]:
|
| 245 |
+
"""
|
| 246 |
+
- context_len allows for a rolling window context, allowing each prediction window to potentially
|
| 247 |
+
condition on some context
|
| 248 |
+
|
| 249 |
+
:param token_list: list
|
| 250 |
+
List of tokens to be PREDICTED
|
| 251 |
+
:param max_seq_len: int
|
| 252 |
+
max_seq_len of model (or max_seq_len we want to use)
|
| 253 |
+
:param context_len: int
|
| 254 |
+
Amount of desired token context for prediction. Needs to be at least 1.
|
| 255 |
+
:param prefix_token: token
|
| 256 |
+
Dummy token like <eos> so the first token has something to condition on
|
| 257 |
+
:return: generator
|
| 258 |
+
Generator of tuples
|
| 259 |
+
(input_tokens, pred_tokens)
|
| 260 |
+
Note: Score only the last len(pred_tokens) logits of the LM
|
| 261 |
+
"""
|
| 262 |
+
assert 1 <= context_len <= max_seq_len
|
| 263 |
+
if not token_list:
|
| 264 |
+
return
|
| 265 |
+
# +1 offset, going from input->preds
|
| 266 |
+
pred_len = max_seq_len - context_len + 1
|
| 267 |
+
predicted = 0
|
| 268 |
+
|
| 269 |
+
# Special handling for first window: predict all tokens
|
| 270 |
+
first_seq_len = min(max_seq_len, len(token_list))
|
| 271 |
+
yield [prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]
|
| 272 |
+
predicted += first_seq_len
|
| 273 |
+
|
| 274 |
+
while predicted < len(token_list):
|
| 275 |
+
window_pred_len = min(len(token_list) - predicted, pred_len)
|
| 276 |
+
window_end = predicted + window_pred_len
|
| 277 |
+
|
| 278 |
+
yield (
|
| 279 |
+
token_list[window_end - max_seq_len - 1 : window_end - 1],
|
| 280 |
+
token_list[window_end - window_pred_len : window_end],
|
| 281 |
+
)
|
| 282 |
+
predicted += window_pred_len
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def make_disjoint_window(
|
| 286 |
+
pair: Tuple[List[int], List[int]],
|
| 287 |
+
) -> Tuple[List[int], List[int]]:
|
| 288 |
+
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
|
| 289 |
+
a, b = pair
|
| 290 |
+
return a[: len(a) - (len(b) - 1)], b
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class EnhancedJSONEncoder(json.JSONEncoder):
|
| 294 |
+
"""
|
| 295 |
+
Provides a proper json encoding for the loggers and trackers json dumps.
|
| 296 |
+
Notably manages the json encoding of dataclasses.
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
def default(self, o):
|
| 300 |
+
if is_dataclass(o):
|
| 301 |
+
return asdict(o)
|
| 302 |
+
return super().default(o)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class Reorderer:
|
| 306 |
+
def __init__(self, arr: List[Any], fn: Callable) -> None:
|
| 307 |
+
"""Reorder an array according to some function
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
arr (List[Any]): The initial array
|
| 311 |
+
fn (Callable[[Any], Any]): A function to determine the priority of elements
|
| 312 |
+
"""
|
| 313 |
+
self.size = len(arr)
|
| 314 |
+
arr = list(enumerate(arr))
|
| 315 |
+
arr = group(arr, lambda x: fn(x[1]))
|
| 316 |
+
# arr = [([y[0] for y in x], x[0][1]) for x in arr]
|
| 317 |
+
# TODO: overhaul reorderer. It currently grouped requests by content but we don't want this
|
| 318 |
+
arr = [([y[0]], x[0][1]) for x in arr for y in x]
|
| 319 |
+
arr.sort(key=lambda x: fn(x[1]))
|
| 320 |
+
|
| 321 |
+
self.arr = arr
|
| 322 |
+
|
| 323 |
+
def get_reordered(self):
|
| 324 |
+
"""Gets the reordered array
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
List[Any]: The reordered array
|
| 328 |
+
"""
|
| 329 |
+
return [x[1] for x in self.arr]
|
| 330 |
+
|
| 331 |
+
def get_original(self, newarr):
|
| 332 |
+
"""Restores the original order of a new array based on the old array's order
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
newarr (List[Any]): The array to be restored
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
List[Any]: The array restored to the original order
|
| 339 |
+
"""
|
| 340 |
+
res = [None] * self.size
|
| 341 |
+
cov = [False] * self.size
|
| 342 |
+
|
| 343 |
+
for (inds, _), v in zip(self.arr, newarr):
|
| 344 |
+
for ind in inds:
|
| 345 |
+
res[ind] = v
|
| 346 |
+
cov[ind] = True
|
| 347 |
+
|
| 348 |
+
assert all(cov)
|
| 349 |
+
|
| 350 |
+
return res
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def make_table(result_dict, column: str = "results", sort_results: bool = False):
|
| 354 |
+
"""Generate table of results."""
|
| 355 |
+
from pytablewriter import LatexTableWriter, MarkdownTableWriter
|
| 356 |
+
|
| 357 |
+
if column == "results":
|
| 358 |
+
column_name = "Tasks"
|
| 359 |
+
elif column == "groups":
|
| 360 |
+
column_name = "Groups"
|
| 361 |
+
|
| 362 |
+
all_headers = [
|
| 363 |
+
column_name,
|
| 364 |
+
"Version",
|
| 365 |
+
"Filter",
|
| 366 |
+
"n-shot",
|
| 367 |
+
"Metric",
|
| 368 |
+
"",
|
| 369 |
+
"Value",
|
| 370 |
+
"",
|
| 371 |
+
"Stderr",
|
| 372 |
+
]
|
| 373 |
+
|
| 374 |
+
md_writer = MarkdownTableWriter()
|
| 375 |
+
latex_writer = LatexTableWriter()
|
| 376 |
+
md_writer.headers = all_headers
|
| 377 |
+
latex_writer.headers = all_headers
|
| 378 |
+
|
| 379 |
+
values = []
|
| 380 |
+
|
| 381 |
+
keys = result_dict[column].keys()
|
| 382 |
+
if sort_results:
|
| 383 |
+
# sort entries alphabetically by task or group name.
|
| 384 |
+
# NOTE: we default here to false, because order matters for multi-level table printing a la mmlu.
|
| 385 |
+
# sorting here would mess that up
|
| 386 |
+
keys = sorted(keys)
|
| 387 |
+
for k in keys:
|
| 388 |
+
dic = result_dict[column][k]
|
| 389 |
+
version = result_dict["versions"].get(k, " N/A")
|
| 390 |
+
n = str(result_dict.get("n-shot", " ").get(k, " "))
|
| 391 |
+
higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})
|
| 392 |
+
|
| 393 |
+
if "alias" in dic:
|
| 394 |
+
k = dic.pop("alias")
|
| 395 |
+
|
| 396 |
+
metric_items = dic.items()
|
| 397 |
+
metric_items = sorted(metric_items)
|
| 398 |
+
|
| 399 |
+
for (mf), v in metric_items:
|
| 400 |
+
m, _, f = mf.partition(",")
|
| 401 |
+
if m.endswith("_stderr"):
|
| 402 |
+
continue
|
| 403 |
+
|
| 404 |
+
hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")
|
| 405 |
+
|
| 406 |
+
v = "%.4f" % v if isinstance(v, float) else v
|
| 407 |
+
|
| 408 |
+
if m + "_stderr" + "," + f in dic:
|
| 409 |
+
se = dic[m + "_stderr" + "," + f]
|
| 410 |
+
se = " N/A" if se == "N/A" else "%.4f" % se
|
| 411 |
+
values.append([k, version, f, n, m, hib, v, "±", se])
|
| 412 |
+
else:
|
| 413 |
+
values.append([k, version, f, n, m, hib, v, "", ""])
|
| 414 |
+
k = ""
|
| 415 |
+
version = ""
|
| 416 |
+
md_writer.value_matrix = values
|
| 417 |
+
latex_writer.value_matrix = values
|
| 418 |
+
|
| 419 |
+
# todo: make latex table look good
|
| 420 |
+
# print(latex_writer.dumps())
|
| 421 |
+
|
| 422 |
+
return md_writer.dumps()
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def positional_deprecated(fn):
|
| 426 |
+
"""
|
| 427 |
+
A decorator to nudge users into passing only keyword args (`kwargs`) to the
|
| 428 |
+
wrapped function, `fn`.
|
| 429 |
+
"""
|
| 430 |
+
|
| 431 |
+
@functools.wraps(fn)
|
| 432 |
+
def _wrapper(*args, **kwargs):
|
| 433 |
+
if len(args) != 1 if inspect.ismethod(fn) else 0:
|
| 434 |
+
print(
|
| 435 |
+
f"WARNING: using {fn.__name__} with positional arguments is "
|
| 436 |
+
"deprecated and will be disallowed in a future version of "
|
| 437 |
+
"lm-evaluation-harness!"
|
| 438 |
+
)
|
| 439 |
+
return fn(*args, **kwargs)
|
| 440 |
+
|
| 441 |
+
return _wrapper
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def ignore_constructor(loader, node):
|
| 445 |
+
return node
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def import_function(loader: yaml.Loader, node, yaml_path: Path):
|
| 449 |
+
function_name = loader.construct_scalar(node)
|
| 450 |
+
|
| 451 |
+
*module_name, function_name = function_name.split(".")
|
| 452 |
+
if isinstance(module_name, list):
|
| 453 |
+
module_name = ".".join(module_name)
|
| 454 |
+
module_path = yaml_path.parent / f"{module_name}.py"
|
| 455 |
+
|
| 456 |
+
spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix())
|
| 457 |
+
|
| 458 |
+
if spec is None:
|
| 459 |
+
raise ImportError(f"Could not import module {module_name} from {module_path}.")
|
| 460 |
+
module = importlib.util.module_from_spec(spec)
|
| 461 |
+
|
| 462 |
+
if spec.loader is None:
|
| 463 |
+
raise ImportError(f"Module loader is None, {module_name} from {module_path}.")
|
| 464 |
+
spec.loader.exec_module(module)
|
| 465 |
+
|
| 466 |
+
function = getattr(module, function_name)
|
| 467 |
+
return function
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"):
|
| 471 |
+
if mode == "simple":
|
| 472 |
+
constructor_fn = ignore_constructor
|
| 473 |
+
elif mode == "full":
|
| 474 |
+
if yaml_path is None:
|
| 475 |
+
raise ValueError("yaml_path must be provided if mode is 'full'.")
|
| 476 |
+
# Attach yaml_path to the import function so that it can be used later
|
| 477 |
+
constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path))
|
| 478 |
+
|
| 479 |
+
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
|
| 480 |
+
# Add the import_function constructor to the YAML loader
|
| 481 |
+
yaml.add_constructor("!function", constructor_fn, Loader=loader)
|
| 482 |
+
if yaml_config is None:
|
| 483 |
+
with open(yaml_path, "rb") as file:
|
| 484 |
+
yaml_config = yaml.load(file, Loader=loader)
|
| 485 |
+
|
| 486 |
+
if yaml_dir is None:
|
| 487 |
+
yaml_dir = os.path.dirname(yaml_path)
|
| 488 |
+
|
| 489 |
+
assert yaml_dir is not None
|
| 490 |
+
|
| 491 |
+
if "include" in yaml_config:
|
| 492 |
+
include_path = yaml_config["include"]
|
| 493 |
+
del yaml_config["include"]
|
| 494 |
+
|
| 495 |
+
if isinstance(include_path, str):
|
| 496 |
+
include_path = [include_path]
|
| 497 |
+
|
| 498 |
+
# Load from the last one first
|
| 499 |
+
include_path.reverse()
|
| 500 |
+
final_yaml_config = {}
|
| 501 |
+
for path in include_path:
|
| 502 |
+
# Assumes that path is a full path.
|
| 503 |
+
# If not found, assume the included yaml
|
| 504 |
+
# is in the same dir as the original yaml
|
| 505 |
+
if not os.path.isfile(path):
|
| 506 |
+
path = os.path.join(yaml_dir, path)
|
| 507 |
+
|
| 508 |
+
try:
|
| 509 |
+
included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
|
| 510 |
+
final_yaml_config.update(included_yaml_config)
|
| 511 |
+
except Exception as ex:
|
| 512 |
+
# If failed to load, ignore
|
| 513 |
+
raise ex
|
| 514 |
+
|
| 515 |
+
final_yaml_config.update(yaml_config)
|
| 516 |
+
return final_yaml_config
|
| 517 |
+
return yaml_config
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def regex_replace(string, pattern, repl, count: int = 0):
|
| 521 |
+
"""Implements the `re.sub` function as a custom Jinja filter."""
|
| 522 |
+
return re.sub(pattern, repl, string, count=count)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
env = Environment(
|
| 526 |
+
loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True
|
| 527 |
+
)
|
| 528 |
+
env.filters["regex_replace"] = regex_replace
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def apply_template(template: str, doc: dict) -> str:
|
| 532 |
+
rtemplate = env.from_string(template)
|
| 533 |
+
return rtemplate.render(**doc)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
|
| 537 |
+
"""
|
| 538 |
+
Method for creating a (potentially) sliced and limited
|
| 539 |
+
iterator from a raw document iterator. Used for splitting data
|
| 540 |
+
among ranks in multigpu setting or only pulling a sample of documents
|
| 541 |
+
"""
|
| 542 |
+
return islice(raw_iterator, rank, limit, world_size)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def weighted_f1_score(items):
|
| 546 |
+
from sklearn.metrics import f1_score
|
| 547 |
+
|
| 548 |
+
unzipped_list = list(zip(*items))
|
| 549 |
+
golds = unzipped_list[0]
|
| 550 |
+
preds = unzipped_list[1]
|
| 551 |
+
fscore = f1_score(golds, preds, average="weighted")
|
| 552 |
+
return fscore
|
Prism/LLaDA/LLaDA_Baseline/evaluation_script.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
from dllm_eval.__main__ import cli_evaluate
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def set_seed(seed):
|
| 9 |
+
torch.manual_seed(seed)
|
| 10 |
+
random.seed(seed)
|
| 11 |
+
np.random.seed(seed)
|
| 12 |
+
|
| 13 |
+
torch.backends.cudnn.deterministic = True
|
| 14 |
+
torch.backends.cudnn.benchmark = False
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
if __name__ == "__main__":
|
| 18 |
+
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
| 19 |
+
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
|
| 20 |
+
set_seed(42)
|
| 21 |
+
cli_evaluate()
|
Prism/LLaDA/LLaDA_Baseline/metrics/gsm8k_all.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import os
|
| 4 |
+
import math
|
| 5 |
+
import argparse
|
| 6 |
+
from collections import Counter
|
| 7 |
+
|
| 8 |
+
RES_PATH = "<PATH_TO_RESULTS_JSONL>"
|
| 9 |
+
|
| 10 |
+
def last_boxed_only_string(string):
|
| 11 |
+
if not string: return None
|
| 12 |
+
idx = max(string.rfind("\\boxed"), string.rfind("\\fbox"))
|
| 13 |
+
if idx < 0: return None
|
| 14 |
+
|
| 15 |
+
if "\\boxed " in string[idx:idx+8] and "{" not in string[idx:idx+8]:
|
| 16 |
+
return "\\boxed " + string[idx:].split("\\boxed ")[-1].split("$")[0].strip()
|
| 17 |
+
|
| 18 |
+
i = idx
|
| 19 |
+
right_brace_idx = None
|
| 20 |
+
num_left_braces_open = 0
|
| 21 |
+
while i < len(string):
|
| 22 |
+
if string[i] == "{":
|
| 23 |
+
num_left_braces_open += 1
|
| 24 |
+
elif string[i] == "}":
|
| 25 |
+
num_left_braces_open -= 1
|
| 26 |
+
if num_left_braces_open == 0:
|
| 27 |
+
right_brace_idx = i
|
| 28 |
+
break
|
| 29 |
+
i += 1
|
| 30 |
+
return string[idx : right_brace_idx + 1] if right_brace_idx else None
|
| 31 |
+
|
| 32 |
+
def remove_boxed(s):
|
| 33 |
+
if not s: return None
|
| 34 |
+
if "\\boxed " in s: return s[len("\\boxed ") :]
|
| 35 |
+
if "\\boxed{" in s and s.endswith("}"): return s[len("\\boxed{") : -1]
|
| 36 |
+
if "\\fbox{" in s and s.endswith("}"): return s[len("\\fbox{") : -1]
|
| 37 |
+
return s
|
| 38 |
+
|
| 39 |
+
def strip_string(string):
|
| 40 |
+
if string is None: return ""
|
| 41 |
+
string = str(string).strip()
|
| 42 |
+
while re.search(r"(\d),(\d{3})", string):
|
| 43 |
+
string = re.sub(r"(\d),(\d{3})", r"\1\2", string)
|
| 44 |
+
|
| 45 |
+
string = string.replace("\n", "").replace("\\!", "")
|
| 46 |
+
string = string.replace("tfrac", "frac").replace("dfrac", "frac")
|
| 47 |
+
string = string.replace("\\left", "").replace("\\right", "")
|
| 48 |
+
string = string.replace("^{\\circ}", "").replace("^\\circ", "")
|
| 49 |
+
string = string.replace("\\$", "").replace("\\%", "").replace("\%", "")
|
| 50 |
+
|
| 51 |
+
if "=" in string and len(string.split("=")[0]) <= 5:
|
| 52 |
+
string = string.split("=")[1].strip()
|
| 53 |
+
|
| 54 |
+
string = string.replace(" ", "")
|
| 55 |
+
string = string.rstrip(".")
|
| 56 |
+
return string
|
| 57 |
+
|
| 58 |
+
def normalize_to_number(s):
|
| 59 |
+
s_clean = strip_string(s)
|
| 60 |
+
try:
|
| 61 |
+
if '/' in s_clean and len(s_clean.split('/')) == 2:
|
| 62 |
+
parts = s_clean.split('/')
|
| 63 |
+
return float(parts[0]) / float(parts[1])
|
| 64 |
+
return float(s_clean)
|
| 65 |
+
except:
|
| 66 |
+
return s_clean
|
| 67 |
+
|
| 68 |
+
def extract_answer_gsm8k_debug(text):
|
| 69 |
+
if not text: return "", "empty"
|
| 70 |
+
text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").strip()
|
| 71 |
+
|
| 72 |
+
boxed = last_boxed_only_string(text)
|
| 73 |
+
if boxed:
|
| 74 |
+
ans = remove_boxed(boxed)
|
| 75 |
+
if ans:
|
| 76 |
+
return strip_string(ans), "boxed"
|
| 77 |
+
|
| 78 |
+
tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
|
| 79 |
+
if tag_match:
|
| 80 |
+
return strip_string(tag_match.group(1)), "xml_tag"
|
| 81 |
+
|
| 82 |
+
last_text = text[-200:] if len(text) > 200 else text
|
| 83 |
+
marker = "the answer is"
|
| 84 |
+
if marker in last_text.lower():
|
| 85 |
+
idx = last_text.lower().rfind(marker)
|
| 86 |
+
after = last_text[idx + len(marker):].strip()
|
| 87 |
+
after = re.split(r"[.\n]", after)[0]
|
| 88 |
+
after = after.replace(":", "").replace("$", "").strip()
|
| 89 |
+
return strip_string(after), "text_marker"
|
| 90 |
+
|
| 91 |
+
tail = text[-50:]
|
| 92 |
+
nums = re.findall(r"(?<!\d)-?\d+\.?\d*(?!\d)", tail)
|
| 93 |
+
if nums:
|
| 94 |
+
return strip_string(nums[-1]), "regex_last_num"
|
| 95 |
+
|
| 96 |
+
return "", "failed"
|
| 97 |
+
|
| 98 |
+
def extract_gold_gsm8k(target_str):
|
| 99 |
+
if "####" in target_str:
|
| 100 |
+
return strip_string(target_str.split("####")[-1])
|
| 101 |
+
return strip_string(target_str)
|
| 102 |
+
|
| 103 |
+
def is_equiv(pred, gold):
|
| 104 |
+
p_val = normalize_to_number(pred)
|
| 105 |
+
g_val = normalize_to_number(gold)
|
| 106 |
+
|
| 107 |
+
if isinstance(p_val, float) and isinstance(g_val, float):
|
| 108 |
+
return math.isclose(p_val, g_val, rel_tol=1e-4)
|
| 109 |
+
return str(p_val) == str(g_val)
|
| 110 |
+
|
| 111 |
+
def run_evaluation(target_path):
|
| 112 |
+
jsonl_files = []
|
| 113 |
+
if os.path.isdir(target_path):
|
| 114 |
+
for root, dirs, files in os.walk(target_path):
|
| 115 |
+
for file in files:
|
| 116 |
+
if file.endswith(".jsonl") and not file.startswith("eval_voted_"):
|
| 117 |
+
jsonl_files.append(os.path.join(root, file))
|
| 118 |
+
else:
|
| 119 |
+
jsonl_files = [target_path]
|
| 120 |
+
|
| 121 |
+
for file_path in jsonl_files:
|
| 122 |
+
print(f">>> 正在评测: {file_path}")
|
| 123 |
+
detailed_results = []
|
| 124 |
+
|
| 125 |
+
correct_voted_count = 0
|
| 126 |
+
correct_any_count = 0
|
| 127 |
+
total_count = 0
|
| 128 |
+
nfe_list = []
|
| 129 |
+
svf_list = []
|
| 130 |
+
|
| 131 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 132 |
+
for line in f:
|
| 133 |
+
if not line.strip(): continue
|
| 134 |
+
try:
|
| 135 |
+
item = json.loads(line)
|
| 136 |
+
except:
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
doc = item.get("doc", {})
|
| 140 |
+
ground_truth = extract_gold_gsm8k(str(item.get("target", "")))
|
| 141 |
+
|
| 142 |
+
total_nfe_item = item.get("nfe", 0)
|
| 143 |
+
nfe_list.append(total_nfe_item)
|
| 144 |
+
svf_list.append(item.get("svf_calls", 0))
|
| 145 |
+
|
| 146 |
+
trajectories = item.get("all_trajectories", [])
|
| 147 |
+
if not trajectories:
|
| 148 |
+
resps = item.get("resps", [])
|
| 149 |
+
for r in resps:
|
| 150 |
+
text = r[0] if isinstance(r, list) else r
|
| 151 |
+
trajectories.append({"resp": text, "score": 0.0})
|
| 152 |
+
|
| 153 |
+
parsed_paths = []
|
| 154 |
+
traj_debug_info = []
|
| 155 |
+
|
| 156 |
+
for idx, traj in enumerate(trajectories):
|
| 157 |
+
raw_text = traj.get("resp", "")
|
| 158 |
+
score = traj.get("score", 0.0)
|
| 159 |
+
|
| 160 |
+
extracted, method = extract_answer_gsm8k_debug(raw_text)
|
| 161 |
+
|
| 162 |
+
is_correct_single = False
|
| 163 |
+
if extracted:
|
| 164 |
+
is_correct_single = is_equiv(extracted, ground_truth)
|
| 165 |
+
val_key = normalize_to_number(extracted)
|
| 166 |
+
|
| 167 |
+
parsed_paths.append({
|
| 168 |
+
"original_text": extracted,
|
| 169 |
+
"val_key": val_key,
|
| 170 |
+
"score": score,
|
| 171 |
+
"method": method
|
| 172 |
+
})
|
| 173 |
+
|
| 174 |
+
traj_debug_info.append({
|
| 175 |
+
"id": idx,
|
| 176 |
+
"extracted": extracted,
|
| 177 |
+
"score": score,
|
| 178 |
+
"is_correct": is_correct_single,
|
| 179 |
+
"extract_method": method
|
| 180 |
+
})
|
| 181 |
+
|
| 182 |
+
if not parsed_paths:
|
| 183 |
+
detailed_results.append({
|
| 184 |
+
"question": doc.get("question", "N/A"),
|
| 185 |
+
"final_voted_answer": "",
|
| 186 |
+
"ground_truth": ground_truth,
|
| 187 |
+
"is_voted_correct": False,
|
| 188 |
+
"trajectory_details": traj_debug_info,
|
| 189 |
+
"nfe": total_nfe_item,
|
| 190 |
+
"svf_calls": item.get("svf_calls", 0)
|
| 191 |
+
})
|
| 192 |
+
total_count += 1
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
has_correct = any(p['score'] > -999 and is_equiv(p['original_text'], ground_truth) for p in parsed_paths)
|
| 196 |
+
if has_correct:
|
| 197 |
+
correct_any_count += 1
|
| 198 |
+
|
| 199 |
+
parsed_paths.sort(key=lambda x: x['score'], reverse=True)
|
| 200 |
+
top_k_count = max(1, int(len(parsed_paths) * 0.6))
|
| 201 |
+
voting_candidates = parsed_paths[:top_k_count]
|
| 202 |
+
|
| 203 |
+
ans_stats = {}
|
| 204 |
+
for p in voting_candidates:
|
| 205 |
+
k = p['val_key']
|
| 206 |
+
if k not in ans_stats:
|
| 207 |
+
ans_stats[k] = {
|
| 208 |
+
"total_weight": 0.0,
|
| 209 |
+
"count": 0,
|
| 210 |
+
"max_score": -float('inf'),
|
| 211 |
+
"best_repr": p['original_text']
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
try:
|
| 215 |
+
weight = math.exp(p['score'])
|
| 216 |
+
except OverflowError:
|
| 217 |
+
weight = float('inf')
|
| 218 |
+
|
| 219 |
+
ans_stats[k]["total_weight"] += weight
|
| 220 |
+
ans_stats[k]["count"] += 1
|
| 221 |
+
if p['score'] > ans_stats[k]["max_score"]:
|
| 222 |
+
ans_stats[k]["max_score"] = p['score']
|
| 223 |
+
ans_stats[k]["best_repr"] = p['original_text']
|
| 224 |
+
|
| 225 |
+
sorted_answers = sorted(
|
| 226 |
+
ans_stats.items(),
|
| 227 |
+
key=lambda x: (x[1]["total_weight"], x[1]["max_score"]),
|
| 228 |
+
reverse=True
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
best_pred = str(sorted_answers[0][1]["best_repr"])
|
| 232 |
+
is_voted_correct = is_equiv(best_pred, ground_truth)
|
| 233 |
+
if is_voted_correct:
|
| 234 |
+
correct_voted_count += 1
|
| 235 |
+
|
| 236 |
+
vote_summary = []
|
| 237 |
+
for val, info in sorted_answers:
|
| 238 |
+
vote_summary.append({
|
| 239 |
+
"answer": str(val),
|
| 240 |
+
"count": info["count"],
|
| 241 |
+
"total_weight": info["total_weight"],
|
| 242 |
+
"is_correct": is_equiv(str(val), ground_truth)
|
| 243 |
+
})
|
| 244 |
+
|
| 245 |
+
total_count += 1
|
| 246 |
+
|
| 247 |
+
detailed_results.append({
|
| 248 |
+
"question": doc.get("question", "N/A"),
|
| 249 |
+
"final_voted_answer": best_pred,
|
| 250 |
+
"ground_truth": ground_truth,
|
| 251 |
+
"is_voted_correct": is_voted_correct,
|
| 252 |
+
"vote_stats": vote_summary,
|
| 253 |
+
"trajectory_details": traj_debug_info,
|
| 254 |
+
"nfe": total_nfe_item,
|
| 255 |
+
"svf_calls": item.get("svf_calls", 0)
|
| 256 |
+
})
|
| 257 |
+
|
| 258 |
+
accuracy = (correct_voted_count / total_count * 100) if total_count > 0 else 0
|
| 259 |
+
pass_at_k = (correct_any_count / total_count * 100) if total_count > 0 else 0
|
| 260 |
+
avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0
|
| 261 |
+
avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0
|
| 262 |
+
|
| 263 |
+
print(f"--- Accuracy: {accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---")
|
| 264 |
+
|
| 265 |
+
output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}"
|
| 266 |
+
output_path = os.path.join(os.path.dirname(file_path), output_name)
|
| 267 |
+
|
| 268 |
+
final_report = {
|
| 269 |
+
"summary": {
|
| 270 |
+
"accuracy": f"{accuracy:.2f}%",
|
| 271 |
+
"correct_voted": correct_voted_count,
|
| 272 |
+
"total": total_count,
|
| 273 |
+
"nfe": avg_nfe,
|
| 274 |
+
"svf_calls": avg_svf
|
| 275 |
+
},
|
| 276 |
+
"details": detailed_results
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
with open(output_path, 'w', encoding='utf-8') as out_f:
|
| 280 |
+
json.dump(final_report, out_f, ensure_ascii=False, indent=4)
|
| 281 |
+
|
| 282 |
+
if __name__ == "__main__":
|
| 283 |
+
parser = argparse.ArgumentParser()
|
| 284 |
+
parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
|
| 285 |
+
args = parser.parse_args()
|
| 286 |
+
run_evaluation(args.res_path)
|
Prism/LLaDA/LLaDA_Baseline/metrics/humaneval_all.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import ast
|
| 5 |
+
import traceback
|
| 6 |
+
import glob
|
| 7 |
+
import math
|
| 8 |
+
import argparse
|
| 9 |
+
from typing import Dict, List, Optional, Set, Tuple
|
| 10 |
+
from collections import Counter
|
| 11 |
+
import evaluate as hf_evaluate
|
| 12 |
+
import re
|
| 13 |
+
|
| 14 |
+
RES_PATH = "<PATH_TO_RESULTS_JSONL>"
|
| 15 |
+
|
| 16 |
+
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
| 17 |
+
|
| 18 |
+
def extract_python_code(text: str) -> str:
|
| 19 |
+
if not text: return ""
|
| 20 |
+
|
| 21 |
+
text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "")
|
| 22 |
+
|
| 23 |
+
tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
|
| 24 |
+
if tag_match:
|
| 25 |
+
text = tag_match.group(1)
|
| 26 |
+
|
| 27 |
+
if "```python" in text:
|
| 28 |
+
content = text.split("```python")[-1]
|
| 29 |
+
if "```" in content:
|
| 30 |
+
return content.split("```")[0].strip()
|
| 31 |
+
return content.strip()
|
| 32 |
+
elif "```" in text:
|
| 33 |
+
content = text.split("```")[-1]
|
| 34 |
+
if "```" in content:
|
| 35 |
+
return content.split("```")[0].strip()
|
| 36 |
+
return content.strip()
|
| 37 |
+
|
| 38 |
+
lines = text.split('\n')
|
| 39 |
+
cleaned_lines = []
|
| 40 |
+
stop_words = ["Explanation:", "Example:", "Test Case:", "Output:"]
|
| 41 |
+
for line in lines:
|
| 42 |
+
if any(sw in line for sw in stop_words):
|
| 43 |
+
break
|
| 44 |
+
cleaned_lines.append(line)
|
| 45 |
+
|
| 46 |
+
return "\n".join(cleaned_lines).strip()
|
| 47 |
+
|
| 48 |
+
def normalize_code_for_voting(code: str) -> str:
|
| 49 |
+
try:
|
| 50 |
+
tree = ast.parse(code)
|
| 51 |
+
for node in ast.walk(tree):
|
| 52 |
+
if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
|
| 53 |
+
if (node.body and isinstance(node.body[0], ast.Expr) and
|
| 54 |
+
isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)):
|
| 55 |
+
node.body.pop(0)
|
| 56 |
+
return ast.unparse(tree).strip()
|
| 57 |
+
except:
|
| 58 |
+
return re.sub(r"\s+", "", code)
|
| 59 |
+
|
| 60 |
+
def sanitize(prompt: str, completion: str, entrypoint: str) -> str:
|
| 61 |
+
if f"def {entrypoint}" in completion:
|
| 62 |
+
return completion
|
| 63 |
+
return prompt + "\n" + completion
|
| 64 |
+
|
| 65 |
+
def run_evaluation(target_path):
|
| 66 |
+
if os.path.isdir(target_path):
|
| 67 |
+
jsonl_files = glob.glob(os.path.join(target_path, "**/*.jsonl"), recursive=True)
|
| 68 |
+
else:
|
| 69 |
+
jsonl_files = [target_path]
|
| 70 |
+
|
| 71 |
+
if not jsonl_files:
|
| 72 |
+
print(f"未在路径 {target_path} 下找到任何 .jsonl 文件")
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
print(f"共找到 {len(jsonl_files)} 个评测任务")
|
| 76 |
+
code_eval = hf_evaluate.load("code_eval")
|
| 77 |
+
|
| 78 |
+
for file_path in jsonl_files:
|
| 79 |
+
print(f"\n>>> 正在评测: {file_path}")
|
| 80 |
+
all_predictions = []
|
| 81 |
+
all_references = []
|
| 82 |
+
detailed_results = []
|
| 83 |
+
nfe_list = []
|
| 84 |
+
svf_list = []
|
| 85 |
+
|
| 86 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 87 |
+
lines = f.readlines()
|
| 88 |
+
if not lines: continue
|
| 89 |
+
|
| 90 |
+
for line in lines:
|
| 91 |
+
if not line.strip(): continue
|
| 92 |
+
item = json.loads(line)
|
| 93 |
+
doc = item.get("doc", {})
|
| 94 |
+
prompt = doc.get("prompt", "")
|
| 95 |
+
entry_point = doc.get("entry_point", "")
|
| 96 |
+
reference = doc.get("test", "")
|
| 97 |
+
|
| 98 |
+
current_nfe = item.get("nfe", 0)
|
| 99 |
+
nfe_list.append(current_nfe)
|
| 100 |
+
svf_list.append(item.get("svf_calls", 0))
|
| 101 |
+
|
| 102 |
+
resps = item.get("resps", [])
|
| 103 |
+
candidate_stats = {}
|
| 104 |
+
|
| 105 |
+
for r in resps:
|
| 106 |
+
raw_text = r[0] if isinstance(r, list) else r
|
| 107 |
+
completion = extract_python_code(raw_text)
|
| 108 |
+
full_code = sanitize(prompt, completion, entry_point)
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
ast.parse(full_code)
|
| 112 |
+
is_valid = True
|
| 113 |
+
except:
|
| 114 |
+
is_valid = False
|
| 115 |
+
|
| 116 |
+
logic_norm = normalize_code_for_voting(full_code)
|
| 117 |
+
if not logic_norm: continue
|
| 118 |
+
|
| 119 |
+
if logic_norm not in candidate_stats:
|
| 120 |
+
candidate_stats[logic_norm] = {"count": 0, "valid": is_valid, "code": full_code}
|
| 121 |
+
candidate_stats[logic_norm]["count"] += 1
|
| 122 |
+
|
| 123 |
+
if not candidate_stats:
|
| 124 |
+
voted_code = prompt
|
| 125 |
+
else:
|
| 126 |
+
sorted_logics = sorted(
|
| 127 |
+
candidate_stats.keys(),
|
| 128 |
+
key=lambda k: (candidate_stats[k]["valid"], candidate_stats[k]["count"]),
|
| 129 |
+
reverse=True
|
| 130 |
+
)
|
| 131 |
+
voted_code = candidate_stats[sorted_logics[0]]["code"]
|
| 132 |
+
|
| 133 |
+
all_predictions.append([voted_code])
|
| 134 |
+
all_references.append(reference)
|
| 135 |
+
detailed_results.append({
|
| 136 |
+
"task_id": doc.get("task_id", doc.get("name", "N/A")),
|
| 137 |
+
"voted_code": voted_code,
|
| 138 |
+
"nfe": current_nfe,
|
| 139 |
+
"svf_calls": item.get("svf_calls", 0),
|
| 140 |
+
"candidates_count": len(candidate_stats)
|
| 141 |
+
})
|
| 142 |
+
|
| 143 |
+
if not all_predictions: continue
|
| 144 |
+
|
| 145 |
+
print(f"正在执行代码测试 (共 {len(all_predictions)} 题)...")
|
| 146 |
+
pass_at_k, exec_results = code_eval.compute(
|
| 147 |
+
references=all_references,
|
| 148 |
+
predictions=all_predictions,
|
| 149 |
+
k=[1],
|
| 150 |
+
num_workers=4
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
accuracy = pass_at_k.get("pass@1", 0.0) * 100
|
| 154 |
+
avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0
|
| 155 |
+
avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0
|
| 156 |
+
|
| 157 |
+
print(f"Accuracy: {accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---")
|
| 158 |
+
|
| 159 |
+
output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}"
|
| 160 |
+
output_path = os.path.join(os.path.dirname(file_path), output_name)
|
| 161 |
+
|
| 162 |
+
for i, detail in enumerate(detailed_results):
|
| 163 |
+
res_list = exec_results.get(i, [])
|
| 164 |
+
detail["is_correct"] = res_list[0][1]["passed"] if res_list else False
|
| 165 |
+
|
| 166 |
+
final_report = {
|
| 167 |
+
"summary": {
|
| 168 |
+
"accuracy": f"{accuracy:.2f}%",
|
| 169 |
+
"nfe": avg_nfe,
|
| 170 |
+
"svf_calls": avg_svf
|
| 171 |
+
},
|
| 172 |
+
"details": detailed_results
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
with open(output_path, 'w', encoding='utf-8') as out_f:
|
| 176 |
+
json.dump(final_report, out_f, ensure_ascii=False, indent=4)
|
| 177 |
+
print(f"报告已保存至: {output_path}")
|
| 178 |
+
|
| 179 |
+
if __name__ == "__main__":
|
| 180 |
+
parser = argparse.ArgumentParser()
|
| 181 |
+
parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
|
| 182 |
+
args = parser.parse_args()
|
| 183 |
+
run_evaluation(args.res_path)
|
Prism/LLaDA/LLaDA_Baseline/metrics/math500_all.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import os
|
| 4 |
+
import math
|
| 5 |
+
import argparse
|
| 6 |
+
from collections import Counter
|
| 7 |
+
|
| 8 |
+
RES_PATH = "<PATH_TO_RESULTS_JSONL>"
|
| 9 |
+
|
| 10 |
+
def extract_answer(text):
|
| 11 |
+
if not text:
|
| 12 |
+
return "", False
|
| 13 |
+
text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").strip()
|
| 14 |
+
|
| 15 |
+
boxed_pattern = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
|
| 16 |
+
all_boxes = re.findall(boxed_pattern, text)
|
| 17 |
+
if all_boxes:
|
| 18 |
+
return all_boxes[-1], True
|
| 19 |
+
|
| 20 |
+
tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
|
| 21 |
+
if tag_match:
|
| 22 |
+
return tag_match.group(1).strip(), True
|
| 23 |
+
|
| 24 |
+
marker = "the answer is"
|
| 25 |
+
if marker in text.lower():
|
| 26 |
+
pos = text.lower().rfind(marker)
|
| 27 |
+
after_text = text[pos + len(marker):].strip()
|
| 28 |
+
after_text = re.sub(r"^[:\s]+", "", after_text)
|
| 29 |
+
return after_text.split('\n')[0].split('$')[0].strip(), True
|
| 30 |
+
|
| 31 |
+
tail = text[-50:].strip()
|
| 32 |
+
nums = re.findall(r"(-?\d+[\./\d]*|\\sqrt\{\d+\}|\(-?\d+.*?\))", tail)
|
| 33 |
+
if nums:
|
| 34 |
+
return nums[-1], False
|
| 35 |
+
return "", False
|
| 36 |
+
|
| 37 |
+
def normalize_math(string):
|
| 38 |
+
if not string: return ""
|
| 39 |
+
string = str(string).lower().strip()
|
| 40 |
+
|
| 41 |
+
string = string.replace("</reasoning>", "").replace("</answer>", "").replace("<answer>", "")
|
| 42 |
+
string = string.replace("...", "").replace("cannot be determined", "")
|
| 43 |
+
|
| 44 |
+
string = re.sub(r"([a-z]+|\\theta|\\alpha|\\pi)\s*=\s*", "", string)
|
| 45 |
+
string = re.sub(r"\\text\{([^}]*)\}", r"\1", string)
|
| 46 |
+
string = re.sub(r"\\(mathbf|mathrm|bold|unit|mbox|operatorname|mathrm)\{([^}]*)\}", r"\2", string)
|
| 47 |
+
string = re.sub(r"\\(d|t)?frac\{([^{}]*)\}\{([^{}]*)\}", r"\2/\3", string)
|
| 48 |
+
string = string.replace("\\!", "").replace("\\ ", "").replace("{", "").replace("}", "")
|
| 49 |
+
string = string.replace("\\left", "").replace("\\right", "")
|
| 50 |
+
string = string.replace("\\$", "").replace("$", "").replace("\\%", "").replace("%", "")
|
| 51 |
+
|
| 52 |
+
units_pattern = r"(units?|cm\^2|cm|inches|inch|square|degrees?|radians?|miles?|per|hour|cents?)"
|
| 53 |
+
string = re.sub(units_pattern, "", string)
|
| 54 |
+
string = string.replace("^{\\circ}", "").replace("^\\circ", "").replace("°", "").replace("\\degree", "")
|
| 55 |
+
string = string.replace("\\pi", "pi")
|
| 56 |
+
string = re.sub(r"(\d),(\d{3})", r"\1\2", string)
|
| 57 |
+
string = string.rstrip(".:,; ").replace(" ", "")
|
| 58 |
+
|
| 59 |
+
if "=" in string:
|
| 60 |
+
string = string.split("=")[-1]
|
| 61 |
+
|
| 62 |
+
return string
|
| 63 |
+
|
| 64 |
+
def is_equiv(pred, gold):
|
| 65 |
+
if not pred: return False
|
| 66 |
+
p, g = normalize_math(pred), normalize_math(gold)
|
| 67 |
+
if p == g: return True
|
| 68 |
+
|
| 69 |
+
if "=" in pred:
|
| 70 |
+
if normalize_math(pred.split("=")[-1]) == g:
|
| 71 |
+
return True
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
def to_float(s):
|
| 75 |
+
if '/' in s and s.count('/') == 1:
|
| 76 |
+
parts = s.split('/')
|
| 77 |
+
return float(parts[0]) / float(parts[1])
|
| 78 |
+
if '_' in s: s = s.split('_')[0]
|
| 79 |
+
return float(s)
|
| 80 |
+
return math.isclose(to_float(p), to_float(g), rel_tol=1e-4)
|
| 81 |
+
except:
|
| 82 |
+
p_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", p)
|
| 83 |
+
g_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", g)
|
| 84 |
+
return p_fuzzy == g_fuzzy if p_fuzzy else False
|
| 85 |
+
|
| 86 |
+
def run_evaluation(target_path):
|
| 87 |
+
jsonl_files = []
|
| 88 |
+
if os.path.isdir(target_path):
|
| 89 |
+
for root, dirs, files in os.walk(target_path):
|
| 90 |
+
for file in files:
|
| 91 |
+
if file.endswith(".jsonl") and not file.startswith("eval_voted_"):
|
| 92 |
+
jsonl_files.append(os.path.join(root, file))
|
| 93 |
+
else:
|
| 94 |
+
jsonl_files = [target_path]
|
| 95 |
+
|
| 96 |
+
for file_path in jsonl_files:
|
| 97 |
+
print(f">>> 正在评测: {file_path}")
|
| 98 |
+
detailed_results = []
|
| 99 |
+
|
| 100 |
+
voted_correct_count = 0
|
| 101 |
+
pass_at_k_count = 0
|
| 102 |
+
total_count = 0
|
| 103 |
+
|
| 104 |
+
nfe_list = []
|
| 105 |
+
svf_list = []
|
| 106 |
+
|
| 107 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 108 |
+
for line in f:
|
| 109 |
+
if not line.strip(): continue
|
| 110 |
+
try:
|
| 111 |
+
item = json.loads(line)
|
| 112 |
+
except:
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
doc = item.get("doc", {})
|
| 116 |
+
ground_truth = str(item.get("target", doc.get("answer", "")))
|
| 117 |
+
|
| 118 |
+
current_nfe = item.get("nfe", 0)
|
| 119 |
+
nfe_list.append(current_nfe)
|
| 120 |
+
current_svf = item.get("svf_calls", 0)
|
| 121 |
+
svf_list.append(current_svf)
|
| 122 |
+
|
| 123 |
+
ans_stats = {}
|
| 124 |
+
trajectories = item.get("all_trajectories", [])
|
| 125 |
+
|
| 126 |
+
has_correct_trajectory = False
|
| 127 |
+
|
| 128 |
+
for traj in trajectories:
|
| 129 |
+
raw_text = traj.get("resp", "")
|
| 130 |
+
score = traj.get("score", 0)
|
| 131 |
+
|
| 132 |
+
extracted, _ = extract_answer(raw_text)
|
| 133 |
+
if not extracted: continue
|
| 134 |
+
|
| 135 |
+
if is_equiv(extracted, ground_truth):
|
| 136 |
+
has_correct_trajectory = True
|
| 137 |
+
|
| 138 |
+
norm = normalize_math(extracted)
|
| 139 |
+
if norm not in ans_stats:
|
| 140 |
+
ans_stats[norm] = {
|
| 141 |
+
"count": 0,
|
| 142 |
+
"max_score": -float('inf'),
|
| 143 |
+
"total_weight": 0.0,
|
| 144 |
+
"original": extracted
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
ans_stats[norm]["count"] += 1
|
| 148 |
+
if score > ans_stats[norm]["max_score"]:
|
| 149 |
+
ans_stats[norm]["max_score"] = score
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
weight = math.exp(score)
|
| 153 |
+
except OverflowError:
|
| 154 |
+
weight = float('inf')
|
| 155 |
+
ans_stats[norm]["total_weight"] += weight
|
| 156 |
+
|
| 157 |
+
if has_correct_trajectory:
|
| 158 |
+
pass_at_k_count += 1
|
| 159 |
+
|
| 160 |
+
if not ans_stats:
|
| 161 |
+
best_pred = ""
|
| 162 |
+
else:
|
| 163 |
+
sorted_norms = sorted(
|
| 164 |
+
ans_stats.keys(),
|
| 165 |
+
key=lambda x: (ans_stats[x]["total_weight"], ans_stats[x]["max_score"], ans_stats[x]["count"]),
|
| 166 |
+
reverse=True
|
| 167 |
+
)
|
| 168 |
+
best_norm = sorted_norms[0]
|
| 169 |
+
best_pred = ans_stats[best_norm]["original"]
|
| 170 |
+
|
| 171 |
+
is_voted_correct = False
|
| 172 |
+
if best_pred and is_equiv(best_pred, ground_truth):
|
| 173 |
+
voted_correct_count += 1
|
| 174 |
+
is_voted_correct = True
|
| 175 |
+
|
| 176 |
+
total_count += 1
|
| 177 |
+
|
| 178 |
+
detailed_results.append({
|
| 179 |
+
"question": doc.get("problem", "N/A"),
|
| 180 |
+
"final_voted_answer": best_pred,
|
| 181 |
+
"ground_truth": ground_truth,
|
| 182 |
+
"is_voted_correct": is_voted_correct,
|
| 183 |
+
"nfe": current_nfe,
|
| 184 |
+
"svf_calls": current_svf
|
| 185 |
+
})
|
| 186 |
+
|
| 187 |
+
pass_at_1_accuracy = (voted_correct_count / total_count * 100) if total_count > 0 else 0
|
| 188 |
+
avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0
|
| 189 |
+
avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0
|
| 190 |
+
|
| 191 |
+
print(f"--- Accuracy: {pass_at_1_accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---")
|
| 192 |
+
|
| 193 |
+
output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}"
|
| 194 |
+
output_path = os.path.join(os.path.dirname(file_path), output_name)
|
| 195 |
+
|
| 196 |
+
final_report = {
|
| 197 |
+
"summary": {
|
| 198 |
+
"accuracy": f"{pass_at_1_accuracy:.2f}%",
|
| 199 |
+
"correct_voted_count": voted_correct_count,
|
| 200 |
+
"total": total_count,
|
| 201 |
+
"nfe": avg_nfe,
|
| 202 |
+
"svf_calls": avg_svf
|
| 203 |
+
},
|
| 204 |
+
"details": detailed_results
|
| 205 |
+
}
|
| 206 |
+
with open(output_path, 'w', encoding='utf-8') as out_f:
|
| 207 |
+
json.dump(final_report, out_f, ensure_ascii=False, indent=4)
|
| 208 |
+
|
| 209 |
+
if __name__ == "__main__":
|
| 210 |
+
parser = argparse.ArgumentParser()
|
| 211 |
+
parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
|
| 212 |
+
args = parser.parse_args()
|
| 213 |
+
run_evaluation(args.res_path)
|
Prism/LLaDA/LLaDA_Baseline/metrics/mbpp_all.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import ast
|
| 4 |
+
import glob
|
| 5 |
+
import re
|
| 6 |
+
import argparse
|
| 7 |
+
from typing import Dict, List, Optional, Set, Tuple
|
| 8 |
+
import evaluate as hf_evaluate
|
| 9 |
+
|
| 10 |
+
RES_PATH = "<PATH_TO_RESULTS_JSONL>"
|
| 11 |
+
|
| 12 |
+
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
| 13 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 14 |
+
|
| 15 |
+
def extract_python_code(text: str) -> str:
|
| 16 |
+
if not text: return ""
|
| 17 |
+
|
| 18 |
+
text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "")
|
| 19 |
+
|
| 20 |
+
tag_matches = re.findall(r"<answer>(.*?)</answer>", text, re.DOTALL)
|
| 21 |
+
if tag_matches:
|
| 22 |
+
for block in tag_matches:
|
| 23 |
+
if "def " in block:
|
| 24 |
+
text = block
|
| 25 |
+
break
|
| 26 |
+
else:
|
| 27 |
+
text = tag_matches[0]
|
| 28 |
+
|
| 29 |
+
if "```python" in text:
|
| 30 |
+
blocks = text.split("```python")
|
| 31 |
+
for b in blocks[1:]:
|
| 32 |
+
code = b.split("```")[0].strip()
|
| 33 |
+
if "def " in code: return code
|
| 34 |
+
elif "```" in text:
|
| 35 |
+
blocks = text.split("```")
|
| 36 |
+
for b in blocks[1:]:
|
| 37 |
+
code = b.strip()
|
| 38 |
+
if "def " in code: return code
|
| 39 |
+
|
| 40 |
+
lines = text.split('\n')
|
| 41 |
+
cleaned_lines = []
|
| 42 |
+
stop_words = ["Explanation:", "Example:", "Test Case:", "Output:", "Reasoning:"]
|
| 43 |
+
for line in lines:
|
| 44 |
+
if any(sw in line for sw in stop_words): break
|
| 45 |
+
cleaned_lines.append(line)
|
| 46 |
+
|
| 47 |
+
return "\n".join(cleaned_lines).strip()
|
| 48 |
+
|
| 49 |
+
def normalize_code_for_voting(code: str) -> str:
|
| 50 |
+
try:
|
| 51 |
+
tree = ast.parse(code)
|
| 52 |
+
for node in ast.walk(tree):
|
| 53 |
+
if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
|
| 54 |
+
if (node.body and isinstance(node.body[0], ast.Expr) and
|
| 55 |
+
isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)):
|
| 56 |
+
node.body.pop(0)
|
| 57 |
+
return ast.unparse(tree).strip()
|
| 58 |
+
except:
|
| 59 |
+
return re.sub(r"\s+", "", code)
|
| 60 |
+
|
| 61 |
+
def run_evaluation(target_path):
|
| 62 |
+
target_path = os.path.abspath(target_path)
|
| 63 |
+
|
| 64 |
+
if os.path.isdir(target_path):
|
| 65 |
+
search_pattern = os.path.join(target_path, "**/*.jsonl")
|
| 66 |
+
jsonl_files = glob.glob(search_pattern, recursive=True)
|
| 67 |
+
jsonl_files = [f for f in jsonl_files if not os.path.basename(f).startswith("eval_mbpp_")]
|
| 68 |
+
else:
|
| 69 |
+
jsonl_files = [target_path]
|
| 70 |
+
|
| 71 |
+
if not jsonl_files:
|
| 72 |
+
print(f"Error: 在路径 {target_path} 及其子目录下未找到任何 .jsonl 文件。")
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
code_eval = hf_evaluate.load("code_eval")
|
| 77 |
+
except:
|
| 78 |
+
print("Error: Could not load code_eval. Ensure 'evaluate' and 'code_eval' are installed.")
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
for file_path in jsonl_files:
|
| 82 |
+
print(f"\n>>> 正在评测 MBPP 文件: {file_path}")
|
| 83 |
+
all_candidate_predictions = []
|
| 84 |
+
all_voted_predictions = []
|
| 85 |
+
all_references = []
|
| 86 |
+
detailed_results = []
|
| 87 |
+
nfe_list = []
|
| 88 |
+
svf_list = []
|
| 89 |
+
|
| 90 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 91 |
+
for line in f:
|
| 92 |
+
if not line.strip(): continue
|
| 93 |
+
item = json.loads(line)
|
| 94 |
+
|
| 95 |
+
doc = item.get("doc", {})
|
| 96 |
+
test_list = doc.get("test_list", [])
|
| 97 |
+
test_setup = doc.get("test_setup_code", "")
|
| 98 |
+
full_reference = (test_setup + "\n" + "\n".join(test_list)).strip()
|
| 99 |
+
|
| 100 |
+
item_nfe = item.get("nfe", 0)
|
| 101 |
+
item_svf = item.get("svf_calls", 0)
|
| 102 |
+
nfe_list.append(item_nfe)
|
| 103 |
+
svf_list.append(item_svf)
|
| 104 |
+
|
| 105 |
+
resps = item.get("resps", [])
|
| 106 |
+
trajs = item.get("all_trajectories", [])
|
| 107 |
+
|
| 108 |
+
candidate_stats = {}
|
| 109 |
+
processed_candidates = []
|
| 110 |
+
|
| 111 |
+
source_data = trajs if trajs else resps
|
| 112 |
+
for idx, entry in enumerate(source_data):
|
| 113 |
+
raw_text = entry.get("resp", "") if isinstance(entry, dict) else (entry[0] if isinstance(entry, list) else entry)
|
| 114 |
+
score = entry.get("score", 0) if isinstance(entry, dict) else 0
|
| 115 |
+
|
| 116 |
+
code = extract_python_code(raw_text)
|
| 117 |
+
if not code: continue
|
| 118 |
+
|
| 119 |
+
processed_candidates.append(code)
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
ast.parse(code)
|
| 123 |
+
is_valid = True
|
| 124 |
+
except:
|
| 125 |
+
is_valid = False
|
| 126 |
+
|
| 127 |
+
norm = normalize_code_for_voting(code)
|
| 128 |
+
if norm not in candidate_stats:
|
| 129 |
+
candidate_stats[norm] = {"count": 0, "valid": is_valid, "code": code, "max_score": -float('inf')}
|
| 130 |
+
candidate_stats[norm]["count"] += 1
|
| 131 |
+
candidate_stats[norm]["max_score"] = max(candidate_stats[norm]["max_score"], score)
|
| 132 |
+
|
| 133 |
+
if not candidate_stats:
|
| 134 |
+
voted_code = ""
|
| 135 |
+
else:
|
| 136 |
+
sorted_norms = sorted(
|
| 137 |
+
candidate_stats.keys(),
|
| 138 |
+
key=lambda k: (candidate_stats[k]["valid"], candidate_stats[k]["max_score"], candidate_stats[k]["count"]),
|
| 139 |
+
reverse=True
|
| 140 |
+
)
|
| 141 |
+
voted_code = candidate_stats[sorted_norms[0]]["code"]
|
| 142 |
+
|
| 143 |
+
all_candidate_predictions.append(processed_candidates if processed_candidates else [""])
|
| 144 |
+
all_voted_predictions.append([voted_code])
|
| 145 |
+
all_references.append(full_reference)
|
| 146 |
+
|
| 147 |
+
detailed_results.append({
|
| 148 |
+
"task_id": doc.get("task_id", "N/A"),
|
| 149 |
+
"voted_code": voted_code,
|
| 150 |
+
"nfe": item_nfe,
|
| 151 |
+
"svf_calls": item_svf,
|
| 152 |
+
"candidates_count": len(processed_candidates)
|
| 153 |
+
})
|
| 154 |
+
|
| 155 |
+
if not all_voted_predictions:
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
print(f"正在测试代码 (共 {len(all_voted_predictions)} 题)...")
|
| 159 |
+
res_voted, details_voted = code_eval.compute(references=all_references, predictions=all_voted_predictions, k=[1])
|
| 160 |
+
res_pk, details_pk = code_eval.compute(references=all_references, predictions=all_candidate_predictions, k=[1])
|
| 161 |
+
|
| 162 |
+
acc_voted = res_voted.get("pass@1", 0.0) * 100
|
| 163 |
+
acc_pk = res_pk.get("pass@1", 0.0) * 100
|
| 164 |
+
avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0
|
| 165 |
+
avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0
|
| 166 |
+
|
| 167 |
+
print(f"--- Pass@1: {acc_voted:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---")
|
| 168 |
+
|
| 169 |
+
for i, detail in enumerate(detailed_results):
|
| 170 |
+
detail["is_voted_correct"] = details_voted.get(i, [[0, {"passed": False}]])[0][1]["passed"]
|
| 171 |
+
|
| 172 |
+
file_dir = os.path.dirname(file_path)
|
| 173 |
+
base_name = os.path.basename(file_path)
|
| 174 |
+
output_name = f"eval_mbpp_{base_name.replace('.jsonl', '.json')}"
|
| 175 |
+
output_path = os.path.join(file_dir, output_name)
|
| 176 |
+
|
| 177 |
+
final_report = {
|
| 178 |
+
"summary": {
|
| 179 |
+
"pass_at_1": f"{acc_voted:.2f}%",
|
| 180 |
+
"avg_nfe": avg_nfe,
|
| 181 |
+
"avg_svf": avg_svf
|
| 182 |
+
},
|
| 183 |
+
"details": detailed_results
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
with open(output_path, 'w', encoding='utf-8') as out_f:
|
| 187 |
+
json.dump(final_report, out_f, ensure_ascii=False, indent=4)
|
| 188 |
+
print(f"成功保存结果至: {output_path}")
|
| 189 |
+
|
| 190 |
+
if __name__ == "__main__":
|
| 191 |
+
parser = argparse.ArgumentParser()
|
| 192 |
+
parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
|
| 193 |
+
args = parser.parse_args()
|
| 194 |
+
run_evaluation(args.res_path)
|
Prism/LLaDA/LLaDA_Baseline/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sacrebleu
|
| 2 |
+
evaluate
|
| 3 |
+
datasets
|
| 4 |
+
numpy
|
| 5 |
+
pandas
|
| 6 |
+
tqdm
|
| 7 |
+
regex
|
| 8 |
+
sqlitedict
|
| 9 |
+
pytablewriter
|
Prism/LLaDA/LLaDA_Baseline/scripts/run_gsm8k.sh
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
PROJECT_ROOT="<PATH_TO_YOUR_ROOT>"
|
| 7 |
+
cd "$PROJECT_ROOT"
|
| 8 |
+
|
| 9 |
+
MODEL_PATH="<PATH_TO_YOUR_LLaDA_8B_INSTRUCT_WEIGHTS>"
|
| 10 |
+
|
| 11 |
+
BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/results_gsm8k"
|
| 12 |
+
|
| 13 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 14 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 15 |
+
|
| 16 |
+
LENGTH=256
|
| 17 |
+
STEPS=32
|
| 18 |
+
BLOCK=32
|
| 19 |
+
TASK="gsm8k"
|
| 20 |
+
NAME="baseline"
|
| 21 |
+
|
| 22 |
+
mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
|
| 23 |
+
|
| 24 |
+
accelerate launch evaluation_script.py \
|
| 25 |
+
--model LLaDA \
|
| 26 |
+
--tasks ${TASK} \
|
| 27 |
+
--batch_size 1 \
|
| 28 |
+
--model_args "pretrained=${MODEL_PATH},mask_id=126336,assistant_prefix=<reasoning>" \
|
| 29 |
+
--gen_kwargs "use_hts=True,hts_N=1,hts_mode=False,steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},task_type=math,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/baseline.jsonl" \
|
| 30 |
+
--num_fewshot 0 \
|
| 31 |
+
--confirm_run_unsafe_code \
|
| 32 |
+
--output_path "${BASE_OUTPUT_PATH}/${NAME}"
|
Prism/LLaDA/LLaDA_Baseline/scripts/run_humaneval.sh
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
PROJECT_ROOT="<PATH_TO_YOUR_ROOT>"
|
| 6 |
+
MODEL_PATH="<PATH_TO_YOUR_LLaDA_8B_INSTRUCT_WEIGHTS>"
|
| 7 |
+
BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/results_humaneval"
|
| 8 |
+
|
| 9 |
+
cd "$PROJECT_ROOT"
|
| 10 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
| 11 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 12 |
+
|
| 13 |
+
LENGTH=512
|
| 14 |
+
STEPS=32
|
| 15 |
+
BLOCK=32
|
| 16 |
+
TASK="humaneval"
|
| 17 |
+
NAME="baseline"
|
| 18 |
+
|
| 19 |
+
mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
|
| 20 |
+
|
| 21 |
+
accelerate launch evaluation_script.py \
|
| 22 |
+
--model LLaDA \
|
| 23 |
+
--tasks ${TASK} \
|
| 24 |
+
--batch_size 1 \
|
| 25 |
+
--model_args "pretrained=${MODEL_PATH},mask_id=126336,assistant_prefix=<reasoning>" \
|
| 26 |
+
--gen_kwargs "use_hts=True,hts_N=1,hts_mode=False,steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},task_type=code,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/baseline.jsonl" \
|
| 27 |
+
--num_fewshot 0 \
|
| 28 |
+
--confirm_run_unsafe_code \
|
| 29 |
+
--output_path "${BASE_OUTPUT_PATH}/${NAME}"
|
Prism/LLaDA/LLaDA_Baseline/scripts/run_math500.sh
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
PROJECT_ROOT="<PATH_TO_YOUR_PRISM_ROOT>"
|
| 6 |
+
MODEL_PATH="<PATH_TO_YOUR_LLaDA_8B_INSTRUCT_WEIGHTS>"
|
| 7 |
+
BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/results_math500"
|
| 8 |
+
|
| 9 |
+
cd "$PROJECT_ROOT"
|
| 10 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 11 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 12 |
+
|
| 13 |
+
LENGTH=256
|
| 14 |
+
STEPS=32
|
| 15 |
+
BLOCK=32
|
| 16 |
+
TASK="math500"
|
| 17 |
+
NAME="baseline"
|
| 18 |
+
|
| 19 |
+
mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
|
| 20 |
+
|
| 21 |
+
accelerate launch evaluation_script.py \
|
| 22 |
+
--model LLaDA \
|
| 23 |
+
--tasks ${TASK} \
|
| 24 |
+
--batch_size 1 \
|
| 25 |
+
--model_args "pretrained=${MODEL_PATH},mask_id=126336,assistant_prefix=<reasoning>" \
|
| 26 |
+
--gen_kwargs "use_hts=True,hts_N=1,hts_mode=False,steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},task_type=math,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/baseline.jsonl" \
|
| 27 |
+
--num_fewshot 0 \
|
| 28 |
+
--confirm_run_unsafe_code \
|
| 29 |
+
--output_path "${BASE_OUTPUT_PATH}/${NAME}"
|
Prism/LLaDA/LLaDA_Baseline/scripts/run_mbpp.sh
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
PROJECT_ROOT="<PATH_TO_YOUR_PRISM_ROOT>"
|
| 6 |
+
MODEL_PATH="<PATH_TO_YOUR_LLaDA_8B_INSTRUCT_WEIGHTS>"
|
| 7 |
+
BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/results_mbpp_k4"
|
| 8 |
+
|
| 9 |
+
cd "$PROJECT_ROOT"
|
| 10 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 11 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
| 12 |
+
|
| 13 |
+
LENGTH=512
|
| 14 |
+
STEPS=32
|
| 15 |
+
BLOCK=32
|
| 16 |
+
TASK="mbpp"
|
| 17 |
+
NAME="baseline"
|
| 18 |
+
|
| 19 |
+
mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
|
| 20 |
+
|
| 21 |
+
accelerate launch evaluation_script.py \
|
| 22 |
+
--model LLaDA \
|
| 23 |
+
--tasks ${TASK} \
|
| 24 |
+
--batch_size 1 \
|
| 25 |
+
--model_args "pretrained=${MODEL_PATH},mask_id=126336,assistant_prefix=<reasoning>" \
|
| 26 |
+
--gen_kwargs "use_hts=True,hts_N=1,hts_mode=False,steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},task_type=math,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/baseline.jsonl" \
|
| 27 |
+
--num_fewshot 0 \
|
| 28 |
+
--confirm_run_unsafe_code \
|
| 29 |
+
--output_path "${BASE_OUTPUT_PATH}/${NAME}"
|
Prism/LLaDA/LLaDA_Prism/.gitignore
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.jsonl
|
| 2 |
+
*.json
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[codz]
|
| 7 |
+
*$py.class
|
| 8 |
+
|
| 9 |
+
# C extensions
|
| 10 |
+
*.so
|
| 11 |
+
|
| 12 |
+
# Distribution / packaging
|
| 13 |
+
.Python
|
| 14 |
+
build/
|
| 15 |
+
develop-eggs/
|
| 16 |
+
dist/
|
| 17 |
+
downloads/
|
| 18 |
+
eggs/
|
| 19 |
+
.eggs/
|
| 20 |
+
lib/
|
| 21 |
+
lib64/
|
| 22 |
+
parts/
|
| 23 |
+
sdist/
|
| 24 |
+
var/
|
| 25 |
+
wheels/
|
| 26 |
+
share/python-wheels/
|
| 27 |
+
*.egg-info/
|
| 28 |
+
.installed.cfg
|
| 29 |
+
*.egg
|
| 30 |
+
MANIFEST
|
| 31 |
+
|
| 32 |
+
# PyInstaller
|
| 33 |
+
# Usually these files are written by a python script from a template
|
| 34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 35 |
+
*.manifest
|
| 36 |
+
*.spec
|
| 37 |
+
|
| 38 |
+
# Installer logs
|
| 39 |
+
pip-log.txt
|
| 40 |
+
pip-delete-this-directory.txt
|
| 41 |
+
|
| 42 |
+
# Unit test / coverage reports
|
| 43 |
+
htmlcov/
|
| 44 |
+
.tox/
|
| 45 |
+
.nox/
|
| 46 |
+
.coverage
|
| 47 |
+
.coverage.*
|
| 48 |
+
.cache
|
| 49 |
+
nosetests.xml
|
| 50 |
+
coverage.xml
|
| 51 |
+
*.cover
|
| 52 |
+
*.py.cover
|
| 53 |
+
.hypothesis/
|
| 54 |
+
.pytest_cache/
|
| 55 |
+
cover/
|
| 56 |
+
|
| 57 |
+
# Translations
|
| 58 |
+
*.mo
|
| 59 |
+
*.pot
|
| 60 |
+
|
| 61 |
+
# Django stuff:
|
| 62 |
+
*.log
|
| 63 |
+
local_settings.py
|
| 64 |
+
db.sqlite3
|
| 65 |
+
db.sqlite3-journal
|
| 66 |
+
|
| 67 |
+
# Flask stuff:
|
| 68 |
+
instance/
|
| 69 |
+
.webassets-cache
|
| 70 |
+
|
| 71 |
+
# Scrapy stuff:
|
| 72 |
+
.scrapy
|
| 73 |
+
|
| 74 |
+
# Sphinx documentation
|
| 75 |
+
docs/_build/
|
| 76 |
+
|
| 77 |
+
# PyBuilder
|
| 78 |
+
.pybuilder/
|
| 79 |
+
target/
|
| 80 |
+
|
| 81 |
+
# Jupyter Notebook
|
| 82 |
+
.ipynb_checkpoints
|
| 83 |
+
|
| 84 |
+
# IPython
|
| 85 |
+
profile_default/
|
| 86 |
+
ipython_config.py
|
| 87 |
+
|
| 88 |
+
# pyenv
|
| 89 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 91 |
+
# .python-version
|
| 92 |
+
|
| 93 |
+
# pipenv
|
| 94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 97 |
+
# install all needed dependencies.
|
| 98 |
+
#Pipfile.lock
|
| 99 |
+
|
| 100 |
+
# UV
|
| 101 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 103 |
+
# commonly ignored for libraries.
|
| 104 |
+
#uv.lock
|
| 105 |
+
|
| 106 |
+
# poetry
|
| 107 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 108 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 109 |
+
# commonly ignored for libraries.
|
| 110 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 111 |
+
#poetry.lock
|
| 112 |
+
#poetry.toml
|
| 113 |
+
|
| 114 |
+
# pdm
|
| 115 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 116 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 117 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 118 |
+
#pdm.lock
|
| 119 |
+
#pdm.toml
|
| 120 |
+
.pdm-python
|
| 121 |
+
.pdm-build/
|
| 122 |
+
|
| 123 |
+
# pixi
|
| 124 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 125 |
+
#pixi.lock
|
| 126 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 127 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 128 |
+
.pixi
|
| 129 |
+
|
| 130 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 131 |
+
__pypackages__/
|
| 132 |
+
|
| 133 |
+
# Celery stuff
|
| 134 |
+
celerybeat-schedule
|
| 135 |
+
celerybeat.pid
|
| 136 |
+
|
| 137 |
+
# SageMath parsed files
|
| 138 |
+
*.sage.py
|
| 139 |
+
|
| 140 |
+
# Environments
|
| 141 |
+
.env
|
| 142 |
+
.envrc
|
| 143 |
+
.venv
|
| 144 |
+
env/
|
| 145 |
+
venv/
|
| 146 |
+
ENV/
|
| 147 |
+
env.bak/
|
| 148 |
+
venv.bak/
|
| 149 |
+
|
| 150 |
+
# Spyder project settings
|
| 151 |
+
.spyderproject
|
| 152 |
+
.spyproject
|
| 153 |
+
|
| 154 |
+
# Rope project settings
|
| 155 |
+
.ropeproject
|
| 156 |
+
|
| 157 |
+
# mkdocs documentation
|
| 158 |
+
/site
|
| 159 |
+
|
| 160 |
+
# mypy
|
| 161 |
+
.mypy_cache/
|
| 162 |
+
.dmypy.json
|
| 163 |
+
dmypy.json
|
| 164 |
+
|
| 165 |
+
# Pyre type checker
|
| 166 |
+
.pyre/
|
| 167 |
+
|
| 168 |
+
# pytype static type analyzer
|
| 169 |
+
.pytype/
|
| 170 |
+
|
| 171 |
+
# Cython debug symbols
|
| 172 |
+
cython_debug/
|
| 173 |
+
|
| 174 |
+
# PyCharm
|
| 175 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 176 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 177 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 178 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 179 |
+
#.idea/
|
| 180 |
+
|
| 181 |
+
# Abstra
|
| 182 |
+
# Abstra is an AI-powered process automation framework.
|
| 183 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 184 |
+
# Learn more at https://abstra.io/docs
|
| 185 |
+
.abstra/
|
| 186 |
+
|
| 187 |
+
# Visual Studio Code
|
| 188 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 189 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 190 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 191 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 192 |
+
# .vscode/
|
| 193 |
+
|
| 194 |
+
# Ruff stuff:
|
| 195 |
+
.ruff_cache/
|
| 196 |
+
|
| 197 |
+
# PyPI configuration file
|
| 198 |
+
.pypirc
|
| 199 |
+
|
| 200 |
+
# Cursor
|
| 201 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 202 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 203 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 204 |
+
.cursorignore
|
| 205 |
+
.cursorindexingignore
|
| 206 |
+
|
| 207 |
+
# Marimo
|
| 208 |
+
marimo/_static/
|
| 209 |
+
marimo/_lsp/
|
| 210 |
+
__marimo__/
|
Prism/LLaDA/LLaDA_Prism/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 preordinary
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
Prism/LLaDA/LLaDA_Prism/evaluation_script.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
from dllm_eval.__main__ import cli_evaluate
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def set_seed(seed):
|
| 9 |
+
torch.manual_seed(seed)
|
| 10 |
+
random.seed(seed)
|
| 11 |
+
np.random.seed(seed)
|
| 12 |
+
|
| 13 |
+
torch.backends.cudnn.deterministic = True
|
| 14 |
+
torch.backends.cudnn.benchmark = False
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
if __name__ == "__main__":
|
| 18 |
+
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
| 19 |
+
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
|
| 20 |
+
set_seed(42)
|
| 21 |
+
cli_evaluate()
|
Prism/LLaDA/LLaDA_Prism/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sacrebleu
|
| 2 |
+
evaluate
|
| 3 |
+
datasets
|
| 4 |
+
numpy
|
| 5 |
+
pandas
|
| 6 |
+
tqdm
|
| 7 |
+
regex
|
| 8 |
+
sqlitedict
|
| 9 |
+
pytablewriter
|
Prism/LLaDA/LLaDA_Truthfulqa/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
outputs
|
| 2 |
+
logs
|
| 3 |
+
LLaDA-8B-Instruct
|
Prism/LLaDA/LLaDA_Truthfulqa/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
Prism/LLaDA/LLaDA_Truthfulqa/eval_llada.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
This file is inspired by the code from https://github.com/ML-GSAI/SMDM
|
| 3 |
+
'''
|
| 4 |
+
import accelerate
|
| 5 |
+
import torch
|
| 6 |
+
import re
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import random
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from datasets import Dataset
|
| 12 |
+
from lm_eval.__main__ import cli_evaluate
|
| 13 |
+
from lm_eval.api.instance import Instance
|
| 14 |
+
from lm_eval.api.model import LM
|
| 15 |
+
from lm_eval.api.registry import register_model
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
from transformers import AutoTokenizer, AutoModel
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import time
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def set_seed(seed):
|
| 25 |
+
torch.manual_seed(seed)
|
| 26 |
+
random.seed(seed)
|
| 27 |
+
np.random.seed(seed)
|
| 28 |
+
|
| 29 |
+
torch.backends.cudnn.deterministic = True
|
| 30 |
+
torch.backends.cudnn.benchmark = False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _sample_categorical(categorical_probs):
|
| 34 |
+
gumbel_norm = (
|
| 35 |
+
1e-10
|
| 36 |
+
- (torch.rand_like(categorical_probs) + 1e-10).log()).to(categorical_probs.dtype)
|
| 37 |
+
return (categorical_probs / gumbel_norm).argmax(dim=-1)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@register_model("llada_dist")
|
| 41 |
+
class LLaDAEvalHarness(LM):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
model_path='',
|
| 45 |
+
mask_id=126336,
|
| 46 |
+
max_length=4096,
|
| 47 |
+
generated_samples_path='',
|
| 48 |
+
batch_size=32,
|
| 49 |
+
mc_num=128,
|
| 50 |
+
is_check_greedy=True,
|
| 51 |
+
cfg=0.,
|
| 52 |
+
sampling_steps=512,
|
| 53 |
+
mask_length=512,
|
| 54 |
+
block_size=32,
|
| 55 |
+
remasking='low_confidence',
|
| 56 |
+
device="cuda",
|
| 57 |
+
sampler='',
|
| 58 |
+
remdm_number=0
|
| 59 |
+
):
|
| 60 |
+
'''
|
| 61 |
+
Args:
|
| 62 |
+
model_path: LLaDA-8B-Base model path.
|
| 63 |
+
mask_id: The token id of [MASK] is 126336.
|
| 64 |
+
max_length: the max sequence length.
|
| 65 |
+
batch_size: mini batch size.
|
| 66 |
+
mc_num: Monte Carlo estimation iterations
|
| 67 |
+
is_check_greedy: For certain metrics like LAMBADA, the evaluation requires the model to verify whether the answer
|
| 68 |
+
is generated through greedy sampling conditioned on the prompt (note that this differs from conditional
|
| 69 |
+
generation). We implement this verification through the suffix_greedy_prediction() function, which
|
| 70 |
+
returns a True/False judgment used for accuracy calculation.
|
| 71 |
+
When is_check_greedy is set to True, the lm-evaluation-harness library automatically invokes this function.
|
| 72 |
+
However, since none of the metrics in the LLaDA paper (https://arxiv.org/abs/2502.09992) require this functionality,
|
| 73 |
+
we recommend setting is_check_greedy to False. This configuration causes suffix_greedy_prediction() to return False
|
| 74 |
+
by default, significantly accelerating the evaluation process.
|
| 75 |
+
cfg_scale: Unsupervised classifier-free guidance scale.
|
| 76 |
+
'''
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
accelerator = accelerate.Accelerator()
|
| 80 |
+
if accelerator.num_processes > 1:
|
| 81 |
+
self.accelerator = accelerator
|
| 82 |
+
else:
|
| 83 |
+
self.accelerator = None
|
| 84 |
+
|
| 85 |
+
model_kwargs = {}
|
| 86 |
+
if self.accelerator is not None:
|
| 87 |
+
model_kwargs.update({'device_map': {'': f'{self.accelerator.device}'}})
|
| 88 |
+
|
| 89 |
+
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, **model_kwargs)
|
| 90 |
+
self.model.eval()
|
| 91 |
+
|
| 92 |
+
self.device = torch.device(device)
|
| 93 |
+
if self.accelerator is not None:
|
| 94 |
+
self.model = self.accelerator.prepare(self.model)
|
| 95 |
+
self.device = torch.device(f'{self.accelerator.device}')
|
| 96 |
+
self._rank = self.accelerator.local_process_index
|
| 97 |
+
self._world_size = self.accelerator.num_processes
|
| 98 |
+
else:
|
| 99 |
+
self.model = self.model.to(device)
|
| 100 |
+
|
| 101 |
+
self.mask_id = mask_id
|
| 102 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 103 |
+
|
| 104 |
+
self.mc_num = mc_num
|
| 105 |
+
self.batch_size = int(batch_size)
|
| 106 |
+
assert mc_num % self.batch_size == 0
|
| 107 |
+
self.sampling_eps = 0.
|
| 108 |
+
self.max_length = max_length
|
| 109 |
+
self.is_check_greedy = is_check_greedy
|
| 110 |
+
|
| 111 |
+
self.generated_samples_path = generated_samples_path
|
| 112 |
+
self.sampler = sampler
|
| 113 |
+
self.remdm_number = remdm_number
|
| 114 |
+
|
| 115 |
+
self.cfg = cfg
|
| 116 |
+
self.sampling_steps = sampling_steps
|
| 117 |
+
self.mask_length = mask_length
|
| 118 |
+
self.block_size = block_size
|
| 119 |
+
self.remasking = remasking
|
| 120 |
+
print(self.generated_samples_path)
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def rank(self):
|
| 124 |
+
return self._rank
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def world_size(self):
|
| 128 |
+
return self._world_size
|
| 129 |
+
|
| 130 |
+
def _forward_process(self, batch, prompt_index):
|
| 131 |
+
b, l = batch.shape
|
| 132 |
+
|
| 133 |
+
target_len = (l - prompt_index.sum()).item()
|
| 134 |
+
k = torch.randint(1, target_len + 1, (), device=batch.device)
|
| 135 |
+
|
| 136 |
+
x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long()
|
| 137 |
+
x = ((x - 1) % target_len) + 1
|
| 138 |
+
assert x.min() >= 1 and x.max() <= target_len
|
| 139 |
+
|
| 140 |
+
indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
|
| 141 |
+
is_mask = indices < x.unsqueeze(1)
|
| 142 |
+
|
| 143 |
+
for i in range(b):
|
| 144 |
+
is_mask[i] = is_mask[i][torch.randperm(target_len)]
|
| 145 |
+
|
| 146 |
+
is_mask = torch.cat((torch.zeros(b, prompt_index.sum(), dtype=torch.bool, device=batch.device), is_mask), dim=1)
|
| 147 |
+
|
| 148 |
+
noisy_batch = torch.where(is_mask, self.mask_id, batch)
|
| 149 |
+
|
| 150 |
+
return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
|
| 151 |
+
|
| 152 |
+
@torch.no_grad()
|
| 153 |
+
def get_logits(self, batch, prompt_index):
|
| 154 |
+
if self.cfg > 0.:
|
| 155 |
+
assert len(prompt_index) == batch.shape[1]
|
| 156 |
+
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
|
| 157 |
+
un_batch = batch.clone()
|
| 158 |
+
un_batch[prompt_index] = self.mask_id
|
| 159 |
+
batch = torch.cat([batch, un_batch])
|
| 160 |
+
|
| 161 |
+
logits = self.model(batch).logits
|
| 162 |
+
|
| 163 |
+
if self.cfg > 0.:
|
| 164 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
| 165 |
+
logits = un_logits + (self.cfg + 1) * (logits - un_logits)
|
| 166 |
+
return logits[:, :batch.shape[1]]
|
| 167 |
+
|
| 168 |
+
@torch.no_grad()
|
| 169 |
+
def get_loglikelihood(self, prefix, target):
|
| 170 |
+
seq = torch.concatenate([prefix, target])[None, :]
|
| 171 |
+
seq = seq.repeat((self.batch_size, 1)).to(self.device)
|
| 172 |
+
|
| 173 |
+
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
| 174 |
+
|
| 175 |
+
loss_acc = []
|
| 176 |
+
for _ in range(self.mc_num // self.batch_size):
|
| 177 |
+
perturbed_seq, p_mask = self._forward_process(seq, prompt_index)
|
| 178 |
+
|
| 179 |
+
mask_indices = perturbed_seq == self.mask_id
|
| 180 |
+
|
| 181 |
+
logits = self.get_logits(perturbed_seq, prompt_index)
|
| 182 |
+
|
| 183 |
+
loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
|
| 184 |
+
loss = loss.sum() / self.batch_size
|
| 185 |
+
loss_acc.append(loss.item())
|
| 186 |
+
|
| 187 |
+
return - sum(loss_acc) / len(loss_acc)
|
| 188 |
+
|
| 189 |
+
@torch.no_grad()
|
| 190 |
+
def suffix_greedy_prediction(self, prefix, target):
|
| 191 |
+
if not self.is_check_greedy:
|
| 192 |
+
return False
|
| 193 |
+
|
| 194 |
+
seq = torch.full((1, len(prefix) + len(target)), self.mask_id, device=self.device)
|
| 195 |
+
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
| 196 |
+
prefix, target = prefix.to(self.device), target.to(self.device)
|
| 197 |
+
seq[0, :len(prefix)] = prefix
|
| 198 |
+
|
| 199 |
+
for i in range(len(target)):
|
| 200 |
+
mask_index = (seq == self.mask_id)
|
| 201 |
+
logits = self.get_logits(seq, prompt_index)[mask_index]
|
| 202 |
+
x0 = torch.argmax(logits, dim=-1)
|
| 203 |
+
|
| 204 |
+
p = torch.softmax(logits.to(torch.float32), dim=-1)
|
| 205 |
+
confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(dim=-1)
|
| 206 |
+
_, index = torch.sort(confidence, descending=True)
|
| 207 |
+
x0[index[1:]] = self.mask_id
|
| 208 |
+
seq[mask_index] = x0.clone()
|
| 209 |
+
correct = target == seq[0, len(prefix):]
|
| 210 |
+
correct = torch.all(correct)
|
| 211 |
+
return correct
|
| 212 |
+
|
| 213 |
+
def _encode_pair(self, context, continuation):
|
| 214 |
+
n_spaces = len(context) - len(context.rstrip())
|
| 215 |
+
if n_spaces > 0:
|
| 216 |
+
continuation = context[-n_spaces:] + continuation
|
| 217 |
+
context = context[:-n_spaces]
|
| 218 |
+
|
| 219 |
+
whole_enc = self.tokenizer(context + continuation)["input_ids"]
|
| 220 |
+
context_enc = self.tokenizer(context)["input_ids"]
|
| 221 |
+
|
| 222 |
+
context_enc_len = len(context_enc)
|
| 223 |
+
continuation_enc = whole_enc[context_enc_len:]
|
| 224 |
+
|
| 225 |
+
return context_enc, continuation_enc
|
| 226 |
+
|
| 227 |
+
def loglikelihood(self, requests):
|
| 228 |
+
def _tokenize(e):
|
| 229 |
+
prefix, target = self._encode_pair(e["prefix"], e["target"])
|
| 230 |
+
return {
|
| 231 |
+
"prefix_text": e["prefix"],
|
| 232 |
+
"target_text": e["target"],
|
| 233 |
+
"prefix": prefix,
|
| 234 |
+
"target": target,
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
ds = []
|
| 238 |
+
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
|
| 239 |
+
ds = Dataset.from_list(ds)
|
| 240 |
+
ds = ds.map(_tokenize)
|
| 241 |
+
ds = ds.with_format("torch")
|
| 242 |
+
prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]
|
| 243 |
+
|
| 244 |
+
assert max(prompt_len) <= 4096
|
| 245 |
+
|
| 246 |
+
out = []
|
| 247 |
+
with torch.no_grad():
|
| 248 |
+
for elem in tqdm(ds, desc="Computing likelihood..."):
|
| 249 |
+
prefix = elem["prefix"]
|
| 250 |
+
target = elem["target"]
|
| 251 |
+
|
| 252 |
+
ll = self.get_loglikelihood(prefix, target)
|
| 253 |
+
|
| 254 |
+
is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)
|
| 255 |
+
|
| 256 |
+
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
|
| 257 |
+
torch.cuda.empty_cache()
|
| 258 |
+
return out
|
| 259 |
+
|
| 260 |
+
def loglikelihood_rolling(self, requests):
|
| 261 |
+
raise NotImplementedError
|
| 262 |
+
|
| 263 |
+
@torch.no_grad()
|
| 264 |
+
def llada_conf_sample(self, prompt):
|
| 265 |
+
xt = torch.full((1, prompt.shape[1] + self.mask_length), self.mask_id, dtype=torch.long).to(self.model.device)
|
| 266 |
+
xt[:, :prompt.shape[1]] = prompt.clone()
|
| 267 |
+
|
| 268 |
+
prompt_index = (xt != self.mask_id)
|
| 269 |
+
prompt_len = prompt_index.sum(1).item()
|
| 270 |
+
|
| 271 |
+
assert self.mask_length % self.block_size == 0
|
| 272 |
+
num_blocks = self.mask_length // self.block_size
|
| 273 |
+
|
| 274 |
+
assert self.sampling_steps % num_blocks == 0
|
| 275 |
+
steps = self.sampling_steps // num_blocks
|
| 276 |
+
|
| 277 |
+
assert self.mask_length % self.sampling_steps == 0
|
| 278 |
+
|
| 279 |
+
for num_block in range(num_blocks):
|
| 280 |
+
for i in range(steps):
|
| 281 |
+
mask_index = (xt == self.mask_id)
|
| 282 |
+
logits = self.model(xt).logits
|
| 283 |
+
p = F.softmax(logits.to(torch.float64), dim=-1)
|
| 284 |
+
x0 = _sample_categorical(p)
|
| 285 |
+
x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
|
| 286 |
+
|
| 287 |
+
x0_p[:, prompt_len + (num_block + 1) * self.block_size:] = -np.inf
|
| 288 |
+
x0 = torch.where(mask_index, x0, xt)
|
| 289 |
+
confidence = torch.where(mask_index, x0_p, -np.inf)
|
| 290 |
+
|
| 291 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
| 292 |
+
for j in range(confidence.shape[0]):
|
| 293 |
+
_, select_index = torch.topk(confidence[j], k=int(self.mask_length / self.sampling_steps))
|
| 294 |
+
transfer_index[j, select_index] = True
|
| 295 |
+
xt[transfer_index] = x0[transfer_index]
|
| 296 |
+
if torch.sum(xt == self.tokenizer.eos_token_id) > 0:
|
| 297 |
+
return xt
|
| 298 |
+
|
| 299 |
+
return xt
|
| 300 |
+
|
| 301 |
+
@torch.no_grad()
|
| 302 |
+
def llada_remdm_sample(self, prompt):
|
| 303 |
+
xt = torch.full((1, prompt.shape[1] + self.mask_length), self.mask_id, dtype=torch.long).to(self.model.device)
|
| 304 |
+
xt[:, :prompt.shape[1]] = prompt.clone()
|
| 305 |
+
|
| 306 |
+
prompt_index = (xt != self.mask_id)
|
| 307 |
+
prompt_len = prompt_index.sum(1).item()
|
| 308 |
+
|
| 309 |
+
assert self.mask_length % self.block_size == 0
|
| 310 |
+
num_blocks = self.mask_length // self.block_size
|
| 311 |
+
|
| 312 |
+
assert self.sampling_steps % num_blocks == 0
|
| 313 |
+
steps = self.sampling_steps // num_blocks
|
| 314 |
+
|
| 315 |
+
assert self.mask_length % self.sampling_steps == 0
|
| 316 |
+
|
| 317 |
+
for num_block in range(num_blocks):
|
| 318 |
+
conf_cache = torch.ones_like(xt, dtype=torch.float64) * np.inf
|
| 319 |
+
remask_thres = int(self.block_size / 8 * 7)
|
| 320 |
+
for i in range(2 * steps):
|
| 321 |
+
if i >= remask_thres and i < remask_thres + steps:
|
| 322 |
+
remask_index = torch.zeros_like(xt, dtype=torch.bool, device=xt.device)
|
| 323 |
+
_, mask_indices = torch.topk(conf_cache, k=self.remdm_number, largest=False, dim=1)
|
| 324 |
+
remask_index[0, mask_indices] = True
|
| 325 |
+
conf_cache[remask_index] = np.inf
|
| 326 |
+
xt[remask_index] = self.mask_id
|
| 327 |
+
mask_index = (xt == self.mask_id)
|
| 328 |
+
logits = self.model(xt).logits
|
| 329 |
+
p = F.softmax(logits.to(torch.float64), dim=-1)
|
| 330 |
+
x0 = _sample_categorical(p)
|
| 331 |
+
x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
|
| 332 |
+
|
| 333 |
+
x0_p[:, prompt_len + (num_block + 1) * self.block_size:] = -np.inf
|
| 334 |
+
x0 = torch.where(mask_index, x0, xt)
|
| 335 |
+
confidence = torch.where(mask_index, x0_p, -np.inf)
|
| 336 |
+
|
| 337 |
+
if i >= remask_thres and i < remask_thres + steps:
|
| 338 |
+
transfer_length = self.remdm_number
|
| 339 |
+
else:
|
| 340 |
+
transfer_length = int(self.mask_length / self.sampling_steps)
|
| 341 |
+
|
| 342 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
| 343 |
+
for j in range(confidence.shape[0]):
|
| 344 |
+
_, select_index = torch.topk(confidence[j], k=transfer_length)
|
| 345 |
+
transfer_index[j, select_index] = True
|
| 346 |
+
xt[transfer_index] = x0[transfer_index]
|
| 347 |
+
conf_cache[transfer_index] = confidence[transfer_index]
|
| 348 |
+
if torch.sum(xt == self.tokenizer.eos_token_id) > 0:
|
| 349 |
+
return xt
|
| 350 |
+
|
| 351 |
+
return xt
|
| 352 |
+
|
| 353 |
+
@torch.no_grad()
|
| 354 |
+
def generate_until(self, requests: list[Instance]):
|
| 355 |
+
start_time = time.time()
|
| 356 |
+
|
| 357 |
+
def _tokenize(e):
|
| 358 |
+
return {
|
| 359 |
+
"question": self.tokenizer(e["question"])["input_ids"],
|
| 360 |
+
"question_text": e["question"],
|
| 361 |
+
"until": e["until"],
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
ds = [{"question": req.args[0], "until": req.args[1]['until']} for req in requests]
|
| 365 |
+
ds = Dataset.from_list(ds)
|
| 366 |
+
ds = ds.map(_tokenize)
|
| 367 |
+
ds = ds.with_format("torch")
|
| 368 |
+
|
| 369 |
+
out, out_for_json = [], []
|
| 370 |
+
for elem in tqdm(ds, desc="Generating..."):
|
| 371 |
+
prompt = elem["question"].unsqueeze(0).to(self.device)
|
| 372 |
+
stop_tokens = elem["until"] + ["<|eot_id|>", self.tokenizer.eos_token]
|
| 373 |
+
|
| 374 |
+
if self.sampler == 'llada_conf':
|
| 375 |
+
generated_answer = self.llada_conf_sample(prompt)
|
| 376 |
+
elif self.sampler == 'llada_remdm':
|
| 377 |
+
generated_answer = self.llada_remdm_sample(prompt)
|
| 378 |
+
|
| 379 |
+
generated_answer = self.tokenizer.decode(generated_answer[0][prompt.shape[1]:], skip_special_tokens=False)
|
| 380 |
+
# print(elem['question_text'] + generated_answer)
|
| 381 |
+
for stop_seq in stop_tokens:
|
| 382 |
+
if stop_seq in generated_answer:
|
| 383 |
+
generated_answer = generated_answer.split(stop_seq)[0]
|
| 384 |
+
|
| 385 |
+
# remove special tokens
|
| 386 |
+
generated_answer_ids = self.tokenizer(generated_answer)["input_ids"]
|
| 387 |
+
generated_answer = self.tokenizer.decode(generated_answer_ids, skip_special_tokens=True)
|
| 388 |
+
# print(elem['question_text'] + generated_answer)
|
| 389 |
+
out.append(generated_answer)
|
| 390 |
+
out_for_json.append({
|
| 391 |
+
"prefix": elem["question_text"],
|
| 392 |
+
"result": generated_answer,
|
| 393 |
+
})
|
| 394 |
+
|
| 395 |
+
if self.accelerator is not None:
|
| 396 |
+
self.accelerator.wait_for_everyone()
|
| 397 |
+
|
| 398 |
+
end_time = time.time()
|
| 399 |
+
total_duration = end_time - start_time
|
| 400 |
+
print(f"\n总耗时: {total_duration:.2f} 秒")
|
| 401 |
+
|
| 402 |
+
with open(os.path.join(self.generated_samples_path, str(self._rank) + ".json"), "w") as f:
|
| 403 |
+
final_output = {
|
| 404 |
+
"total_time_seconds": total_duration,
|
| 405 |
+
"samples": out_for_json
|
| 406 |
+
}
|
| 407 |
+
json.dump(final_output, f, indent=2)
|
| 408 |
+
|
| 409 |
+
return out
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
if __name__ == "__main__":
|
| 413 |
+
cli_evaluate()
|
Prism/LLaDA/LLaDA_Truthfulqa/eval_llada_prism.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
This file is inspired by the code from https://github.com/ML-GSAI/SMDM
|
| 3 |
+
And extended with Prism methods for LLaDA-Instruct.
|
| 4 |
+
'''
|
| 5 |
+
import accelerate
|
| 6 |
+
import torch
|
| 7 |
+
import re
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import random
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from datasets import Dataset
|
| 13 |
+
from lm_eval.__main__ import cli_evaluate
|
| 14 |
+
from lm_eval.api.instance import Instance
|
| 15 |
+
from lm_eval.api.model import LM
|
| 16 |
+
from lm_eval.api.registry import register_model
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
from transformers import AutoTokenizer, AutoModel
|
| 20 |
+
import json
|
| 21 |
+
import os
|
| 22 |
+
import logging
|
| 23 |
+
import math
|
| 24 |
+
import textwrap
|
| 25 |
+
import time
|
| 26 |
+
from collections import Counter
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
def set_seed(seed):
|
| 31 |
+
torch.manual_seed(seed)
|
| 32 |
+
random.seed(seed)
|
| 33 |
+
np.random.seed(seed)
|
| 34 |
+
torch.backends.cudnn.deterministic = True
|
| 35 |
+
torch.backends.cudnn.benchmark = False
|
| 36 |
+
|
| 37 |
+
def _sample_categorical(categorical_probs):
|
| 38 |
+
gumbel_norm = (1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()).to(categorical_probs.dtype)
|
| 39 |
+
return (categorical_probs / gumbel_norm).argmax(dim=-1)
|
| 40 |
+
|
| 41 |
+
class CodeVerifier:
|
| 42 |
+
def __init__(self, model, tokenizer, device="cuda"):
|
| 43 |
+
self.model = model
|
| 44 |
+
self.tokenizer = tokenizer
|
| 45 |
+
self.device = device
|
| 46 |
+
|
| 47 |
+
self.yes_ids, self.no_ids = [], []
|
| 48 |
+
for t in ["Yes", " Yes", "YES"]:
|
| 49 |
+
ids = self.tokenizer.encode(t, add_special_tokens=False)
|
| 50 |
+
if ids: self.yes_ids.append(ids[-1])
|
| 51 |
+
for t in ["No", " No", "NO"]:
|
| 52 |
+
ids = self.tokenizer.encode(t, add_special_tokens=False)
|
| 53 |
+
if ids: self.no_ids.append(ids[-1])
|
| 54 |
+
self.yes_ids = list(set(self.yes_ids))
|
| 55 |
+
self.no_ids = list(set(self.no_ids))
|
| 56 |
+
|
| 57 |
+
def svf_score(self, prompt, code_str, task_type="code"):
|
| 58 |
+
max_len = 2000
|
| 59 |
+
truncated_code = code_str[:max_len]
|
| 60 |
+
|
| 61 |
+
if task_type == "code":
|
| 62 |
+
prompt_template = f"""
|
| 63 |
+
You are an expert programming contest judge. Your task is to evaluate a generated solution for a given problem based on correctness, efficiency, and adherence to constraints.
|
| 64 |
+
|
| 65 |
+
[Problem Statement]
|
| 66 |
+
{prompt}
|
| 67 |
+
[/Problem Statement]
|
| 68 |
+
|
| 69 |
+
[Proposed Python Solution]
|
| 70 |
+
```python
|
| 71 |
+
{truncated_code}
|
| 72 |
+
```
|
| 73 |
+
[/Proposed Python Solution]
|
| 74 |
+
|
| 75 |
+
**Conclusion**: Based on your analysis, is the solution likely to be fully correct? Answer with a single word: Yes or No.
|
| 76 |
+
**Answer:** """
|
| 77 |
+
elif task_type == "math":
|
| 78 |
+
prompt_template = f"""
|
| 79 |
+
You are an expert mathematician and competition judge. Your task is to evaluate a proposed mathematical solution for a given problem based on its logical rigor and accuracy.
|
| 80 |
+
|
| 81 |
+
[Math Problem]
|
| 82 |
+
{prompt}
|
| 83 |
+
[/Math Problem]
|
| 84 |
+
|
| 85 |
+
[Proposed Mathematical Solution]
|
| 86 |
+
{truncated_code}
|
| 87 |
+
[/Proposed Mathematical Solution]
|
| 88 |
+
|
| 89 |
+
**Conclusion**: Based on your analysis, is this solution path sound and likely to result in the correct final answer? Answer with a single word: Yes or No.
|
| 90 |
+
**Answer:** """
|
| 91 |
+
elif task_type == "reasoning":
|
| 92 |
+
prompt_template = f"""
|
| 93 |
+
You are an expert reading comprehension and faithfulness judge. Your task is to evaluate a generated answer based on the provided context and question.
|
| 94 |
+
|
| 95 |
+
[Context and Question]
|
| 96 |
+
{prompt}
|
| 97 |
+
[/Context and Question]
|
| 98 |
+
|
| 99 |
+
[Proposed Answer]
|
| 100 |
+
{truncated_code}
|
| 101 |
+
[/Proposed Answer]
|
| 102 |
+
|
| 103 |
+
**Conclusion**: Based on your analysis, is the answer fully faithful to the context and correct? Answer with a single word: Yes or No.
|
| 104 |
+
**Answer:** """
|
| 105 |
+
elif task_type == "qa":
|
| 106 |
+
prompt_template = f"""
|
| 107 |
+
You are a strict truthfulness judge. Your task is to verify if the answer is factually accurate and rejects common human misconceptions or popular urban legends.
|
| 108 |
+
|
| 109 |
+
[Question]
|
| 110 |
+
{prompt}
|
| 111 |
+
[/Question]
|
| 112 |
+
|
| 113 |
+
[Proposed Answer]
|
| 114 |
+
{truncated_code}
|
| 115 |
+
[/Proposed Answer]
|
| 116 |
+
|
| 117 |
+
**Conclusion**: Based on these criteria, is this answer fully truthful and accurate? Answer with a single word: Yes or No.
|
| 118 |
+
**Answer:** """
|
| 119 |
+
else:
|
| 120 |
+
prompt_template = f"Is the following answer correct?\nQuestion: {prompt}\nAnswer: {truncated_code}\nAnswer Yes or No.\nAnswer:"
|
| 121 |
+
|
| 122 |
+
verify_text = textwrap.dedent(prompt_template).strip()
|
| 123 |
+
input_ids = self.tokenizer(verify_text, return_tensors="pt").input_ids.to(self.device)
|
| 124 |
+
if input_ids.shape[1] > 2048: input_ids = input_ids[:, -2048:]
|
| 125 |
+
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
outputs = self.model(input_ids)
|
| 128 |
+
logits = outputs.logits[0, -1, :]
|
| 129 |
+
yes_score = max([logits[i].item() for i in self.yes_ids if i < logits.shape[-1]] + [-100.0])
|
| 130 |
+
no_score = max([logits[i].item() for i in self.no_ids if i < logits.shape[-1]] + [-100.0])
|
| 131 |
+
probs = torch.softmax(torch.tensor([yes_score, no_score]), dim=0)
|
| 132 |
+
return probs[0].item()
|
| 133 |
+
|
| 134 |
+
def get_reward(self, prompt, code_str, mode="confidence", current_logits=None, task_type="code"):
|
| 135 |
+
if mode == "svf":
|
| 136 |
+
return self.svf_score(prompt, code_str, task_type=task_type)
|
| 137 |
+
else:
|
| 138 |
+
if current_logits is None: return 0.0
|
| 139 |
+
probs = torch.softmax(current_logits.to(torch.float32), dim=-1)
|
| 140 |
+
max_probs, _ = torch.max(probs, dim=-1)
|
| 141 |
+
return torch.exp(torch.mean(torch.log(max_probs + 1e-10))).item()
|
| 142 |
+
|
| 143 |
+
class HTSSampler:
|
| 144 |
+
def __init__(self, model, tokenizer, device="cuda"):
|
| 145 |
+
self.model = model
|
| 146 |
+
self.tokenizer = tokenizer
|
| 147 |
+
self.device = device
|
| 148 |
+
self.verifier = CodeVerifier(model, tokenizer, device)
|
| 149 |
+
|
| 150 |
+
def _sample_with_temperature(self, logits, temperature=0.7):
|
| 151 |
+
logits = logits.to(torch.float32)
|
| 152 |
+
if temperature > 0:
|
| 153 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
| 154 |
+
x0 = torch.multinomial(probs.view(-1, probs.shape[-1]), 1).view(logits.shape[:-1])
|
| 155 |
+
x0_p = torch.gather(torch.softmax(logits, dim=-1), -1, x0.unsqueeze(-1)).squeeze(-1)
|
| 156 |
+
else:
|
| 157 |
+
x0_p, x0 = torch.max(torch.softmax(logits, dim=-1), dim=-1)
|
| 158 |
+
return x0, x0_p
|
| 159 |
+
|
| 160 |
+
@torch.no_grad()
|
| 161 |
+
def generate_hts(self, prompt_text, input_ids, initial_N=1, final_K=1, hts_survivor_k=2,
|
| 162 |
+
steps=32, gen_length=32, mask_id=126336, reward_mode="svf", task_type="qa",
|
| 163 |
+
decay_factor=1.8, hts_start_pct=0.1, hts_end_pct=0.6, pruning_interval=3):
|
| 164 |
+
|
| 165 |
+
b = initial_N
|
| 166 |
+
prompt_len = input_ids.shape[1]
|
| 167 |
+
xt = torch.full((b, prompt_len + gen_length), mask_id, dtype=torch.long, device=self.device)
|
| 168 |
+
xt[:, :prompt_len] = input_ids.repeat(b, 1)
|
| 169 |
+
|
| 170 |
+
conf_scores = torch.zeros((b, prompt_len + gen_length), device=self.device)
|
| 171 |
+
ts_start, tr_end = int(steps * hts_start_pct), int(steps * hts_end_pct)
|
| 172 |
+
|
| 173 |
+
schedule = torch.full((steps,), gen_length // steps, dtype=torch.int64, device=self.device)
|
| 174 |
+
schedule[:gen_length % steps] += 1
|
| 175 |
+
|
| 176 |
+
next_pruning = ts_start
|
| 177 |
+
for i in range(steps):
|
| 178 |
+
mask_indices = (xt == mask_id)
|
| 179 |
+
if not mask_indices.any(): break
|
| 180 |
+
|
| 181 |
+
logits = self.model(xt).logits
|
| 182 |
+
x0, x0_p = self._sample_with_temperature(logits[:, prompt_len:], temperature=0.7)
|
| 183 |
+
|
| 184 |
+
# Update tokens based on confidence
|
| 185 |
+
for idx in range(b):
|
| 186 |
+
curr_mask = mask_indices[idx, prompt_len:]
|
| 187 |
+
if not curr_mask.any(): continue
|
| 188 |
+
conf = torch.where(curr_mask, x0_p[idx], -float('inf'))
|
| 189 |
+
_, sel_idx = torch.topk(conf, k=min(schedule[i].item(), curr_mask.sum().item()))
|
| 190 |
+
xt[idx, prompt_len + sel_idx] = x0[idx, sel_idx]
|
| 191 |
+
conf_scores[idx, prompt_len + sel_idx] = x0_p[idx, sel_idx]
|
| 192 |
+
|
| 193 |
+
# Pruning
|
| 194 |
+
if i >= next_pruning and i < tr_end and b > final_K:
|
| 195 |
+
target_width = max(final_K, math.ceil(initial_N * (decay_factor ** -(i - ts_start))))
|
| 196 |
+
if b > target_width:
|
| 197 |
+
scores = []
|
| 198 |
+
decoded_texts = self.tokenizer.batch_decode(xt[:, prompt_len:], skip_special_tokens=True)
|
| 199 |
+
for j in range(b):
|
| 200 |
+
s = self.verifier.get_reward(prompt_text, decoded_texts[j], mode=reward_mode,
|
| 201 |
+
task_type=task_type, current_logits=logits[j, prompt_len:])
|
| 202 |
+
scores.append(s)
|
| 203 |
+
|
| 204 |
+
top_k_indices = torch.topk(torch.tensor(scores), k=min(target_width, b)).indices
|
| 205 |
+
xt = xt[top_k_indices]
|
| 206 |
+
conf_scores = conf_scores[top_k_indices]
|
| 207 |
+
b = xt.shape[0]
|
| 208 |
+
next_pruning = i + pruning_interval
|
| 209 |
+
|
| 210 |
+
# Final decoding and ranking
|
| 211 |
+
final_texts = self.tokenizer.batch_decode(xt[:, prompt_len:], skip_special_tokens=True)
|
| 212 |
+
results = []
|
| 213 |
+
for j in range(b):
|
| 214 |
+
s = self.verifier.get_reward(prompt_text, final_texts[j], mode=reward_mode, task_type=task_type)
|
| 215 |
+
results.append({'text': final_texts[j], 'score': s})
|
| 216 |
+
|
| 217 |
+
results.sort(key=lambda v: v['score'], reverse=True)
|
| 218 |
+
return [r['text'] for r in results]
|
| 219 |
+
|
| 220 |
+
@register_model("llada_dist")
|
| 221 |
+
class LLaDAEvalHarness(LM):
|
| 222 |
+
def __init__(self, model_path='', mask_id=126336, max_length=4096, generated_samples_path='',
|
| 223 |
+
batch_size=32, sampling_steps=64, mask_length=128, sampler='hts', task_type="qa",
|
| 224 |
+
hts_initial_n=8, final_K=1, hts_reward_mode="svf", hts_start_pct=0.1, hts_end_pct=0.6,
|
| 225 |
+
**kwargs):
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto")
|
| 228 |
+
self.model.eval()
|
| 229 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 230 |
+
self.device = self.model.device
|
| 231 |
+
|
| 232 |
+
self.mask_id = mask_id
|
| 233 |
+
self.sampling_steps = int(sampling_steps)
|
| 234 |
+
self.mask_length = int(mask_length)
|
| 235 |
+
self.sampler = sampler
|
| 236 |
+
self.task_type = task_type
|
| 237 |
+
self.generated_samples_path = generated_samples_path
|
| 238 |
+
|
| 239 |
+
self.hts_initial_n = int(hts_initial_n)
|
| 240 |
+
self.final_K = int(final_K)
|
| 241 |
+
self.hts_reward_mode = hts_reward_mode
|
| 242 |
+
self.hts_start_pct = float(hts_start_pct)
|
| 243 |
+
self.hts_end_pct = float(hts_end_pct)
|
| 244 |
+
|
| 245 |
+
self.hts_sampler = HTSSampler(self.model, self.tokenizer, self.device)
|
| 246 |
+
self._rank = 0
|
| 247 |
+
|
| 248 |
+
@torch.no_grad()
|
| 249 |
+
def llada_conf_sample(self, prompt):
|
| 250 |
+
xt = torch.full((1, prompt.shape[1] + self.mask_length), self.mask_id, dtype=torch.long, device=self.device)
|
| 251 |
+
xt[:, :prompt.shape[1]] = prompt
|
| 252 |
+
|
| 253 |
+
step_size = self.mask_length // self.sampling_steps
|
| 254 |
+
for i in range(self.sampling_steps):
|
| 255 |
+
mask_indices = (xt == self.mask_id)
|
| 256 |
+
if not mask_indices.any(): break
|
| 257 |
+
logits = self.model(xt).logits
|
| 258 |
+
p = F.softmax(logits.to(torch.float64), dim=-1)
|
| 259 |
+
x0 = _sample_categorical(p)
|
| 260 |
+
x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
|
| 261 |
+
|
| 262 |
+
confidence = torch.where(mask_indices, x0_p, -float('inf'))
|
| 263 |
+
_, select_idx = torch.topk(confidence[0], k=min(step_size, mask_indices.sum().item()))
|
| 264 |
+
xt[0, select_idx] = x0[0, select_idx]
|
| 265 |
+
return xt
|
| 266 |
+
|
| 267 |
+
@torch.no_grad()
|
| 268 |
+
def generate_until(self, requests):
|
| 269 |
+
start_time = time.time()
|
| 270 |
+
|
| 271 |
+
out, out_for_json = [], []
|
| 272 |
+
for req in tqdm(requests, desc="Generating..."):
|
| 273 |
+
prompt_text = req.args[0]
|
| 274 |
+
until = req.args[1]['until']
|
| 275 |
+
prompt_ids = self.tokenizer(prompt_text, return_tensors="pt").input_ids.to(self.device)
|
| 276 |
+
|
| 277 |
+
if self.sampler == 'hts':
|
| 278 |
+
candidates = self.hts_sampler.generate_hts(
|
| 279 |
+
prompt_text=prompt_text,
|
| 280 |
+
input_ids=prompt_ids,
|
| 281 |
+
initial_N=self.hts_initial_n,
|
| 282 |
+
final_K=self.final_K,
|
| 283 |
+
steps=self.sampling_steps,
|
| 284 |
+
gen_length=self.mask_length,
|
| 285 |
+
reward_mode=self.hts_reward_mode,
|
| 286 |
+
task_type=self.task_type,
|
| 287 |
+
hts_start_pct=self.hts_start_pct,
|
| 288 |
+
hts_end_pct=self.hts_end_pct
|
| 289 |
+
)
|
| 290 |
+
if not candidates:
|
| 291 |
+
generated_answer = ""
|
| 292 |
+
else:
|
| 293 |
+
counts = Counter(candidates)
|
| 294 |
+
most_common = counts.most_common()
|
| 295 |
+
if most_common[0][1] > 1:
|
| 296 |
+
generated_answer = most_common[0][0]
|
| 297 |
+
else:
|
| 298 |
+
generated_answer = candidates[0]
|
| 299 |
+
else:
|
| 300 |
+
res_ids = self.llada_conf_sample(prompt_ids)
|
| 301 |
+
generated_answer = self.tokenizer.decode(res_ids[0, prompt_ids.shape[1]:], skip_special_tokens=True)
|
| 302 |
+
|
| 303 |
+
for stop_seq in until + ["<|eot_id|>", self.tokenizer.eos_token]:
|
| 304 |
+
if stop_seq and stop_seq in generated_answer:
|
| 305 |
+
generated_answer = generated_answer.split(stop_seq)[0]
|
| 306 |
+
|
| 307 |
+
generated_answer = generated_answer.strip()
|
| 308 |
+
out.append(generated_answer)
|
| 309 |
+
out_for_json.append({"prefix": prompt_text, "result": generated_answer})
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
end_time = time.time()
|
| 313 |
+
total_duration = end_time - start_time
|
| 314 |
+
|
| 315 |
+
if self.generated_samples_path:
|
| 316 |
+
os.makedirs(self.generated_samples_path, exist_ok=True)
|
| 317 |
+
final_output = {
|
| 318 |
+
"total_time_seconds": total_duration,
|
| 319 |
+
"samples": out_for_json
|
| 320 |
+
}
|
| 321 |
+
with open(os.path.join(self.generated_samples_path, "res.json"), "w") as f:
|
| 322 |
+
json.dump(final_output, f, indent=2)
|
| 323 |
+
return out
|
| 324 |
+
|
| 325 |
+
def loglikelihood(self, requests): return []
|
| 326 |
+
def loglikelihood_rolling(self, requests): return []
|
| 327 |
+
@property
|
| 328 |
+
def rank(self): return 0
|
| 329 |
+
@property
|
| 330 |
+
def world_size(self): return 1
|
| 331 |
+
|
| 332 |
+
if __name__ == "__main__":
|
| 333 |
+
cli_evaluate()
|
Prism/README.md
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Prism: Efficient Test-Time Scaling via Hierarchical Search and Self-Verification for Discrete Diffusion Language Models
|
| 2 |
+
|
| 3 |
+
## PRISM: Pruning, Remasking, and Integrated Self-verification Method
|
| 4 |
+
|
| 5 |
+
PRISM is an efficient inference framework designed for **Discrete Diffusion Language Models (dLLMs)**, focusing on a favorable performance-efficiency trade-off by matching Best-of-N performance with substantially fewer Function Evaluations (NFE).
|
| 6 |
+
|
| 7 |
+
[](https://arxiv.org/abs/2602.01842)
|
| 8 |
+
[](https://github.com/viiika/Prism)
|
| 9 |
+
|
| 10 |
+
### Method
|
| 11 |
+

|
| 12 |
+
|
| 13 |
+
### Experiments
|
| 14 |
+
|
| 15 |
+

|
| 16 |
+
|
| 17 |
+
### Project Structure
|
| 18 |
+
|
| 19 |
+
```text
|
| 20 |
+
PRISM/
|
| 21 |
+
├── Dream/ # Experiments for Dream
|
| 22 |
+
│ ├── Dream_Baseline/ # Standard baseline sampling (N=1)
|
| 23 |
+
│ └── Dream_Prism/ # Prism implementation
|
| 24 |
+
├── LLaDA/ # Experiments for LLaDA 8B Instruct
|
| 25 |
+
│ ├── LLaDA_Baseline/ # Standard baseline sampling (N=1)
|
| 26 |
+
│ ├── LLaDA_Prism/ # PRISM implementation
|
| 27 |
+
│ └── LLaDA_Truthfulqa/ # TruthfulQA evaluation
|
| 28 |
+
└── LLaDA2mini/ # Experiments for LLaDA 2.0-mini
|
| 29 |
+
├── LLaDA2mini_Baseline/ # Standard baseline sampling (N=1)
|
| 30 |
+
└── LLaDA2mini_Prism/ # Prism implementation
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
### Prerequisites
|
| 34 |
+
```bash
|
| 35 |
+
cd PRISM
|
| 36 |
+
```
|
| 37 |
+
For Dream Project:
|
| 38 |
+
```bash
|
| 39 |
+
cd Dream/Dream_Prism/eval_instruct
|
| 40 |
+
pip install -e .
|
| 41 |
+
```
|
| 42 |
+
For LLaDA_Truthfulqa:
|
| 43 |
+
```bash
|
| 44 |
+
cd LLaDA/LLaDA_Truthfulqa/lm-evaluation-harness
|
| 45 |
+
pip install -e .
|
| 46 |
+
```
|
| 47 |
+
For LLaDA and LLaDA2 Projects:
|
| 48 |
+
```bash
|
| 49 |
+
cd LLaDA/LLaDA_Prism
|
| 50 |
+
pip install -r requirements.txt
|
| 51 |
+
```
|
| 52 |
+
#### Quick Start
|
| 53 |
+
Evaluate Dream
|
| 54 |
+
```bash
|
| 55 |
+
cd Dream/Dream_Prism
|
| 56 |
+
bash scripts/run_gsm8k.sh
|
| 57 |
+
bash scripts/run_humaneval.sh
|
| 58 |
+
bash scripts/run_math500.sh
|
| 59 |
+
bash scripts/run_mbpp.sh
|
| 60 |
+
```
|
| 61 |
+
Evaluate LLaDA 8B Instruct
|
| 62 |
+
```bash
|
| 63 |
+
cd LLaDA/LLaDA_Prism
|
| 64 |
+
bash scripts/run_gsm8k.sh
|
| 65 |
+
bash scripts/run_humaneval.sh
|
| 66 |
+
bash scripts/run_math500.sh
|
| 67 |
+
bash scripts/run_mbpp.sh
|
| 68 |
+
```
|
| 69 |
+
Evaluate LLaDA 8B Instruct(Truthfulqa)
|
| 70 |
+
```bash
|
| 71 |
+
cd LLaDA/LLaDA_Truthfulqa
|
| 72 |
+
bash scripts/llada_prism.sh
|
| 73 |
+
```
|
| 74 |
+
Evaluate LLaDA 2.0-mini
|
| 75 |
+
```bash
|
| 76 |
+
cd LLaDA2mini/LLaDA2mini_Prism
|
| 77 |
+
bash scripts/run_gsm8k.sh
|
| 78 |
+
bash scripts/run_humaneval.sh
|
| 79 |
+
bash scripts/run_math500.sh
|
| 80 |
+
bash scripts/run_mbpp.sh
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### Evaluation & Metrics
|
| 84 |
+
Each project folder contains a metrics/ directory used for calculating final accuracy and efficiency metrics.
|
| 85 |
+
Usage Example:
|
| 86 |
+
```bash
|
| 87 |
+
python PRISM/LLaDA/LLaDA_Prism/metrics/gsm8k_all.py
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Acknowledgements
|
| 91 |
+
This project is built upon [preordinary/LLaDA2](https://github.com/preordinary/LLaDA2), [ML-GSAI/LLaDA](https://github.com/ML-GSAI/LLaDA), [DreamLM/Dream](https://github.com/DreamLM/Dream) and [EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). Special thanks to the authors for their contributions.
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
### 📚 Citation
|
| 96 |
+
|
| 97 |
+
If you find this work helpful, please consider citing:
|
| 98 |
+
|
| 99 |
+
```bibtex
|
| 100 |
+
@article{bai2026prism,
|
| 101 |
+
title={Prism: Efficient Test-Time Scaling via Hierarchical Search and Self-Verification for Discrete Diffusion Language Models},
|
| 102 |
+
author={Bai, Jinbin and Li, Yixuan and Zhu, Yuchen and Xin, Yi and Shi, Qingyu and Feng, Aosong and Liu, Xiaohong and Tao, Molei and Xue, Jianru and Li, Xiangtai and Yang, Ming-Hsuan},
|
| 103 |
+
journal={arXiv preprint arXiv:2602.01842},
|
| 104 |
+
year={2026}
|
| 105 |
+
}
|
| 106 |
+
```
|
| 107 |
+
|
URSA-1.7B/.gitattributes
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
. filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
URSA-1.7B/.gitignore
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Compiled Object files
|
| 2 |
+
*.slo
|
| 3 |
+
*.lo
|
| 4 |
+
*.o
|
| 5 |
+
*.cuo
|
| 6 |
+
|
| 7 |
+
# Compiled Dynamic libraries
|
| 8 |
+
*.so
|
| 9 |
+
*.dll
|
| 10 |
+
*.dylib
|
| 11 |
+
|
| 12 |
+
# Compiled Static libraries
|
| 13 |
+
*.lai
|
| 14 |
+
*.la
|
| 15 |
+
*.a
|
| 16 |
+
*.lib
|
| 17 |
+
|
| 18 |
+
# Compiled python
|
| 19 |
+
*.pyc
|
| 20 |
+
__pycache__
|
| 21 |
+
|
| 22 |
+
# Compiled MATLAB
|
| 23 |
+
*.mex*
|
| 24 |
+
|
| 25 |
+
# IPython notebook checkpoints
|
| 26 |
+
.ipynb_checkpoints
|
| 27 |
+
|
| 28 |
+
# Editor temporaries
|
| 29 |
+
*.swp
|
| 30 |
+
*~
|
| 31 |
+
|
| 32 |
+
# Sublime Text settings
|
| 33 |
+
*.sublime-workspace
|
| 34 |
+
*.sublime-project
|
| 35 |
+
|
| 36 |
+
# Eclipse Project settings
|
| 37 |
+
*.*project
|
| 38 |
+
.settings
|
| 39 |
+
|
| 40 |
+
# QtCreator files
|
| 41 |
+
*.user
|
| 42 |
+
|
| 43 |
+
# VSCode files
|
| 44 |
+
.vscode
|
| 45 |
+
|
| 46 |
+
# IDEA files
|
| 47 |
+
.idea
|
| 48 |
+
|
| 49 |
+
# OSX dir files
|
| 50 |
+
.DS_Store
|
| 51 |
+
|
| 52 |
+
# Android files
|
| 53 |
+
.gradle
|
| 54 |
+
*.iml
|
| 55 |
+
local.properties
|
URSA-1.7B/LICENSE
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
URSA-1.7B/README.md
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: diffusers
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
license_link: https://huggingface.co/BAAI/URSA-1.7B-FSQ320/blob/main/LICENSE
|
| 5 |
+
pipeline_tag: text-to-video
|
| 6 |
+
base_model:
|
| 7 |
+
- Qwen/Qwen3-1.7B
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# URSA-1.7B-FSQ320 Model Card
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
- **Developed by:** BAAI
|
| 14 |
+
- **Model type:** Text-to-Video Generation Model
|
| 15 |
+
- **Model size:** 1.7B
|
| 16 |
+
- **Model precision:** torch.float16 (FP16)
|
| 17 |
+
- **Model resolution:** 512x320
|
| 18 |
+
- **Model paper:** [Uniform Discrete Diffusion with Metric Path for Video Generation](https://arxiv.org/abs/2510.24717)
|
| 19 |
+
- **Model family:** [BAAI-Vision-URSA](https://github.com/baaivision/URSA)
|
| 20 |
+
- **Model Tokenizer:** [Cosmos-Tokenize1-DV4x8x8-360p](https://huggingface.co/nvidia/Cosmos-Tokenize1-DV4x8x8-360p)
|
| 21 |
+
- **Model Description:** This is a model that can be used to generate and modify videos based on text prompts.
|
| 22 |
+
|
| 23 |
+
## Examples
|
| 24 |
+
|
| 25 |
+
Using the [🤗's Diffusers library](https://github.com/huggingface/diffusers) to run URSA in a simple and efficient manner.
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
pip install diffusers transformers accelerate imageio[ffmpeg]
|
| 29 |
+
pip install git+ssh://git@github.com/baaivision/URSA.git
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Running the pipeline:
|
| 33 |
+
|
| 34 |
+
```python
|
| 35 |
+
import os, torch, numpy
|
| 36 |
+
from diffnext.pipelines import URSAPipeline
|
| 37 |
+
from diffnext.utils import export_to_video
|
| 38 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 39 |
+
|
| 40 |
+
model_id, height, width = "BAAI/URSA-1.7B-FSQ320", 320, 512
|
| 41 |
+
model_args = {"torch_dtype": torch.float16, "trust_remote_code": True}
|
| 42 |
+
pipe = URSAPipeline.from_pretrained(model_id, **model_args)
|
| 43 |
+
pipe = pipe.to(torch.device("cuda"))
|
| 44 |
+
|
| 45 |
+
text_prompt = "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur."
|
| 46 |
+
negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly"
|
| 47 |
+
|
| 48 |
+
# Text-to-Image
|
| 49 |
+
prompt = text_prompt
|
| 50 |
+
num_frames, num_inference_steps = 1, 25
|
| 51 |
+
image = pipe(**locals()).frames[0]
|
| 52 |
+
image.save("ursa.jpg")
|
| 53 |
+
|
| 54 |
+
# Image-to-Video
|
| 55 |
+
prompt = f"motion=9.0, {text_prompt}"
|
| 56 |
+
num_frames, num_inference_steps = 49, 50
|
| 57 |
+
video = pipe(**locals()).frames[0]
|
| 58 |
+
export_to_video(video, "ursa_1+48f.mp4", fps=12)
|
| 59 |
+
|
| 60 |
+
# Text-to-Video
|
| 61 |
+
image, video = None, None
|
| 62 |
+
prompt = f"motion=9.0, {text_prompt}"
|
| 63 |
+
num_frames, num_inference_steps = 49, 50
|
| 64 |
+
video = pipe(**locals()).frames[0]
|
| 65 |
+
export_to_video(video, "ursa_49f.mp4", fps=12)
|
| 66 |
+
|
| 67 |
+
# Video-to-Video
|
| 68 |
+
prompt = f"motion=5.0, {text_prompt}"
|
| 69 |
+
num_frames, num_inference_steps = 49, 50
|
| 70 |
+
num_cond_frames, cond_noise_scale = 13, 0.1
|
| 71 |
+
for i in range(12):
|
| 72 |
+
video, start_video = video[-num_cond_frames:], video
|
| 73 |
+
video = pipe(**locals()).frames[0]
|
| 74 |
+
video = numpy.concatenate([start_video, video[num_cond_frames:]])
|
| 75 |
+
export_to_video(video, "ursa_{}f.mp4".format(video.shape[0]), fps=12)
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
# Uses
|
| 79 |
+
|
| 80 |
+
## Direct Use
|
| 81 |
+
The model is intended for research purposes only. Possible research areas and tasks include
|
| 82 |
+
|
| 83 |
+
- Research on generative models.
|
| 84 |
+
- Applications in educational or creative tools.
|
| 85 |
+
- Generation of artworks and use in design and other artistic processes.
|
| 86 |
+
- Probing and understanding the limitations and biases of generative models.
|
| 87 |
+
- Safe deployment of models which have the potential to generate harmful content.
|
| 88 |
+
|
| 89 |
+
Excluded uses are described below.
|
| 90 |
+
|
| 91 |
+
#### Out-of-Scope Use
|
| 92 |
+
The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
| 93 |
+
|
| 94 |
+
#### Misuse and Malicious Use
|
| 95 |
+
Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
|
| 96 |
+
|
| 97 |
+
- Mis- and disinformation.
|
| 98 |
+
- Representations of egregious violence and gore.
|
| 99 |
+
- Impersonating individuals without their consent.
|
| 100 |
+
- Sexual content without consent of the people who might see it.
|
| 101 |
+
- Sharing of copyrighted or licensed material in violation of its terms of use.
|
| 102 |
+
- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
|
| 103 |
+
- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
|
| 104 |
+
- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
|
| 105 |
+
|
| 106 |
+
## Limitations and Bias
|
| 107 |
+
|
| 108 |
+
### Limitations
|
| 109 |
+
|
| 110 |
+
- The autoencoding part of the model is lossy.
|
| 111 |
+
- The model cannot render complex legible text.
|
| 112 |
+
- The model does not achieve perfect photorealism.
|
| 113 |
+
- The fingers, .etc in general may not be generated properly.
|
| 114 |
+
- The model was trained on a subset of the web datasets [LAION-5B](https://laion.ai/blog/laion-5b/) and [COYO-700M](https://github.com/kakaobrain/coyo-dataset), which contains adult, violent and sexual content.
|
| 115 |
+
|
| 116 |
+
### Bias
|
| 117 |
+
While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
|
URSA-1.7B/model_index.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "URSAPipeline",
|
| 3 |
+
"tokenizer": [
|
| 4 |
+
"transformers",
|
| 5 |
+
"Qwen2TokenizerFast"
|
| 6 |
+
],
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"__scheduler__",
|
| 9 |
+
"KineticOptimalScheduler"
|
| 10 |
+
],
|
| 11 |
+
"transformer": [
|
| 12 |
+
"__transformer__",
|
| 13 |
+
"URSATransformer3DModel"
|
| 14 |
+
],
|
| 15 |
+
"vae": [
|
| 16 |
+
"__vae__",
|
| 17 |
+
"AutoencoderVQCosmos3D"
|
| 18 |
+
]
|
| 19 |
+
}
|
URSA/.flake8
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[flake8]
|
| 2 |
+
max-line-length = 100
|
| 3 |
+
ignore =
|
| 4 |
+
# whitespace before ':' (conflicted with Black)
|
| 5 |
+
E203,
|
| 6 |
+
# ambiguous variable name
|
| 7 |
+
E741,
|
| 8 |
+
# ‘from module import *’ used; unable to detect undefined names
|
| 9 |
+
F403,
|
| 10 |
+
# name may be undefined, or defined from star imports: module
|
| 11 |
+
F405,
|
| 12 |
+
# redefinition of unused name from line N
|
| 13 |
+
F811,
|
| 14 |
+
# undefined name
|
| 15 |
+
F821,
|
| 16 |
+
# line break before binary operator
|
| 17 |
+
W503,
|
| 18 |
+
# line break after binary operator
|
| 19 |
+
W504
|
| 20 |
+
# module imported but unused
|
| 21 |
+
per-file-ignores = __init__.py: F401
|
URSA/.gitignore
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Compiled Object files
|
| 2 |
+
*.slo
|
| 3 |
+
*.lo
|
| 4 |
+
*.o
|
| 5 |
+
*.cuo
|
| 6 |
+
|
| 7 |
+
# Compiled Dynamic libraries
|
| 8 |
+
*.so
|
| 9 |
+
*.dll
|
| 10 |
+
*.dylib
|
| 11 |
+
|
| 12 |
+
# Compiled Static libraries
|
| 13 |
+
*.lai
|
| 14 |
+
*.la
|
| 15 |
+
*.a
|
| 16 |
+
*.lib
|
| 17 |
+
|
| 18 |
+
# Compiled python
|
| 19 |
+
*.pyc
|
| 20 |
+
__pycache__
|
| 21 |
+
|
| 22 |
+
# Compiled MATLAB
|
| 23 |
+
*.mex*
|
| 24 |
+
|
| 25 |
+
# IPython notebook checkpoints
|
| 26 |
+
.ipynb_checkpoints
|
| 27 |
+
|
| 28 |
+
# Editor temporaries
|
| 29 |
+
*.swp
|
| 30 |
+
*~
|
| 31 |
+
|
| 32 |
+
# Sublime Text settings
|
| 33 |
+
*.sublime-workspace
|
| 34 |
+
*.sublime-project
|
| 35 |
+
|
| 36 |
+
# Eclipse Project settings
|
| 37 |
+
*.*project
|
| 38 |
+
.settings
|
| 39 |
+
|
| 40 |
+
# QtCreator files
|
| 41 |
+
*.user
|
| 42 |
+
|
| 43 |
+
# VSCode files
|
| 44 |
+
.vscode
|
| 45 |
+
|
| 46 |
+
# IDEA files
|
| 47 |
+
.idea
|
| 48 |
+
|
| 49 |
+
# OSX dir files
|
| 50 |
+
.DS_Store
|
| 51 |
+
|
| 52 |
+
# Android files
|
| 53 |
+
.gradle
|
| 54 |
+
*.iml
|
| 55 |
+
local.properties
|
URSA/=4.57.1
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Requirement already satisfied: diffusers in /usr/local/lib/python3.12/dist-packages (0.36.0)
|
| 2 |
+
Requirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (5.2.0)
|
| 3 |
+
Requirement already satisfied: accelerate in /usr/local/lib/python3.12/dist-packages (1.12.0)
|
| 4 |
+
Requirement already satisfied: imageio in /usr/local/lib/python3.12/dist-packages (2.37.2)
|
| 5 |
+
Requirement already satisfied: imageio-ffmpeg in /usr/local/lib/python3.12/dist-packages (0.6.0)
|
| 6 |
+
Requirement already satisfied: omegaconf in /usr/local/lib/python3.12/dist-packages (2.3.0)
|
| 7 |
+
Requirement already satisfied: wandb in /usr/local/lib/python3.12/dist-packages (0.25.0)
|
| 8 |
+
Requirement already satisfied: importlib_metadata in /usr/local/lib/python3.12/dist-packages/setuptools/_vendor (from diffusers) (8.0.0)
|
| 9 |
+
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from diffusers) (3.17.0)
|
| 10 |
+
Requirement already satisfied: httpx<1.0.0 in /usr/local/lib/python3.12/dist-packages (from diffusers) (0.28.1)
|
| 11 |
+
Requirement already satisfied: huggingface-hub<2.0,>=0.34.0 in /usr/local/lib/python3.12/dist-packages (from diffusers) (1.3.0)
|
| 12 |
+
Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from diffusers) (1.26.4)
|
| 13 |
+
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from diffusers) (2024.11.6)
|
| 14 |
+
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from diffusers) (2.32.3)
|
| 15 |
+
Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from diffusers) (0.5.3)
|
| 16 |
+
Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from diffusers) (11.1.0)
|
| 17 |
+
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (23.2)
|
| 18 |
+
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers) (6.0.2)
|
| 19 |
+
Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.22.2)
|
| 20 |
+
Requirement already satisfied: typer-slim in /usr/local/lib/python3.12/dist-packages (from transformers) (0.21.2)
|
| 21 |
+
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.12/dist-packages (from transformers) (4.67.1)
|
| 22 |
+
Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from accelerate) (7.0.0)
|
| 23 |
+
Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from accelerate) (2.9.0+cu128)
|
| 24 |
+
Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.12/dist-packages (from omegaconf) (4.9.3)
|
| 25 |
+
Requirement already satisfied: click>=8.0.1 in /usr/local/lib/python3.12/dist-packages (from wandb) (8.1.8)
|
| 26 |
+
Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (3.1.46)
|
| 27 |
+
Requirement already satisfied: platformdirs in /usr/local/lib/python3.12/dist-packages (from wandb) (4.3.6)
|
| 28 |
+
Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (4.24.4)
|
| 29 |
+
Requirement already satisfied: pydantic<3 in /usr/local/lib/python3.12/dist-packages (from wandb) (2.10.6)
|
| 30 |
+
Requirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (2.54.0)
|
| 31 |
+
Requirement already satisfied: typing-extensions<5,>=4.8 in /usr/local/lib/python3.12/dist-packages (from wandb) (4.12.2)
|
| 32 |
+
Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.12/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.12)
|
| 33 |
+
Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers) (4.8.0)
|
| 34 |
+
Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers) (2025.1.31)
|
| 35 |
+
Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers) (1.0.7)
|
| 36 |
+
Requirement already satisfied: idna in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers) (3.10)
|
| 37 |
+
Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1.0.0->diffusers) (0.14.0)
|
| 38 |
+
Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=0.34.0->diffusers) (2025.2.0)
|
| 39 |
+
Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=0.34.0->diffusers) (1.3.2)
|
| 40 |
+
Requirement already satisfied: shellingham in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=0.34.0->diffusers) (1.5.4)
|
| 41 |
+
Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic<3->wandb) (0.7.0)
|
| 42 |
+
Requirement already satisfied: pydantic-core==2.27.2 in /usr/local/lib/python3.12/dist-packages (from pydantic<3->wandb) (2.27.2)
|
| 43 |
+
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->diffusers) (3.4.1)
|
| 44 |
+
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->diffusers) (2.0.7)
|
| 45 |
+
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (75.8.2)
|
| 46 |
+
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (1.14.0)
|
| 47 |
+
Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.4.2)
|
| 48 |
+
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.1.6)
|
| 49 |
+
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.93)
|
| 50 |
+
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.90)
|
| 51 |
+
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.90)
|
| 52 |
+
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (9.10.2.21)
|
| 53 |
+
Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.4.1)
|
| 54 |
+
Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (11.3.3.83)
|
| 55 |
+
Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (10.3.9.90)
|
| 56 |
+
Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (11.7.3.90)
|
| 57 |
+
Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.5.8.93)
|
| 58 |
+
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (0.7.1)
|
| 59 |
+
Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (2.27.5)
|
| 60 |
+
Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.3.20)
|
| 61 |
+
Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.90)
|
| 62 |
+
Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.93)
|
| 63 |
+
Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (1.13.1.3)
|
| 64 |
+
Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.5.0)
|
| 65 |
+
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.12/dist-packages/setuptools/_vendor (from importlib_metadata->diffusers) (3.19.2)
|
| 66 |
+
Requirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.12/dist-packages (from typer-slim->transformers) (0.0.4)
|
| 67 |
+
Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.2)
|
| 68 |
+
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=2.0.0->accelerate) (1.3.0)
|
| 69 |
+
Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.12/dist-packages (from anyio->httpx<1.0.0->diffusers) (1.3.1)
|
| 70 |
+
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=2.0.0->accelerate) (3.0.2)
|
URSA/LICENSE
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
URSA/README.md
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
<img src="assets/logo.png" width="30%" alt="logo"/>
|
| 4 |
+
|
| 5 |
+
<h1>🐻 URSA: Uniform Discrete Diffusion with Metric Path<br>for Video Generation</h1>
|
| 6 |
+
|
| 7 |
+
<p align="center">
|
| 8 |
+
<a href="https://arxiv.org/abs/2510.24717"><img src="https://img.shields.io/badge/ArXiv-2510.24717-%23840707.svg" alt="ArXiv"></a>
|
| 9 |
+
<a href="https://huggingface.co/collections/BAAI/ursa"><img src="https://img.shields.io/badge/🤗 Weights-BAAI/URSA-rgb(166,109,59).svg" alt=""></a>
|
| 10 |
+
<a href="https://huggingface.co/spaces/BAAI/nova-d48w1024-osp480"><img src="https://img.shields.io/badge/🤗 Demo-TI2V-%26840707.svg" alt="TI2VDemo"></a>
|
| 11 |
+
<a href="http://bitterdhg.github.io/URSA_page"><img src="https://img.shields.io/badge/Project-URSA-%237CB4F7.svg" alt="Project"></a>
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+
<p align="center">
|
| 15 |
+
|
| 16 |
+
[Haoge Deng](https://scholar.google.com/citations?user=S2sbvjgAAAAJ&hl)<sup>1,4*</sup>, [Ting Pan](https://scholar.google.com/citations?&user=qQv6YbsAAAAJ)<sup>2,4*</sup>, [Fan Zhang](https://scholar.google.com/citations?user=VsJ39HMAAAAJ)<sup>4*</sup>, [Yang Liu](https://scholar.google.com/citations?user=9JcQ2hwAAAAJ&hl)<sup>3,4*</sup>, [Zhuoyan Luo](https://scholar.google.com/citations?user=mKQhEsIAAAAJ&hl)<sup>4</sup>, [Yufeng Cui](https://scholar.google.com/citations?user=5Ydha2EAAAAJ&hl)<sup>4</sup>, [Wenxuan Wang](https://scholar.google.com/citations?user=75OyC-oAAAAJ&hl)<sup>4</sup><br>
|
| 17 |
+
[Chunhua Shen](https://scholar.google.com/citations?user=Ljk2BvIAAAAJ&hl)<sup>3</sup>, [Shiguang Shan](https://scholar.google.com/citations?user=Vkzd7MIAAAAJ&hl)<sup>2</sup>, [Zhaoxiang Zhang](https://scholar.google.com/citations?user=qxWfV6cAAAAJ&hl)<sup>1†</sup>, [Xinlong Wang](https://scholar.google.com/citations?user=DPz0DjYAAAAJ&hl)<sup>4†</sup><br>
|
| 18 |
+
|
| 19 |
+
[CASIA](http://english.ia.cas.cn)<sup>1</sup>, [CASICT](http://english.ict.cas.cn)<sup>2</sup>, [ZJU](https://www.zju.edu.cn/english)<sup>3</sup>, [BAAI](https://www.baai.ac.cn/en)<sup>4</sup><br>
|
| 20 |
+
<sup>*</sup> Equal Contribution, <sup>†</sup> Corresponding Author
|
| 21 |
+
<br><br><image src="assets/model_preview.gif"/>
|
| 22 |
+
<br><br><image src="assets/model_overview.png"/>
|
| 23 |
+
</div>
|
| 24 |
+
|
| 25 |
+
We present **URSA** (**U**niform disc**R**ete diffu**S**ion with metric p**A**th), a simple yet powerful framework that bridges the gap with continuous approaches. **URSA** formulates the video generation task as an iterative global refinement of discrete spatiotemporal tokens and scales efficiently to long video generation, requiring fewer inference steps. **URSA** enables multi-task video generation with asynchronous timestep scheduling strategy in one unified model.
|
| 26 |
+
|
| 27 |
+
## 🚀 News
|
| 28 |
+
- ```[Feb 2026]``` Accepted by ICLR 2026 [[OpenReview]](https://openreview.net/forum?id=GFU5yCbILk).
|
| 29 |
+
- ```[Jan 2026]``` Released [Training Guide](./docs/training.md).
|
| 30 |
+
- ```[Oct 2025]``` 🎉 URSA is part of [Emu3.5](https://github.com/baaivision/Emu3.5) as DiDA (Discrete Diffusion Adaptation)!
|
| 31 |
+
- ```[Oct 2025]``` Released <a href="https://huggingface.co/spaces/BAAI/nova-d48w1024-osp480"><b>TI2V</b></a> 🤗 Demo.
|
| 32 |
+
- ```[Oct 2025]``` Released [Paper](https://arxiv.org/abs/2510.24717) & [Project Page](http://bitterdhg.github.io/URSA_page) & [Evaluation Guide](./docs/evaluation.md).
|
| 33 |
+
|
| 34 |
+
## ✨Hightlights
|
| 35 |
+
|
| 36 |
+
- 🥇 **Novel Approach**: Uniform Discrete Diffusion with Metric Path.
|
| 37 |
+
- 🥈 **SOTA Performance**: High efficiency with state-of-the-art T2I/T2V/I2V results.
|
| 38 |
+
- 🥉 **Unified Modeling**: Multi-task capabilities in a single unified model.
|
| 39 |
+
|
| 40 |
+
## 🗄️ Models
|
| 41 |
+
|
| 42 |
+
### 🖼️ Text to Image
|
| 43 |
+
|
| 44 |
+
| Model | Resolution | Data | Weight | GenEval | DPGBench |
|
| 45 |
+
|:-----:|:----------:|:----:|:------:|:-------:|:--------:|
|
| 46 |
+
| URSA-0.6B-IBQ1024 | 1024x1024 | 30M | [🤗 HF](https://huggingface.co/BAAI/URSA-0.6B-IBQ1024) \| [🤖 ModelScope](https://www.modelscope.cn/models/BAAI/URSA-0.6B-IBQ1024) | 0.79 | 85.6 |
|
| 47 |
+
| URSA-1.7B-IBQ1024 | 1024x1024 | 30M | [🤗 HF](https://huggingface.co/BAAI/URSA-1.7B-IBQ1024) \| [🤖 ModelScope](https://www.modelscope.cn/models/BAAI/URSA-1.7B-IBQ1024) | 0.80 | 86.0 |
|
| 48 |
+
|
| 49 |
+
### 🎬 Text to Video
|
| 50 |
+
|
| 51 |
+
| Model | Resolution | Data | Weight | VBench-T2V | VBench-I2V |
|
| 52 |
+
|:-----:|:----------:|:----:|:------:|:----------:|:----------:|
|
| 53 |
+
| URSA-0.6B-FSQ320 | 49x512x320 | 24M | [🤗 HF](https://huggingface.co/BAAI/URSA-0.6B-FSQ320) \| [🤖 ModelScope](https://www.modelscope.cn/models/BAAI/URSA-0.6B-FSQ320) | 81.4 | 86.0 |
|
| 54 |
+
| URSA-1.7B-FSQ320 | 49x512x320 | 24M | [🤗 HF](https://huggingface.co/BAAI/URSA-1.7B-FSQ320) \| [🤖 ModelScope](https://www.modelscope.cn/models/BAAI/URSA-1.7B-FSQ320) | 82.4 | 86.2 |
|
| 55 |
+
|
| 56 |
+
## 📖 Table of Contents
|
| 57 |
+
- [🔧 Installation](#installation)
|
| 58 |
+
- [🔥 Quick Start](#quick-start)
|
| 59 |
+
- [🖼️ Image Generation](#quickstart-image-generation)
|
| 60 |
+
- [🎬 Video Generation](#quickstart-video-generation)
|
| 61 |
+
- [💻 Gradio Demo](#gradio-demo)
|
| 62 |
+
- [💯 Evaluation](./docs/evaluation.md)
|
| 63 |
+
- [🤖 Training](./docs/training.md)
|
| 64 |
+
|
| 65 |
+
## 🔧 Installation
|
| 66 |
+
<a id="installation"></a>
|
| 67 |
+
|
| 68 |
+
Clone this repository to local disk and install:
|
| 69 |
+
```bash
|
| 70 |
+
pip install diffusers transformers>=4.57.1 accelerate imageio imageio-ffmpeg omegaconf wandb
|
| 71 |
+
git clone https://github.com/baaivision/URSA.git
|
| 72 |
+
cd URSA && pip install .
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
## 🔥 Quick Start
|
| 76 |
+
<a id="quick-start"></a>
|
| 77 |
+
|
| 78 |
+
### 🖼️ Image Generation
|
| 79 |
+
<a id="quickstart-image-generation"></a>
|
| 80 |
+
|
| 81 |
+
```python
|
| 82 |
+
import torch
|
| 83 |
+
from diffnext.pipelines import URSAPipeline
|
| 84 |
+
|
| 85 |
+
model_id, height, width = "BAAI/URSA-1.7B-IBQ1024", 1024, 1024
|
| 86 |
+
model_args = {"torch_dtype": torch.float16, "trust_remote_code": True}
|
| 87 |
+
pipe = URSAPipeline.from_pretrained(model_id, **model_args)
|
| 88 |
+
pipe = pipe.to(torch.device("cuda"))
|
| 89 |
+
|
| 90 |
+
prompt = "The bear, calm and still, gazes upward as if lost in contemplation of the cosmos."
|
| 91 |
+
negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly"
|
| 92 |
+
|
| 93 |
+
image = pipe(**locals()).frames[0]
|
| 94 |
+
image.save("ursa.jpg")
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### 🎬 Video Generation
|
| 98 |
+
<a id="quickstart-video-generation"></a>
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
import os, torch, numpy
|
| 102 |
+
from diffnext.pipelines import URSAPipeline
|
| 103 |
+
from diffnext.utils import export_to_video
|
| 104 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 105 |
+
|
| 106 |
+
model_id, height, width = "BAAI/URSA-1.7B-FSQ320", 320, 512
|
| 107 |
+
model_args = {"torch_dtype": torch.float16, "trust_remote_code": True}
|
| 108 |
+
pipe = URSAPipeline.from_pretrained(model_id, **model_args)
|
| 109 |
+
pipe = pipe.to(torch.device("cuda"))
|
| 110 |
+
|
| 111 |
+
text_prompt = "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur."
|
| 112 |
+
negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly"
|
| 113 |
+
|
| 114 |
+
# Text-to-Image
|
| 115 |
+
prompt = text_prompt
|
| 116 |
+
num_frames, num_inference_steps = 1, 25
|
| 117 |
+
image = pipe(**locals()).frames[0]
|
| 118 |
+
image.save("ursa.jpg")
|
| 119 |
+
|
| 120 |
+
# Image-to-Video
|
| 121 |
+
prompt = f"motion=9.0, {text_prompt}"
|
| 122 |
+
num_frames, num_inference_steps = 49, 50
|
| 123 |
+
video = pipe(**locals()).frames[0]
|
| 124 |
+
export_to_video(video, "ursa_1+48f.mp4", fps=12)
|
| 125 |
+
|
| 126 |
+
# Text-to-Video
|
| 127 |
+
image, video = None, None
|
| 128 |
+
prompt = f"motion=9.0, {text_prompt}"
|
| 129 |
+
num_frames, num_inference_steps = 49, 50
|
| 130 |
+
video = pipe(**locals()).frames[0]
|
| 131 |
+
export_to_video(video, "ursa_49f.mp4", fps=12)
|
| 132 |
+
|
| 133 |
+
# Video-to-Video
|
| 134 |
+
prompt = f"motion=5.0, {text_prompt}"
|
| 135 |
+
num_frames, num_inference_steps = 49, 50
|
| 136 |
+
num_cond_frames, cond_noise_scale = 13, 0.1
|
| 137 |
+
for i in range(12):
|
| 138 |
+
video, start_video = video[-num_cond_frames:], video
|
| 139 |
+
video = pipe(**locals()).frames[0]
|
| 140 |
+
video = numpy.concatenate([start_video, video[num_cond_frames:]])
|
| 141 |
+
export_to_video(video, "ursa_{}f.mp4".format(video.shape[0]), fps=12)
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
## 💻 Gradio Demo
|
| 145 |
+
<a id="gradio-demo"></a>
|
| 146 |
+
|
| 147 |
+
```bash
|
| 148 |
+
# Text-to-Image (T2I)
|
| 149 |
+
python scripts/app_ursa_t2i.py --model "BAAI/URSA-1.7B-IBQ1024" --device 0
|
| 150 |
+
|
| 151 |
+
# Text-to-Image-to-Video (TI2V)
|
| 152 |
+
python scripts/app_ursa_ti2v.py --model "BAAI/URSA-1.7B-FSQ320" --device 0
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
## 📋 Todo List
|
| 156 |
+
- [X] [Model Zoo](#model-zoo)
|
| 157 |
+
- [X] [Quick Start](#quick-start)
|
| 158 |
+
- [X] [Gradio Demo](#gradio-demo)
|
| 159 |
+
- [X] [Evaluation Guide](./docs/evaluation.md)
|
| 160 |
+
- [X] [Training Guide](./docs/training.md)
|
| 161 |
+
- [ ] 4B Model
|
| 162 |
+
|
| 163 |
+
## 📖 Citation
|
| 164 |
+
If you find this repository useful, please consider giving a star ⭐ and citation 🦖:
|
| 165 |
+
```
|
| 166 |
+
@article{deng2025ursa,
|
| 167 |
+
title={Uniform Discrete Diffusion with Metric Path for Video Generation},
|
| 168 |
+
author={Deng, Haoge and Pan, Ting and Zhang, Fan and Liu, Yang and Luo, Zhuoyan and Cui, Yufeng and Shen, Chunhua and Shan, Shiguang and Zhang, Zhaoxiang and Wang, Xinlong},
|
| 169 |
+
journal={arXiv preprint arXiv:2510.24717},
|
| 170 |
+
year={2025}
|
| 171 |
+
}
|
| 172 |
+
```
|
| 173 |
+
```
|
| 174 |
+
@article{deng2024nova,
|
| 175 |
+
title={Autoregressive Video Generation without Vector Quantization},
|
| 176 |
+
author={Deng, Haoge and Pan, Ting and Diao, Haiwen and Luo, Zhengxiong and Cui, Yufeng and Lu, Huchuan and Shan, Shiguang and Qi, Yonggang and Wang, Xinlong},
|
| 177 |
+
journal={arXiv preprint arXiv:2412.14169},
|
| 178 |
+
year={2024}
|
| 179 |
+
}
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
## 🤗 Acknowledgement
|
| 183 |
+
|
| 184 |
+
We thank the repositories:
|
| 185 |
+
- [NOVA](https://github.com/baaivision/NOVA). ✨NOVA is the predecessor of 🐻URSA.
|
| 186 |
+
- [FlowMatching](https://github.com/facebookresearch/flow_matching). This codebase systemically provides CFM and DFM implementations.
|
| 187 |
+
- [FUDOKI](https://github.com/fudoki-hku/FUDOKI). This codebase provides a naive multimodal DFM implementation.
|
| 188 |
+
- [CodeWithGPU](https://github.com/seetacloud/codewithgpu). CodeWithGPU library is the core of our data loading pipeline.
|
| 189 |
+
|
| 190 |
+
## License
|
| 191 |
+
Code and models are licensed under [Apache License 2.0](LICENSE).
|
URSA/inference.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, torch, numpy
|
| 2 |
+
from diffnext.pipelines import URSAPipeline
|
| 3 |
+
from diffnext.utils import export_to_video
|
| 4 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
model_id, height, width = "BAAI/URSA-1.7B-FSQ320", 320, 512
|
| 9 |
+
model_args = {"torch_dtype": torch.bfloat16, "trust_remote_code": True}
|
| 10 |
+
pipe = URSAPipeline.from_pretrained(model_id, **model_args)
|
| 11 |
+
pipe = pipe.to(torch.device("cuda"))
|
| 12 |
+
|
| 13 |
+
text_prompt = "tom and jerry"#"a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur."
|
| 14 |
+
negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly"
|
| 15 |
+
|
| 16 |
+
import time
|
| 17 |
+
|
| 18 |
+
t1 = time.time()
|
| 19 |
+
|
| 20 |
+
# Text-to-Image
|
| 21 |
+
prompt = text_prompt
|
| 22 |
+
num_frames, num_inference_steps = 1, 25
|
| 23 |
+
image = pipe(**locals()).frames[0]
|
| 24 |
+
image.save("tom/ursa.jpg")
|
| 25 |
+
|
| 26 |
+
t2 = time.time()
|
| 27 |
+
|
| 28 |
+
# Image-to-Video
|
| 29 |
+
prompt = f"motion=9.0, {text_prompt}"
|
| 30 |
+
num_frames, num_inference_steps = 49, 50
|
| 31 |
+
video = pipe(**locals()).frames[0]
|
| 32 |
+
export_to_video(video, "tom/ursa_1+48f.mp4", fps=12)
|
| 33 |
+
|
| 34 |
+
t3 = time.time()
|
| 35 |
+
|
| 36 |
+
# Text-to-Video
|
| 37 |
+
image, video = None, None
|
| 38 |
+
prompt = f"motion=9.0, {text_prompt}"
|
| 39 |
+
num_frames, num_inference_steps = 49, 50
|
| 40 |
+
video = pipe(**locals()).frames[0]
|
| 41 |
+
export_to_video(video, "tom/ursa_49f.mp4", fps=12)
|
| 42 |
+
|
| 43 |
+
t4 = time.time()
|
| 44 |
+
|
| 45 |
+
# Video-to-Video
|
| 46 |
+
prompt = f"motion=5.0, {text_prompt}"
|
| 47 |
+
num_frames, num_inference_steps = 49, 50
|
| 48 |
+
num_cond_frames, cond_noise_scale = 13, 0.1
|
| 49 |
+
for i in range(12):
|
| 50 |
+
video, start_video = video[-num_cond_frames:], video
|
| 51 |
+
video = pipe(**locals()).frames[0]
|
| 52 |
+
video = numpy.concatenate([start_video, video[num_cond_frames:]])
|
| 53 |
+
export_to_video(video, "tom/ursa_{}f.mp4".format(video.shape[0]), fps=12)
|
| 54 |
+
|
| 55 |
+
t5 = time.time()
|
| 56 |
+
|
| 57 |
+
print(f"Text-to-Image time: {t2-t1:.2f} seconds")
|
| 58 |
+
print(f"Image-to-Video time: {t3-t2:.2f} seconds")
|
| 59 |
+
print(f"Text-to-Video time: {t4-t3:.2f} seconds")
|
| 60 |
+
print(f"Video-to-Video time: {t5-t4:.2f} seconds")
|
| 61 |
+
# Single H800 GPU, batch_size=1, the inference time is:
|
| 62 |
+
# Text-to-Image time: 5.05 seconds
|
| 63 |
+
# Image-to-Video time: 101.92 seconds
|
| 64 |
+
# Text-to-Video time: 101.52 seconds
|
| 65 |
+
# Video-to-Video time: 1226.25 seconds
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# cd URSA/
|
| 69 |
+
# source .venv_ursa/bin/activate
|
| 70 |
+
|
| 71 |
+
# accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml --machine_rank 0 --num_machines 1 --num_processes 8 scripts/train_distill_dimo.py config="./configs/distill_dimo.yaml" experiment.output_dir="./experiments/distill_dimo_v3" distill.teacher_ckpt="/gfs/space/private/fengzl/World_Model/URSA-1.7B" distill.prompt_source="/gfs/space/private/fengzl/World_Model/Koala-36M-v1"
|
URSA/pyproject.toml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.black]
|
| 2 |
+
line-length = 100
|
| 3 |
+
target-version = ['py310']
|
URSA/requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
diffusers
|
| 3 |
+
transformers>=4.57.1
|
| 4 |
+
accelerate
|
| 5 |
+
imageio
|
| 6 |
+
imageio-ffmpeg
|
| 7 |
+
omegaconf
|
| 8 |
+
wandb
|
| 9 |
+
scipy
|
| 10 |
+
codewithgpu
|
URSA/setup.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ------------------------------------------------------------------------
|
| 16 |
+
"""Python setup script."""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import os
|
| 20 |
+
import shutil
|
| 21 |
+
import subprocess
|
| 22 |
+
import sys
|
| 23 |
+
|
| 24 |
+
import setuptools
|
| 25 |
+
import setuptools.command.build_py
|
| 26 |
+
import setuptools.command.install
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def parse_args():
|
| 30 |
+
"""Parse arguments."""
|
| 31 |
+
parser = argparse.ArgumentParser()
|
| 32 |
+
parser.add_argument("--version", default=None)
|
| 33 |
+
args, unknown = parser.parse_known_args()
|
| 34 |
+
sys.argv = [sys.argv[0]] + unknown
|
| 35 |
+
args.git_version = None
|
| 36 |
+
args.long_description = ""
|
| 37 |
+
if args.version is None and os.path.exists("version.txt"):
|
| 38 |
+
with open("version.txt", "r") as f:
|
| 39 |
+
args.version = f.read().strip()
|
| 40 |
+
if os.path.exists(".git"):
|
| 41 |
+
try:
|
| 42 |
+
git_version = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd="./")
|
| 43 |
+
args.git_version = git_version.decode("ascii").strip()
|
| 44 |
+
except (OSError, subprocess.CalledProcessError):
|
| 45 |
+
pass
|
| 46 |
+
if os.path.exists("README.md"):
|
| 47 |
+
with open(os.path.join("README.md"), encoding="utf-8") as f:
|
| 48 |
+
args.long_description = f.read()
|
| 49 |
+
return args
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def clean_builds():
|
| 53 |
+
for path in ["build", "diffnext.egg-info"]:
|
| 54 |
+
if os.path.exists(path):
|
| 55 |
+
shutil.rmtree(path)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def find_packages(top):
|
| 59 |
+
"""Return the python sources installed to package."""
|
| 60 |
+
packages = []
|
| 61 |
+
for root, _, _ in os.walk(top):
|
| 62 |
+
if os.path.exists(os.path.join(root, "__init__.py")):
|
| 63 |
+
packages.append(root)
|
| 64 |
+
return packages
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def find_package_data():
|
| 68 |
+
"""Return the external data installed to package."""
|
| 69 |
+
return []
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class BuildPyCommand(setuptools.command.build_py.build_py):
|
| 73 |
+
"""Enhanced 'build_py' command."""
|
| 74 |
+
|
| 75 |
+
def build_packages(self):
|
| 76 |
+
with open("diffnext/version.py", "w") as f:
|
| 77 |
+
f.write(
|
| 78 |
+
'version = "{}"\n'
|
| 79 |
+
'git_version = "{}"\n'
|
| 80 |
+
"__version__ = version\n".format(args.version, args.git_version)
|
| 81 |
+
)
|
| 82 |
+
super(BuildPyCommand, self).build_packages()
|
| 83 |
+
|
| 84 |
+
def build_package_data(self):
|
| 85 |
+
self.package_data = {"diffnext": find_package_data()}
|
| 86 |
+
super(BuildPyCommand, self).build_package_data()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class InstallCommand(setuptools.command.install.install):
|
| 90 |
+
"""Enhanced 'install' command."""
|
| 91 |
+
|
| 92 |
+
def initialize_options(self):
|
| 93 |
+
super(InstallCommand, self).initialize_options()
|
| 94 |
+
self.old_and_unmanageable = True
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
args = parse_args()
|
| 98 |
+
setuptools.setup(
|
| 99 |
+
name="diffnext",
|
| 100 |
+
version=args.version,
|
| 101 |
+
description="A diffusers based library for autoregressive diffusion models.",
|
| 102 |
+
long_description=args.long_description,
|
| 103 |
+
long_description_content_type="text/markdown",
|
| 104 |
+
url="https://github.com/baaivision/URSA",
|
| 105 |
+
author="BAAI",
|
| 106 |
+
license="Apache License",
|
| 107 |
+
packages=find_packages("diffnext"),
|
| 108 |
+
cmdclass={"build_py": BuildPyCommand, "install": InstallCommand},
|
| 109 |
+
install_requires=[
|
| 110 |
+
"torch",
|
| 111 |
+
"diffusers",
|
| 112 |
+
"transformers",
|
| 113 |
+
"accelerate",
|
| 114 |
+
"imageio",
|
| 115 |
+
"imageio-ffmpeg",
|
| 116 |
+
"omegaconf",
|
| 117 |
+
"wandb",
|
| 118 |
+
"scipy",
|
| 119 |
+
],
|
| 120 |
+
classifiers=[
|
| 121 |
+
"Development Status :: 5 - Production/Stable",
|
| 122 |
+
"Intended Audience :: Developers",
|
| 123 |
+
"Intended Audience :: Education",
|
| 124 |
+
"Intended Audience :: Science/Research",
|
| 125 |
+
"License :: OSI Approved :: Apache Software License",
|
| 126 |
+
"Programming Language :: Python :: 3",
|
| 127 |
+
"Programming Language :: Python :: 3 :: Only",
|
| 128 |
+
"Topic :: Scientific/Engineering",
|
| 129 |
+
"Topic :: Scientific/Engineering :: Mathematics",
|
| 130 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 131 |
+
],
|
| 132 |
+
)
|
| 133 |
+
clean_builds()
|
URSA/ursa.jpg
ADDED
|
URSA/version.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0.3.0a0
|