BryanW commited on
Commit
d2253eb
·
verified ·
1 Parent(s): 33030f3

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. Koala-36M-v1/.gitattributes +68 -0
  2. Prism/LICENSE +201 -0
  3. Prism/LLaDA/LLaDA_Baseline/.gitignore +210 -0
  4. Prism/LLaDA/LLaDA_Baseline/LICENSE +21 -0
  5. Prism/LLaDA/LLaDA_Baseline/dllm_eval/__init__.py +7 -0
  6. Prism/LLaDA/LLaDA_Baseline/dllm_eval/__main__.py +527 -0
  7. Prism/LLaDA/LLaDA_Baseline/dllm_eval/evaluator.py +765 -0
  8. Prism/LLaDA/LLaDA_Baseline/dllm_eval/evaluator_utils.py +554 -0
  9. Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/__init__.py +25 -0
  10. Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/custom.py +17 -0
  11. Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/decontamination.py +25 -0
  12. Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/extraction.py +233 -0
  13. Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/selection.py +61 -0
  14. Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/transformation.py +122 -0
  15. Prism/LLaDA/LLaDA_Baseline/dllm_eval/utils.py +552 -0
  16. Prism/LLaDA/LLaDA_Baseline/evaluation_script.py +21 -0
  17. Prism/LLaDA/LLaDA_Baseline/metrics/gsm8k_all.py +286 -0
  18. Prism/LLaDA/LLaDA_Baseline/metrics/humaneval_all.py +183 -0
  19. Prism/LLaDA/LLaDA_Baseline/metrics/math500_all.py +213 -0
  20. Prism/LLaDA/LLaDA_Baseline/metrics/mbpp_all.py +194 -0
  21. Prism/LLaDA/LLaDA_Baseline/requirements.txt +9 -0
  22. Prism/LLaDA/LLaDA_Baseline/scripts/run_gsm8k.sh +32 -0
  23. Prism/LLaDA/LLaDA_Baseline/scripts/run_humaneval.sh +29 -0
  24. Prism/LLaDA/LLaDA_Baseline/scripts/run_math500.sh +29 -0
  25. Prism/LLaDA/LLaDA_Baseline/scripts/run_mbpp.sh +29 -0
  26. Prism/LLaDA/LLaDA_Prism/.gitignore +210 -0
  27. Prism/LLaDA/LLaDA_Prism/LICENSE +21 -0
  28. Prism/LLaDA/LLaDA_Prism/evaluation_script.py +21 -0
  29. Prism/LLaDA/LLaDA_Prism/requirements.txt +9 -0
  30. Prism/LLaDA/LLaDA_Truthfulqa/.gitignore +3 -0
  31. Prism/LLaDA/LLaDA_Truthfulqa/LICENSE +201 -0
  32. Prism/LLaDA/LLaDA_Truthfulqa/eval_llada.py +413 -0
  33. Prism/LLaDA/LLaDA_Truthfulqa/eval_llada_prism.py +333 -0
  34. Prism/README.md +107 -0
  35. URSA-1.7B/.gitattributes +37 -0
  36. URSA-1.7B/.gitignore +55 -0
  37. URSA-1.7B/LICENSE +176 -0
  38. URSA-1.7B/README.md +117 -0
  39. URSA-1.7B/model_index.json +19 -0
  40. URSA/.flake8 +21 -0
  41. URSA/.gitignore +55 -0
  42. URSA/=4.57.1 +70 -0
  43. URSA/LICENSE +176 -0
  44. URSA/README.md +191 -0
  45. URSA/inference.py +71 -0
  46. URSA/pyproject.toml +3 -0
  47. URSA/requirements.txt +10 -0
  48. URSA/setup.py +133 -0
  49. URSA/ursa.jpg +0 -0
  50. URSA/version.txt +1 -0
Koala-36M-v1/.gitattributes ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.lz4 filter=lfs diff=lfs merge=lfs -text
12
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
13
+ *.model filter=lfs diff=lfs merge=lfs -text
14
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
15
+ *.npy filter=lfs diff=lfs merge=lfs -text
16
+ *.npz filter=lfs diff=lfs merge=lfs -text
17
+ *.onnx filter=lfs diff=lfs merge=lfs -text
18
+ *.ot filter=lfs diff=lfs merge=lfs -text
19
+ *.parquet filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pickle filter=lfs diff=lfs merge=lfs -text
22
+ *.pkl filter=lfs diff=lfs merge=lfs -text
23
+ *.pt filter=lfs diff=lfs merge=lfs -text
24
+ *.pth filter=lfs diff=lfs merge=lfs -text
25
+ *.rar filter=lfs diff=lfs merge=lfs -text
26
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
29
+ *.tar filter=lfs diff=lfs merge=lfs -text
30
+ *.tflite filter=lfs diff=lfs merge=lfs -text
31
+ *.tgz filter=lfs diff=lfs merge=lfs -text
32
+ *.wasm filter=lfs diff=lfs merge=lfs -text
33
+ *.xz filter=lfs diff=lfs merge=lfs -text
34
+ *.zip filter=lfs diff=lfs merge=lfs -text
35
+ *.zst filter=lfs diff=lfs merge=lfs -text
36
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
37
+ # Audio files - uncompressed
38
+ *.pcm filter=lfs diff=lfs merge=lfs -text
39
+ *.sam filter=lfs diff=lfs merge=lfs -text
40
+ *.raw filter=lfs diff=lfs merge=lfs -text
41
+ # Audio files - compressed
42
+ *.aac filter=lfs diff=lfs merge=lfs -text
43
+ *.flac filter=lfs diff=lfs merge=lfs -text
44
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
45
+ *.ogg filter=lfs diff=lfs merge=lfs -text
46
+ *.wav filter=lfs diff=lfs merge=lfs -text
47
+ # Image files - uncompressed
48
+ *.bmp filter=lfs diff=lfs merge=lfs -text
49
+ *.gif filter=lfs diff=lfs merge=lfs -text
50
+ *.png filter=lfs diff=lfs merge=lfs -text
51
+ *.tiff filter=lfs diff=lfs merge=lfs -text
52
+ # Image files - compressed
53
+ *.jpg filter=lfs diff=lfs merge=lfs -text
54
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
55
+ *.webp filter=lfs diff=lfs merge=lfs -text
56
+ # Video files - compressed
57
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
58
+ *.webm filter=lfs diff=lfs merge=lfs -text
59
+ Koala_36M_1.csv filter=lfs diff=lfs merge=lfs -text
60
+ Koala_36M_2.csv filter=lfs diff=lfs merge=lfs -text
61
+ Koala_36M_3.csv filter=lfs diff=lfs merge=lfs -text
62
+ Koala_36M_4.csv filter=lfs diff=lfs merge=lfs -text
63
+ Koala_36M_5.csv filter=lfs diff=lfs merge=lfs -text
64
+ Koala_36M_6.csv filter=lfs diff=lfs merge=lfs -text
65
+ Koala_36M_7.csv filter=lfs diff=lfs merge=lfs -text
66
+ Koala_36M_8.csv filter=lfs diff=lfs merge=lfs -text
67
+ Koala_36M_9.csv filter=lfs diff=lfs merge=lfs -text
68
+ Koala_36M_10.csv filter=lfs diff=lfs merge=lfs -text
Prism/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Prism/LLaDA/LLaDA_Baseline/.gitignore ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.jsonl
2
+ *.json
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[codz]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py.cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # UV
101
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ #uv.lock
105
+
106
+ # poetry
107
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111
+ #poetry.lock
112
+ #poetry.toml
113
+
114
+ # pdm
115
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
117
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
118
+ #pdm.lock
119
+ #pdm.toml
120
+ .pdm-python
121
+ .pdm-build/
122
+
123
+ # pixi
124
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
125
+ #pixi.lock
126
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
127
+ # in the .venv directory. It is recommended not to include this directory in version control.
128
+ .pixi
129
+
130
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
131
+ __pypackages__/
132
+
133
+ # Celery stuff
134
+ celerybeat-schedule
135
+ celerybeat.pid
136
+
137
+ # SageMath parsed files
138
+ *.sage.py
139
+
140
+ # Environments
141
+ .env
142
+ .envrc
143
+ .venv
144
+ env/
145
+ venv/
146
+ ENV/
147
+ env.bak/
148
+ venv.bak/
149
+
150
+ # Spyder project settings
151
+ .spyderproject
152
+ .spyproject
153
+
154
+ # Rope project settings
155
+ .ropeproject
156
+
157
+ # mkdocs documentation
158
+ /site
159
+
160
+ # mypy
161
+ .mypy_cache/
162
+ .dmypy.json
163
+ dmypy.json
164
+
165
+ # Pyre type checker
166
+ .pyre/
167
+
168
+ # pytype static type analyzer
169
+ .pytype/
170
+
171
+ # Cython debug symbols
172
+ cython_debug/
173
+
174
+ # PyCharm
175
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
176
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
177
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
178
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
179
+ #.idea/
180
+
181
+ # Abstra
182
+ # Abstra is an AI-powered process automation framework.
183
+ # Ignore directories containing user credentials, local state, and settings.
184
+ # Learn more at https://abstra.io/docs
185
+ .abstra/
186
+
187
+ # Visual Studio Code
188
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
189
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
190
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
191
+ # you could uncomment the following to ignore the entire vscode folder
192
+ # .vscode/
193
+
194
+ # Ruff stuff:
195
+ .ruff_cache/
196
+
197
+ # PyPI configuration file
198
+ .pypirc
199
+
200
+ # Cursor
201
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
202
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
203
+ # refer to https://docs.cursor.com/context/ignore-files
204
+ .cursorignore
205
+ .cursorindexingignore
206
+
207
+ # Marimo
208
+ marimo/_static/
209
+ marimo/_lsp/
210
+ __marimo__/
Prism/LLaDA/LLaDA_Baseline/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 preordinary
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Prism/LLaDA/LLaDA_Baseline/dllm_eval/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ from .evaluator import evaluate, simple_evaluate
5
+
6
+
7
+ __version__ = "0.4.9"
Prism/LLaDA/LLaDA_Baseline/dllm_eval/__main__.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import sys
6
+ from functools import partial
7
+ from pathlib import Path
8
+ from typing import Union
9
+
10
+ from dllm_eval import evaluator, utils
11
+ from dllm_eval.evaluator import request_caching_arg_to_dict
12
+ from dllm_eval.loggers import EvaluationTracker, WandbLogger
13
+ from dllm_eval.tasks import TaskManager
14
+ from dllm_eval.utils import (
15
+ handle_non_serializable,
16
+ make_table,
17
+ simple_parse_args_string,
18
+ )
19
+
20
+
21
+ def try_parse_json(value: str) -> Union[str, dict, None]:
22
+ if value is None:
23
+ return None
24
+ try:
25
+ return json.loads(value)
26
+ except json.JSONDecodeError:
27
+ if "{" in value:
28
+ raise argparse.ArgumentTypeError(
29
+ f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
30
+ )
31
+ return value
32
+
33
+
34
+ def _int_or_none_list_arg_type(
35
+ min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
36
+ ):
37
+ def parse_value(item):
38
+ item = item.strip().lower()
39
+ if item == "none":
40
+ return None
41
+ try:
42
+ return int(item)
43
+ except ValueError:
44
+ raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
45
+
46
+ items = [parse_value(v) for v in value.split(split_char)]
47
+ num_items = len(items)
48
+
49
+ if num_items == 1:
50
+ # Makes downstream handling the same for single and multiple values
51
+ items = items * max_len
52
+ elif num_items < min_len or num_items > max_len:
53
+ raise argparse.ArgumentTypeError(
54
+ f"Argument requires {max_len} integers or None, separated by '{split_char}'"
55
+ )
56
+ elif num_items != max_len:
57
+ logging.warning(
58
+ f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
59
+ "Missing values will be filled with defaults."
60
+ )
61
+ default_items = [parse_value(v) for v in defaults.split(split_char)]
62
+ items.extend(
63
+ default_items[num_items:]
64
+ ) # extend items list with missing defaults
65
+
66
+ return items
67
+
68
+
69
+ def check_argument_types(parser: argparse.ArgumentParser):
70
+ """
71
+ Check to make sure all CLI args are typed, raises error if not
72
+ """
73
+ for action in parser._actions:
74
+ if action.dest != "help" and not action.const:
75
+ if action.type is None:
76
+ raise ValueError(
77
+ f"Argument '{action.dest}' doesn't have a type specified."
78
+ )
79
+ else:
80
+ continue
81
+
82
+
83
+ def setup_parser() -> argparse.ArgumentParser:
84
+ parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
85
+ parser.add_argument(
86
+ "--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
87
+ )
88
+ parser.add_argument(
89
+ "--tasks",
90
+ "-t",
91
+ default=None,
92
+ type=str,
93
+ metavar="task1,task2",
94
+ help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
95
+ )
96
+ parser.add_argument(
97
+ "--model_args",
98
+ "-a",
99
+ default="",
100
+ type=try_parse_json,
101
+ help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""",
102
+ )
103
+ parser.add_argument(
104
+ "--num_fewshot",
105
+ "-f",
106
+ type=int,
107
+ default=None,
108
+ metavar="N",
109
+ help="Number of examples in few-shot context",
110
+ )
111
+ parser.add_argument(
112
+ "--batch_size",
113
+ "-b",
114
+ type=str,
115
+ default=1,
116
+ metavar="auto|auto:N|N",
117
+ help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
118
+ )
119
+ parser.add_argument(
120
+ "--max_batch_size",
121
+ type=int,
122
+ default=None,
123
+ metavar="N",
124
+ help="Maximal batch size to try with --batch_size auto.",
125
+ )
126
+ parser.add_argument(
127
+ "--device",
128
+ type=str,
129
+ default=None,
130
+ help="Device to use (e.g. cuda, cuda:0, cpu).",
131
+ )
132
+ parser.add_argument(
133
+ "--output_path",
134
+ "-o",
135
+ default=None,
136
+ type=str,
137
+ metavar="DIR|DIR/file.json",
138
+ help="Path where result metrics will be saved. Can be either a directory or a .json file. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
139
+ )
140
+ parser.add_argument(
141
+ "--limit",
142
+ "-L",
143
+ type=float,
144
+ default=None,
145
+ metavar="N|0<N<1",
146
+ help="Limit the number of examples per task. "
147
+ "If <1, limit is a percentage of the total number of examples.",
148
+ )
149
+ parser.add_argument(
150
+ "--samples",
151
+ "-E",
152
+ default=None,
153
+ type=str,
154
+ metavar="/path/to/json",
155
+ help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
156
+ )
157
+ parser.add_argument(
158
+ "--use_cache",
159
+ "-c",
160
+ type=str,
161
+ default=None,
162
+ metavar="DIR",
163
+ help="A path to a sqlite db file for caching model responses. `None` if not caching.",
164
+ )
165
+ parser.add_argument(
166
+ "--cache_requests",
167
+ type=str,
168
+ default=None,
169
+ choices=["true", "refresh", "delete"],
170
+ help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
171
+ )
172
+ parser.add_argument(
173
+ "--check_integrity",
174
+ action="store_true",
175
+ help="Whether to run the relevant part of the test suite for the tasks.",
176
+ )
177
+ parser.add_argument(
178
+ "--write_out",
179
+ "-w",
180
+ action="store_true",
181
+ default=False,
182
+ help="Prints the prompt for the first few documents.",
183
+ )
184
+ parser.add_argument(
185
+ "--log_samples",
186
+ "-s",
187
+ action="store_true",
188
+ default=False,
189
+ help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
190
+ )
191
+ parser.add_argument(
192
+ "--system_instruction",
193
+ type=str,
194
+ default=None,
195
+ help="System instruction to be used in the prompt",
196
+ )
197
+ parser.add_argument(
198
+ "--apply_chat_template",
199
+ type=str,
200
+ nargs="?",
201
+ const=True,
202
+ default=False,
203
+ help=(
204
+ "If True, apply chat template to the prompt. "
205
+ "Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
206
+ "To apply a specific template from the available list of templates, provide the template name as an argument. "
207
+ "E.g. `--apply_chat_template template_name`"
208
+ ),
209
+ )
210
+ parser.add_argument(
211
+ "--fewshot_as_multiturn",
212
+ action="store_true",
213
+ default=False,
214
+ help="If True, uses the fewshot as a multi-turn conversation",
215
+ )
216
+ parser.add_argument(
217
+ "--show_config",
218
+ action="store_true",
219
+ default=False,
220
+ help="If True, shows the the full config of all tasks at the end of the evaluation.",
221
+ )
222
+ parser.add_argument(
223
+ "--include_path",
224
+ type=str,
225
+ default=None,
226
+ metavar="DIR",
227
+ help="Additional path to include if there are external tasks to include.",
228
+ )
229
+ parser.add_argument(
230
+ "--gen_kwargs",
231
+ type=try_parse_json,
232
+ default=None,
233
+ help=(
234
+ "Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
235
+ """ e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
236
+ ),
237
+ )
238
+ parser.add_argument(
239
+ "--verbosity",
240
+ "-v",
241
+ type=str.upper,
242
+ default=None,
243
+ metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
244
+ help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.",
245
+ )
246
+ parser.add_argument(
247
+ "--wandb_args",
248
+ type=str,
249
+ default="",
250
+ help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
251
+ )
252
+ parser.add_argument(
253
+ "--wandb_config_args",
254
+ type=str,
255
+ default="",
256
+ help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3",
257
+ )
258
+ parser.add_argument(
259
+ "--hf_hub_log_args",
260
+ type=str,
261
+ default="",
262
+ help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
263
+ )
264
+ parser.add_argument(
265
+ "--predict_only",
266
+ "-x",
267
+ action="store_true",
268
+ default=False,
269
+ help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
270
+ )
271
+ default_seed_string = "0,1234,1234,1234"
272
+ parser.add_argument(
273
+ "--seed",
274
+ type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
275
+ default=default_seed_string, # for backward compatibility
276
+ help=(
277
+ "Set seed for python's random, numpy, torch, and fewshot sampling.\n"
278
+ "Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
279
+ "respectively, or a single integer to set the same seed for all four.\n"
280
+ f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
281
+ "(for backward compatibility).\n"
282
+ "E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
283
+ "Here numpy's seed is not set since the second value is `None`.\n"
284
+ "E.g, `--seed 42` sets all four seeds to 42."
285
+ ),
286
+ )
287
+ parser.add_argument(
288
+ "--trust_remote_code",
289
+ action="store_true",
290
+ help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
291
+ )
292
+ parser.add_argument(
293
+ "--confirm_run_unsafe_code",
294
+ action="store_true",
295
+ help="Confirm that you understand the risks of running unsafe code for tasks that require it",
296
+ )
297
+ parser.add_argument(
298
+ "--metadata",
299
+ type=json.loads,
300
+ default=None,
301
+ help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
302
+ )
303
+ return parser
304
+
305
+
306
+ def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
307
+ check_argument_types(parser)
308
+ return parser.parse_args()
309
+
310
+
311
+ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
312
+ if not args:
313
+ # we allow for args to be passed externally, else we parse them ourselves
314
+ parser = setup_parser()
315
+ args = parse_eval_args(parser)
316
+
317
+ if args.wandb_args:
318
+ wandb_args_dict = simple_parse_args_string(args.wandb_args)
319
+ wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args)
320
+ wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict)
321
+
322
+ utils.setup_logging(args.verbosity)
323
+ eval_logger = logging.getLogger(__name__)
324
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
325
+
326
+ # update the evaluation tracker args with the output path and the HF token
327
+ if args.output_path:
328
+ args.hf_hub_log_args += f",output_path={args.output_path}"
329
+ if os.environ.get("HF_TOKEN", None):
330
+ args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}"
331
+ evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
332
+ evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
333
+
334
+ if args.predict_only:
335
+ args.log_samples = True
336
+ if (args.log_samples or args.predict_only) and not args.output_path:
337
+ raise ValueError(
338
+ "Specify --output_path if providing --log_samples or --predict_only"
339
+ )
340
+
341
+ if args.fewshot_as_multiturn and args.apply_chat_template is False:
342
+ raise ValueError(
343
+ "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
344
+ )
345
+
346
+ if args.include_path is not None:
347
+ eval_logger.info(f"Including path: {args.include_path}")
348
+ metadata = (
349
+ simple_parse_args_string(args.model_args)
350
+ if isinstance(args.model_args, str)
351
+ else args.model_args
352
+ if isinstance(args.model_args, dict)
353
+ else {}
354
+ ) | (
355
+ args.metadata
356
+ if isinstance(args.metadata, dict)
357
+ else simple_parse_args_string(args.metadata)
358
+ )
359
+
360
+ task_manager = TaskManager(include_path=args.include_path, metadata=metadata)
361
+
362
+ if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
363
+ eval_logger.warning(
364
+ "Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
365
+ )
366
+
367
+ if args.limit:
368
+ eval_logger.warning(
369
+ " --limit SHOULD ONLY BE USED FOR TESTING."
370
+ "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
371
+ )
372
+ if args.samples:
373
+ assert args.limit is None, (
374
+ "If --samples is not None, then --limit must be None."
375
+ )
376
+ if (samples := Path(args.samples)).is_file():
377
+ args.samples = json.loads(samples.read_text())
378
+ else:
379
+ args.samples = json.loads(args.samples)
380
+
381
+ if args.tasks is None:
382
+ eval_logger.error("Need to specify task to evaluate.")
383
+ sys.exit()
384
+ elif args.tasks == "list":
385
+ print(task_manager.list_all_tasks())
386
+ sys.exit()
387
+ elif args.tasks == "list_groups":
388
+ print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
389
+ sys.exit()
390
+ elif args.tasks == "list_tags":
391
+ print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
392
+ sys.exit()
393
+ elif args.tasks == "list_subtasks":
394
+ print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
395
+ sys.exit()
396
+ else:
397
+ if os.path.isdir(args.tasks):
398
+ import glob
399
+
400
+ task_names = []
401
+ yaml_path = os.path.join(args.tasks, "*.yaml")
402
+ for yaml_file in glob.glob(yaml_path):
403
+ config = utils.load_yaml_config(yaml_file)
404
+ task_names.append(config)
405
+ else:
406
+ task_list = args.tasks.split(",")
407
+ task_names = task_manager.match_tasks(task_list)
408
+ for task in [task for task in task_list if task not in task_names]:
409
+ if os.path.isfile(task):
410
+ config = utils.load_yaml_config(task)
411
+ task_names.append(config)
412
+ task_missing = [
413
+ task for task in task_list if task not in task_names and "*" not in task
414
+ ] # we don't want errors if a wildcard ("*") task name was used
415
+
416
+ if task_missing:
417
+ missing = ", ".join(task_missing)
418
+ eval_logger.error(
419
+ f"Tasks were not found: {missing}\n"
420
+ f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
421
+ )
422
+ raise ValueError(
423
+ f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
424
+ )
425
+
426
+ # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
427
+ if args.trust_remote_code:
428
+ eval_logger.info(
429
+ "Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
430
+ )
431
+ # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
432
+ # because it's already been determined based on the prior env var before launching our
433
+ # script--`datasets` gets imported by dllm_eval internally before these lines can update the env.
434
+ import datasets
435
+
436
+ datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
437
+
438
+ args.model_args = args.model_args + ",trust_remote_code=True"
439
+ (
440
+ eval_logger.info(f"Selected Tasks: {task_names}")
441
+ if eval_logger.getEffectiveLevel() >= logging.INFO
442
+ else print(f"Selected Tasks: {task_names}")
443
+ )
444
+
445
+ request_caching_args = request_caching_arg_to_dict(
446
+ cache_requests=args.cache_requests
447
+ )
448
+
449
+ results = evaluator.simple_evaluate(
450
+ model=args.model,
451
+ model_args=args.model_args,
452
+ tasks=task_names,
453
+ num_fewshot=args.num_fewshot,
454
+ batch_size=args.batch_size,
455
+ max_batch_size=args.max_batch_size,
456
+ device=args.device,
457
+ use_cache=args.use_cache,
458
+ limit=args.limit,
459
+ samples=args.samples,
460
+ check_integrity=args.check_integrity,
461
+ write_out=args.write_out,
462
+ log_samples=args.log_samples,
463
+ evaluation_tracker=evaluation_tracker,
464
+ system_instruction=args.system_instruction,
465
+ apply_chat_template=args.apply_chat_template,
466
+ fewshot_as_multiturn=args.fewshot_as_multiturn,
467
+ gen_kwargs=args.gen_kwargs,
468
+ task_manager=task_manager,
469
+ predict_only=args.predict_only,
470
+ random_seed=args.seed[0],
471
+ numpy_random_seed=args.seed[1],
472
+ torch_random_seed=args.seed[2],
473
+ fewshot_random_seed=args.seed[3],
474
+ confirm_run_unsafe_code=args.confirm_run_unsafe_code,
475
+ metadata=metadata,
476
+ **request_caching_args,
477
+ )
478
+
479
+ if results is not None:
480
+ if args.log_samples:
481
+ samples = results.pop("samples")
482
+ dumped = json.dumps(
483
+ results, indent=2, default=handle_non_serializable, ensure_ascii=False
484
+ )
485
+ if args.show_config:
486
+ print(dumped)
487
+
488
+ batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
489
+
490
+ # Add W&B logging
491
+ if args.wandb_args:
492
+ try:
493
+ wandb_logger.post_init(results)
494
+ wandb_logger.log_eval_result()
495
+ if args.log_samples:
496
+ wandb_logger.log_eval_samples(samples)
497
+ except Exception as e:
498
+ eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
499
+
500
+ evaluation_tracker.save_results_aggregated(
501
+ results=results, samples=samples if args.log_samples else None
502
+ )
503
+
504
+ if args.log_samples:
505
+ for task_name, config in results["configs"].items():
506
+ evaluation_tracker.save_results_samples(
507
+ task_name=task_name, samples=samples[task_name]
508
+ )
509
+
510
+ if (
511
+ evaluation_tracker.push_results_to_hub
512
+ or evaluation_tracker.push_samples_to_hub
513
+ ):
514
+ evaluation_tracker.recreate_metadata_card()
515
+
516
+ print(
517
+ f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
518
+ f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
519
+ )
520
+
521
+ if args.wandb_args:
522
+ # Tear down wandb run once all the logging is done.
523
+ wandb_logger.run.finish()
524
+
525
+
526
+ if __name__ == "__main__":
527
+ cli_evaluate()
Prism/LLaDA/LLaDA_Baseline/dllm_eval/evaluator.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import json
3
+ import logging
4
+ import random
5
+ import time
6
+ from collections import defaultdict
7
+ from typing import TYPE_CHECKING, List, Optional, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ import dllm_eval.api.metrics
13
+ import dllm_eval.api.registry
14
+ import dllm_eval.api.task
15
+ import dllm_eval.models
16
+ from dllm_eval.caching.cache import delete_cache
17
+ from dllm_eval.evaluator_utils import (
18
+ consolidate_group_results,
19
+ consolidate_results,
20
+ get_sample_size,
21
+ get_subtask_list,
22
+ get_task_list,
23
+ prepare_print_tasks,
24
+ print_writeout,
25
+ run_task_tests,
26
+ )
27
+ from dllm_eval.loggers import EvaluationTracker
28
+ from dllm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
29
+ from dllm_eval.tasks import TaskManager, get_task_dict
30
+ from dllm_eval.utils import (
31
+ handle_non_serializable,
32
+ hash_string,
33
+ positional_deprecated,
34
+ setup_logging,
35
+ simple_parse_args_string,
36
+ )
37
+
38
+
39
+ if TYPE_CHECKING:
40
+ from dllm_eval.api.model import LM
41
+ from dllm_eval.api.task import Task
42
+
43
+ eval_logger = logging.getLogger(__name__)
44
+
45
+
46
+ @positional_deprecated
47
+ def simple_evaluate(
48
+ model,
49
+ model_args: Optional[Union[str, dict]] = None,
50
+ tasks: Optional[List[Union[str, dict, object]]] = None,
51
+ num_fewshot: Optional[int] = None,
52
+ batch_size: Optional[Union[int, str]] = None,
53
+ max_batch_size: Optional[int] = None,
54
+ device: Optional[str] = None,
55
+ use_cache: Optional[str] = None,
56
+ cache_requests: bool = False,
57
+ rewrite_requests_cache: bool = False,
58
+ delete_requests_cache: bool = False,
59
+ limit: Optional[Union[int, float]] = None,
60
+ samples: Optional[dict] = None,
61
+ bootstrap_iters: int = 100000,
62
+ check_integrity: bool = False,
63
+ write_out: bool = False,
64
+ log_samples: bool = True,
65
+ evaluation_tracker: Optional[EvaluationTracker] = None,
66
+ system_instruction: Optional[str] = None,
67
+ apply_chat_template: Union[bool, str] = False,
68
+ fewshot_as_multiturn: bool = False,
69
+ gen_kwargs: Union[str, dict, None] = None,
70
+ task_manager: Optional[TaskManager] = None,
71
+ verbosity=None,
72
+ predict_only: bool = False,
73
+ random_seed: int = 0,
74
+ numpy_random_seed: int = 1234,
75
+ torch_random_seed: int = 1234,
76
+ fewshot_random_seed: int = 1234,
77
+ confirm_run_unsafe_code: bool = False,
78
+ metadata: Optional[dict] = None,
79
+ ):
80
+ """Instantiate and evaluate a model on a list of tasks.
81
+
82
+ :param model: Union[str, LM]
83
+ Name of model or LM object, see dllm_eval.models.get_model
84
+ :param model_args: Optional[str, dict]
85
+ String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
86
+ Ignored if `model` argument is a LM object.
87
+ :param tasks: list[Union[str, dict, Task]]
88
+ List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
89
+ :param num_fewshot: int
90
+ Number of examples in few-shot context
91
+ :param batch_size: int or str, optional
92
+ Batch size for model
93
+ :param max_batch_size: int, optional
94
+ Maximal batch size to try with automatic batch size detection
95
+ :param device: str, optional
96
+ PyTorch device (e.g. "cpu" or "cuda:0") for running models
97
+ :param use_cache: str, optional
98
+ A path to a sqlite db file for caching model responses. `None` if not caching.
99
+ :param cache_requests: bool, optional
100
+ Speed up evaluation by caching the building of dataset requests. `None` if not caching.
101
+ :param rewrite_requests_cache: bool, optional
102
+ Rewrites all the request cache if set to `True`. `None` if not desired.
103
+ :param delete_requests_cache: bool, optional
104
+ Deletes all the request cache if set to `True`. `None` if not desired.
105
+ :param limit: int or float, optional
106
+ Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
107
+ :param samples: dictionary, optional
108
+ Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
109
+ :param bootstrap_iters:
110
+ Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
111
+ :param check_integrity: bool
112
+ Whether to run the relevant part of the test suite for the tasks
113
+ :param write_out: bool
114
+ If True, write out an example document and model input for checking task integrity
115
+ :param log_samples: bool
116
+ If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
117
+ :param system_instruction: str
118
+ System instruction to be applied to the prompt
119
+ :param apply_chat_template: Union[bool, str]
120
+ Specifies whether to apply a chat template to the prompt.
121
+ - If set to True, the default chat template is applied.
122
+ - If set to a string, applies the specified chat template by name.
123
+ Defaults to False (no chat template applied).
124
+ :param fewshot_as_multiturn: bool
125
+ Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
126
+ :param gen_kwargs: dict or comma-separated string
127
+ Arguments for model generation
128
+ Ignored for all tasks with loglikelihood output_type
129
+ :param verbosity: str
130
+ Verbosity level for logging
131
+ :param predict_only: bool
132
+ If true only model outputs will be generated and returned. Metrics will not be evaluated
133
+ :param random_seed: int
134
+ Random seed for python's random module. If set to None, the seed will not be set.
135
+ :param numpy_random_seed: int
136
+ Random seed for numpy. If set to None, the seed will not be set.
137
+ :param torch_random_seed: int
138
+ Random seed for torch. If set to None, the seed will not be set.
139
+ :param fewshot_random_seed: int
140
+ Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
141
+ :param metadata: dict
142
+ Additional metadata to be added to the task manager. Will get passed to the download function of the task.
143
+
144
+ return
145
+ Dictionary of results
146
+ """
147
+ if verbosity is not None:
148
+ setup_logging(verbosity=verbosity)
149
+ start_date = time.time()
150
+
151
+ if limit is not None and samples is not None:
152
+ raise ValueError(
153
+ "Either 'limit' or 'samples' must be None, but both are not None."
154
+ )
155
+
156
+ if (
157
+ (isinstance(model_args, str) and "inst" in model_args.lower())
158
+ or (
159
+ isinstance(model_args, dict)
160
+ and any("inst" in str(v).lower() for v in model_args.values())
161
+ )
162
+ ) and not apply_chat_template:
163
+ eval_logger.warning(
164
+ "Model appears to be an instruct variant but chat template is not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
165
+ )
166
+
167
+ if delete_requests_cache:
168
+ eval_logger.info("Deleting requests cache...")
169
+ delete_cache()
170
+
171
+ seed_message = []
172
+ if random_seed is not None:
173
+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
174
+ seed_message.append(f"Setting random seed to {random_seed}")
175
+ random.seed(random_seed)
176
+
177
+ if numpy_random_seed is not None:
178
+ seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
179
+ np.random.seed(numpy_random_seed)
180
+
181
+ if torch_random_seed is not None:
182
+ seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
183
+ torch.manual_seed(torch_random_seed)
184
+
185
+ if fewshot_random_seed is not None:
186
+ seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
187
+
188
+ if seed_message:
189
+ eval_logger.info(" | ".join(seed_message))
190
+
191
+ if tasks is None:
192
+ tasks = []
193
+ if len(tasks) == 0:
194
+ raise ValueError(
195
+ "No tasks specified, or no tasks found. Please verify the task names."
196
+ )
197
+
198
+ if gen_kwargs is not None:
199
+ if isinstance(gen_kwargs, str):
200
+ gen_kwargs = simple_parse_args_string(gen_kwargs)
201
+ eval_logger.warning(
202
+ f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
203
+ "Ensure 'do_sample=True' for non-greedy decoding!"
204
+ )
205
+ if not gen_kwargs:
206
+ gen_kwargs = None
207
+
208
+ if isinstance(model, str):
209
+ if model_args is None:
210
+ eval_logger.warning("model_args not specified. Using defaults.")
211
+ model_args = ""
212
+
213
+ if isinstance(model_args, dict):
214
+ eval_logger.info(
215
+ f"Initializing {model} model, with arguments: {model_args}"
216
+ )
217
+ lm = dllm_eval.api.registry.get_model(model).create_from_arg_obj(
218
+ model_args,
219
+ {
220
+ "batch_size": batch_size,
221
+ "max_batch_size": max_batch_size,
222
+ "device": device,
223
+ },
224
+ )
225
+
226
+ else:
227
+ eval_logger.info(
228
+ f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
229
+ )
230
+ lm = dllm_eval.api.registry.get_model(model).create_from_arg_string(
231
+ model_args,
232
+ {
233
+ "batch_size": batch_size,
234
+ "max_batch_size": max_batch_size,
235
+ "device": device,
236
+ },
237
+ )
238
+ else:
239
+ if not isinstance(model, dllm_eval.api.model.LM):
240
+ raise TypeError(
241
+ f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of dllm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `dllm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
242
+ )
243
+ eval_logger.info("Using pre-initialized model")
244
+ lm = model
245
+
246
+ if use_cache is not None:
247
+ eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
248
+ lm = dllm_eval.api.model.CachingLM(
249
+ lm,
250
+ use_cache
251
+ # each rank receives a different cache db.
252
+ # necessary to avoid multiple writes to cache at once
253
+ + "_rank"
254
+ + str(lm.rank)
255
+ + ".db",
256
+ )
257
+
258
+ if task_manager is None:
259
+ metadata = (
260
+ simple_parse_args_string(model_args)
261
+ if isinstance(model_args, str)
262
+ else model_args
263
+ if isinstance(model_args, dict)
264
+ else {}
265
+ ) | (metadata or {})
266
+ task_manager = TaskManager(metadata=metadata)
267
+
268
+ task_dict = get_task_dict(
269
+ tasks,
270
+ task_manager,
271
+ )
272
+
273
+ # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
274
+ # (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
275
+ def _adjust_config(task_dict):
276
+ adjusted_task_dict = {}
277
+ for task_name, task_obj in task_dict.items():
278
+ if isinstance(task_obj, dict):
279
+ adjusted_task_dict = {
280
+ **adjusted_task_dict,
281
+ **{task_name: _adjust_config(task_obj)},
282
+ }
283
+
284
+ else:
285
+ if task_obj.get_config("output_type") == "generate_until":
286
+ if gen_kwargs is not None:
287
+ task_obj.set_config(
288
+ key="generation_kwargs", value=gen_kwargs, update=True
289
+ )
290
+ eval_logger.info(
291
+ f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
292
+ )
293
+
294
+ if predict_only:
295
+ eval_logger.info(
296
+ f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
297
+ )
298
+ # we have to change the class properties post-hoc. This is pretty hacky.
299
+ task_obj.override_metric(metric_name="bypass")
300
+
301
+ # override tasks' fewshot values to the provided num_fewshot arg value
302
+ # except if tasks have it set to 0 manually in their configs--then we should never overwrite that
303
+ if num_fewshot is not None:
304
+ if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
305
+ eval_logger.info(
306
+ f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
307
+ )
308
+ else:
309
+ eval_logger.warning(
310
+ f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
311
+ )
312
+ task_obj.set_config(key="num_fewshot", value=num_fewshot)
313
+ else:
314
+ # if num_fewshot not provided, and the task does not define a default one, default to 0
315
+ if (
316
+ default_num_fewshot := task_obj.get_config("num_fewshot")
317
+ ) is None:
318
+ task_obj.set_config(key="num_fewshot", value=0)
319
+ # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
320
+ task_obj.set_fewshot_seed(seed=fewshot_random_seed)
321
+
322
+ adjusted_task_dict[task_name] = task_obj
323
+
324
+ return adjusted_task_dict
325
+
326
+ task_dict = _adjust_config(task_dict)
327
+
328
+ if check_integrity:
329
+ run_task_tests(task_list=tasks)
330
+
331
+ if evaluation_tracker is not None:
332
+ evaluation_tracker.general_config_tracker.log_experiment_args(
333
+ model_source=model,
334
+ model_args=model_args,
335
+ system_instruction=system_instruction,
336
+ chat_template=lm.chat_template(apply_chat_template)
337
+ if apply_chat_template
338
+ else None,
339
+ fewshot_as_multiturn=fewshot_as_multiturn,
340
+ )
341
+
342
+ results = evaluate(
343
+ lm=lm,
344
+ task_dict=task_dict,
345
+ limit=limit,
346
+ samples=samples,
347
+ cache_requests=cache_requests,
348
+ rewrite_requests_cache=rewrite_requests_cache,
349
+ bootstrap_iters=bootstrap_iters,
350
+ write_out=write_out,
351
+ log_samples=True if predict_only else log_samples,
352
+ system_instruction=system_instruction,
353
+ apply_chat_template=apply_chat_template,
354
+ fewshot_as_multiturn=fewshot_as_multiturn,
355
+ verbosity=verbosity,
356
+ confirm_run_unsafe_code=confirm_run_unsafe_code,
357
+ )
358
+ if verbosity is not None:
359
+ setup_logging(verbosity=verbosity)
360
+
361
+ if lm.rank == 0:
362
+ if isinstance(model, str):
363
+ model_name = model
364
+ elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
365
+ model_name = model.config._name_or_path
366
+ else:
367
+ model_name = type(model).__name__
368
+
369
+ # add info about the model and few shot config
370
+ results["config"] = {
371
+ "model": model_name,
372
+ "model_args": model_args,
373
+ }
374
+ # add more detailed model info if available
375
+ if isinstance(lm, dllm_eval.models.huggingface.HFLM):
376
+ results["config"].update(lm.get_model_info())
377
+ # add info about execution
378
+ results["config"].update(
379
+ {
380
+ "batch_size": batch_size,
381
+ "batch_sizes": (
382
+ list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
383
+ ),
384
+ "device": device,
385
+ "use_cache": use_cache,
386
+ "limit": limit,
387
+ "bootstrap_iters": bootstrap_iters,
388
+ "gen_kwargs": gen_kwargs,
389
+ "random_seed": random_seed,
390
+ "numpy_seed": numpy_random_seed,
391
+ "torch_seed": torch_random_seed,
392
+ "fewshot_seed": fewshot_random_seed,
393
+ }
394
+ )
395
+ results["git_hash"] = get_git_commit_hash()
396
+ results["date"] = start_date
397
+ add_env_info(results) # additional environment info to results
398
+ add_tokenizer_info(results, lm) # additional info about tokenizer
399
+ return results
400
+ else:
401
+ return None
402
+
403
+
404
+ @positional_deprecated
405
+ def evaluate(
406
+ lm: "LM",
407
+ task_dict,
408
+ limit: Optional[int] = None,
409
+ samples: Optional[dict] = None,
410
+ cache_requests: bool = False,
411
+ rewrite_requests_cache: bool = False,
412
+ bootstrap_iters: Optional[int] = 100000,
413
+ write_out: bool = False,
414
+ log_samples: bool = True,
415
+ system_instruction: Optional[str] = None,
416
+ apply_chat_template: Union[bool, str] = False,
417
+ fewshot_as_multiturn: bool = False,
418
+ verbosity: str = "INFO",
419
+ confirm_run_unsafe_code: bool = False,
420
+ ):
421
+ """Instantiate and evaluate a model on a list of tasks.
422
+
423
+ :param lm: obj
424
+ Language Model
425
+ :param task_dict: dict[str, Task]
426
+ Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
427
+ :param limit: int, optional
428
+ Limit the number of examples per task (only use this for testing)
429
+ :param samples: dictionary, optional
430
+ Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
431
+ :param cache_requests: bool, optional
432
+ Speed up evaluation by caching the building of dataset requests.
433
+ :param rewrite_requests_cache: bool, optional
434
+ Rewrites all the request cache if set to `True`.
435
+ :param bootstrap_iters:
436
+ Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
437
+ :param write_out: bool
438
+ If True, write out an example document and model input for checking task integrity
439
+ :param log_samples: bool
440
+ If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
441
+ :param system_instruction: str
442
+ System instruction to be applied to the prompt
443
+ :param apply_chat_template: Union[bool, str]
444
+ Specifies whether to apply a chat template to the prompt.
445
+ - If set to True, the default chat template is applied.
446
+ - If set to a string, applies the specified chat template by name.
447
+ Defaults to False (no chat template applied).
448
+ :param fewshot_as_multiturn: bool
449
+ Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
450
+ :param verbosity: str
451
+ Verbosity level for logging
452
+ :param confirm_run_unsafe_code: bool
453
+ Whether to confirm running tasks marked as unsafe.
454
+ :return
455
+ Dictionary of results
456
+ """
457
+
458
+ if limit is not None and samples is not None:
459
+ raise ValueError(
460
+ "Either 'limit' or 'samples' must be None, but both are not None."
461
+ )
462
+ if samples is not None:
463
+ eval_logger.info(f"Evaluating examples for tasks {list(samples.keys())}")
464
+ if apply_chat_template:
465
+ eval_logger.warning(
466
+ "Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
467
+ )
468
+ # tracks all Instances/requests a model must generate output on.
469
+ requests = defaultdict(list)
470
+ # stores the amount to pad out reqs per req. type so that
471
+ # number of fwd passes per distributed rank is equal
472
+ padding_requests = defaultdict(int)
473
+
474
+ # get lists of group hierarchy and each type of request
475
+ eval_tasks = get_task_list(task_dict)
476
+ if not log_samples:
477
+ if not all(
478
+ "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
479
+ for task_output in eval_tasks
480
+ ):
481
+ raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
482
+
483
+ # validation checks:
484
+ # 1.are we running multimodal task <-> non-multimodal model class, or vice-versa.
485
+ # 2.are we running code that is marked as unsafe.
486
+ incompatible_tasks = []
487
+ for task_output in eval_tasks:
488
+ task: Task = task_output.task
489
+
490
+ if getattr(task, "MULTIMODAL", False) and not getattr(lm, "MULTIMODAL", False):
491
+ incompatible_tasks.append(task_output.task_name)
492
+ elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
493
+ raise ValueError(
494
+ f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task."
495
+ )
496
+ if len(incompatible_tasks) > 0:
497
+ if not getattr(lm, "MULTIMODAL", False):
498
+ raise ValueError(
499
+ f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
500
+ )
501
+ # end validation check
502
+
503
+ # Cache the limit arg.
504
+ limit_arg = limit
505
+ limits = []
506
+ for task_output in eval_tasks:
507
+ task: Task = task_output.task
508
+
509
+ limit = get_sample_size(task, limit_arg)
510
+ limits.append(limit)
511
+ task.build_all_requests(
512
+ limit=limit,
513
+ samples=samples.get(task_output.task_name, None)
514
+ if samples is not None
515
+ else samples,
516
+ rank=lm.rank,
517
+ world_size=lm.world_size,
518
+ cache_requests=cache_requests,
519
+ rewrite_requests_cache=rewrite_requests_cache,
520
+ system_instruction=system_instruction,
521
+ apply_chat_template=bool(apply_chat_template),
522
+ fewshot_as_multiturn=fewshot_as_multiturn,
523
+ chat_template=getattr(lm, "apply_chat_template")
524
+ if apply_chat_template
525
+ else None,
526
+ tokenizer_name=getattr(lm, "tokenizer_name", "")
527
+ if apply_chat_template
528
+ else "",
529
+ )
530
+ eval_logger.debug(
531
+ f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
532
+ )
533
+ if write_out:
534
+ print_writeout(task)
535
+ # aggregate Instances by LM method requested to get output.
536
+ for instance in task.instances:
537
+ reqtype = instance.request_type
538
+ requests[reqtype].append(instance)
539
+
540
+ if lm.world_size > 1:
541
+ instances_rnk = torch.tensor(len(task._instances), device=lm.device)
542
+ gathered_item = (
543
+ lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
544
+ )
545
+ # "multiple_choice" task types dispatch (several) "loglikelihood" request types
546
+ reqtype = (
547
+ "loglikelihood"
548
+ if task.OUTPUT_TYPE == "multiple_choice"
549
+ else task.OUTPUT_TYPE
550
+ )
551
+ # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
552
+ numpad = max(gathered_item) - gathered_item[lm.rank]
553
+ # todo: may not account for padding in cases like SquadV2 which has multiple req types
554
+ padding_requests[reqtype] += numpad
555
+
556
+ ### Run LM on inputs, get all outputs ###
557
+ # execute each type of request
558
+ for reqtype, reqs in requests.items():
559
+ eval_logger.info(f"Running {reqtype} requests")
560
+ # create `K` copies of each request `req` based off `K = req.repeats`
561
+ cloned_reqs = []
562
+ for req in reqs:
563
+ cloned_reqs.extend([req] * req.repeats)
564
+
565
+ if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
566
+ for _ in range(padding_requests[reqtype]):
567
+ cloned_reqs.extend([req] * req.repeats)
568
+
569
+ # run requests through model
570
+ resps = getattr(lm, reqtype)(cloned_reqs)
571
+
572
+ # put responses from model into a list of length K for each request.
573
+ for x, req in zip(resps, cloned_reqs):
574
+ req.resps.append(x)
575
+
576
+ if lm.world_size > 1:
577
+ lm.accelerator.wait_for_everyone()
578
+
579
+ RANK = lm.rank
580
+ WORLD_SIZE = lm.world_size
581
+ ### Postprocess outputs ###
582
+ # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
583
+ for task_output, limit in zip(eval_tasks, limits):
584
+ task = task_output.task
585
+ task.apply_filters()
586
+
587
+ ### Collect values of metrics on all datapoints ###
588
+ # # unpack results and sort back in order and return control to Task
589
+ # TODO: make it possible to use a different metric per filter
590
+ # Pre-process task.instances to group by doc_id
591
+ instances_by_doc_id = defaultdict(list)
592
+ for instance in task.instances:
593
+ instances_by_doc_id[instance.doc_id].append(instance)
594
+ # Sort instances within each group
595
+ for instances in instances_by_doc_id.values():
596
+ instances.sort(key=lambda x: x.idx)
597
+ # iterate over different filters used
598
+ for filter_key in task.instances[0].filtered_resps.keys():
599
+ indices = (
600
+ samples.get(task_output.task_name, None)
601
+ if samples is not None
602
+ else None
603
+ )
604
+ doc_iterator = task.doc_iterator(
605
+ rank=RANK,
606
+ limit=limit,
607
+ world_size=WORLD_SIZE,
608
+ samples=indices,
609
+ )
610
+ for doc_id, doc in doc_iterator:
611
+ if indices:
612
+ doc_id_true = indices[doc_id]
613
+ else:
614
+ doc_id_true = doc_id
615
+ requests = instances_by_doc_id[doc_id]
616
+ metrics = task.process_results(
617
+ doc, [req.filtered_resps[filter_key] for req in requests]
618
+ )
619
+ if log_samples:
620
+ target = task.doc_to_target(doc)
621
+ example = {
622
+ "doc_id": doc_id_true,
623
+ "doc": doc,
624
+ "target": target,
625
+ "arguments": [req.args for req in requests],
626
+ "resps": [req.resps for req in requests],
627
+ "filtered_resps": [
628
+ req.filtered_resps[filter_key] for req in requests
629
+ ],
630
+ "filter": filter_key,
631
+ "metrics": list(metrics.keys()),
632
+ "doc_hash": hash_string(
633
+ json.dumps(
634
+ requests[0].doc,
635
+ indent=2,
636
+ default=handle_non_serializable,
637
+ ensure_ascii=False,
638
+ )
639
+ ),
640
+ "prompt_hash": hash_string(requests[0].arguments[0]),
641
+ "target_hash": hash_string(str(target)),
642
+ }
643
+ example.update(metrics)
644
+ task_output.logged_samples.append(example)
645
+ for metric, value in metrics.items():
646
+ task_output.sample_metrics[(metric, filter_key)].append(value)
647
+
648
+ if WORLD_SIZE > 1:
649
+ # if multigpu, then gather data across all ranks to rank 0
650
+ # first gather logged samples across all ranks
651
+ for task_output in eval_tasks:
652
+ if log_samples:
653
+ # for task_name, task_samples in list(samples.items()):
654
+ full_samples = [None] * WORLD_SIZE if RANK == 0 else None
655
+ torch.distributed.gather_object(
656
+ obj=task_output.logged_samples,
657
+ object_gather_list=full_samples,
658
+ dst=0,
659
+ )
660
+
661
+ if RANK == 0:
662
+ task_output.logged_samples = list(
663
+ itertools.chain.from_iterable(full_samples)
664
+ )
665
+
666
+ # then collect metrics across all ranks
667
+ for metrics in task_output.sample_metrics:
668
+ metric_list = [None] * WORLD_SIZE if RANK == 0 else None
669
+ torch.distributed.gather_object(
670
+ obj=task_output.sample_metrics[metrics],
671
+ object_gather_list=metric_list,
672
+ dst=0,
673
+ )
674
+ if RANK == 0:
675
+ task_output.sample_metrics[metrics] = list(
676
+ itertools.chain.from_iterable(metric_list)
677
+ )
678
+
679
+ if RANK == 0:
680
+ ### Aggregate results over all datapoints ###
681
+ # aggregate results ; run bootstrap CIs
682
+ for task_output in eval_tasks:
683
+ task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
684
+ (
685
+ results,
686
+ samples,
687
+ configs,
688
+ versions,
689
+ num_fewshot,
690
+ higher_is_better,
691
+ ) = consolidate_results(eval_tasks)
692
+
693
+ ### Calculate group metrics ###
694
+ if bool(results):
695
+ results, versions, show_group_table, *_ = consolidate_group_results(
696
+ results, versions, task_dict
697
+ )
698
+
699
+ results_agg, group_agg = prepare_print_tasks(task_dict, results)
700
+ subtask_list = get_subtask_list(task_dict)
701
+
702
+ # collect all higher_is_better values for metrics
703
+ # in the group's subtasks.
704
+ # TODO: clean this up ; unify with the below metric_list loop?
705
+ _higher_is_better = {}
706
+ for group, task_list in subtask_list.items():
707
+ if (
708
+ len(task_list) != 0
709
+ ): # subtask list will list "task_name": [] for solo tasks
710
+ for task in task_list:
711
+ for m, h in higher_is_better[task].items():
712
+ if m not in _higher_is_better.keys():
713
+ _higher_is_better[m] = h
714
+
715
+ if (
716
+ m in _higher_is_better
717
+ and _higher_is_better[m] is not None
718
+ and _higher_is_better[m] != h
719
+ ):
720
+ eval_logger.warning(
721
+ f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
722
+ )
723
+ _higher_is_better[m] = None
724
+ higher_is_better[group] = _higher_is_better
725
+
726
+ results_dict = {
727
+ "results": dict(results_agg.items()),
728
+ **(
729
+ {"groups": dict(group_agg.items())}
730
+ if (bool(group_agg) & show_group_table)
731
+ else {}
732
+ ),
733
+ "group_subtasks": dict(reversed(subtask_list.items())),
734
+ "configs": dict(sorted(configs.items())),
735
+ "versions": dict(sorted(versions.items())),
736
+ "n-shot": dict(sorted(num_fewshot.items())),
737
+ "higher_is_better": dict(sorted(higher_is_better.items())),
738
+ "n-samples": {
739
+ task_output.task_name: {
740
+ "original": len(task_output.task.eval_docs),
741
+ "effective": min(
742
+ limit if limit else len(task_output.task.eval_docs),
743
+ len(task_output.task.eval_docs),
744
+ ),
745
+ }
746
+ for task_output, limit in zip(eval_tasks, limits)
747
+ },
748
+ }
749
+ if log_samples:
750
+ results_dict["samples"] = dict(samples)
751
+
752
+ return results_dict
753
+
754
+ else:
755
+ return None
756
+
757
+
758
+ def request_caching_arg_to_dict(cache_requests: str) -> dict:
759
+ request_caching_args = {
760
+ "cache_requests": cache_requests in {"true", "refresh"},
761
+ "rewrite_requests_cache": cache_requests == "refresh",
762
+ "delete_requests_cache": cache_requests == "delete",
763
+ }
764
+
765
+ return request_caching_args
Prism/LLaDA/LLaDA_Baseline/dllm_eval/evaluator_utils.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import math
4
+ import pathlib
5
+ import sys
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ from dllm_eval.api.group import ConfigurableGroup
9
+ from dllm_eval.api.metrics import (
10
+ aggregate_subtask_metrics,
11
+ mean,
12
+ pooled_sample_stderr,
13
+ stderr_for_metric,
14
+ )
15
+ from dllm_eval.api.task import Task
16
+ from dllm_eval.utils import positional_deprecated
17
+
18
+
19
+ eval_logger = logging.getLogger(__name__)
20
+
21
+
22
+ class TaskOutput:
23
+ """
24
+ Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task.
25
+
26
+ Attributes:
27
+ task (object): The task object.
28
+ task_name (str): The name of the task.
29
+ task_config (dict): The configuration of the task.
30
+ version (str): The version of the task.
31
+ group_name (str): The name of the task group.
32
+ n_shot (int): The number of shots for the task.
33
+ task_alias (str): The alias of the task.
34
+ group_alias (str): The alias of the task group.
35
+ is_group (bool): Indicates if the task is a group.
36
+ logged_samples (list): The list of logged samples.
37
+ sample_len (int): The length of the samples.
38
+ sample_metrics (defaultdict): The dictionary of samples' metrics.
39
+ agg_metrics (defaultdict): The dictionary of aggregate metrics.
40
+
41
+ Methods:
42
+ from_taskdict(cls, task_name: str, task):
43
+ Creates a TaskOutput instance from a task dictionary.
44
+
45
+ calculate_aggregate_metric(bootstrap_iters=100000) -> None:
46
+ Calculates the aggregate metrics for the task.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ task=None,
52
+ task_name=None,
53
+ task_config=None,
54
+ version=None,
55
+ group_name=None,
56
+ n_shot=None,
57
+ task_alias=None,
58
+ group_alias=None,
59
+ is_group=None,
60
+ ):
61
+ self.task = task
62
+ self.task_config = task_config
63
+ self.task_name = task_name
64
+ self.group_name = group_name
65
+ self.version = version
66
+ self.n_shot = n_shot
67
+ self.task_alias = task_alias
68
+ self.group_alias = group_alias
69
+ self.is_group = is_group
70
+ self.logged_samples = []
71
+ self.sample_len = None
72
+ self.sample_metrics = collections.defaultdict(list)
73
+ self.agg_metrics = collections.defaultdict(list)
74
+
75
+ @classmethod
76
+ def from_taskdict(cls, task_name: str, task):
77
+ if isinstance(task, tuple):
78
+ group_name, task = task
79
+ else:
80
+ group_name = None
81
+ if not task:
82
+ # these gets filtered out in get_task_list
83
+ # once they are added to group hierarchy
84
+ is_group = True
85
+ return cls(
86
+ task=task, task_name=task_name, is_group=is_group, group_name=group_name
87
+ )
88
+ version = task.VERSION
89
+ task_config = dict(task.dump_config())
90
+ if (n_shot := task_config.get("num_fewshot")) == 0:
91
+ n_shot = task_config.get("metadata", {}).get("num_fewshot", 0)
92
+ task_alias = task_config.get("alias")
93
+ group_alias = task_config.get("group_alias")
94
+ return cls(
95
+ task=task,
96
+ task_name=task_name,
97
+ task_config=task_config,
98
+ group_name=group_name,
99
+ version=version,
100
+ n_shot=n_shot,
101
+ task_alias=task_alias,
102
+ group_alias=group_alias,
103
+ )
104
+
105
+ def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
106
+ for (metric, filter_key), items in self.sample_metrics.items():
107
+ try:
108
+ agg_fn = self.task.aggregation()[metric]
109
+ except KeyError:
110
+ # This is when process results output an arbitrary metric
111
+ # TODO: Handle this better and allow other aggregate functions other than mean.
112
+ agg_fn = mean
113
+ metric_key = f"{metric},{filter_key}"
114
+ self.agg_metrics[metric_key] = agg_fn(items)
115
+ self.sample_len = len(items) # TODO: same sample size for each metric?
116
+ if isinstance(bootstrap_iters, int):
117
+ stderr_fn = stderr_for_metric(
118
+ metric=agg_fn,
119
+ bootstrap_iters=min(bootstrap_iters, 100)
120
+ if metric in ["bleu", "chrf", "ter"]
121
+ else bootstrap_iters,
122
+ )
123
+ self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
124
+ stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
125
+ )
126
+ else:
127
+ raise ValueError(
128
+ f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations."
129
+ )
130
+
131
+ def __repr__(self):
132
+ return (
133
+ f"TaskOutput(task_name={self.task_name}, "
134
+ f"group_name={self.group_name}, "
135
+ f"version={self.version}, "
136
+ f"n_shot={self.n_shot}, "
137
+ f"task_alias={self.task_alias}, "
138
+ f"group_alias={self.group_alias})"
139
+ )
140
+
141
+
142
+ def get_task_list(task_dict: dict) -> List[TaskOutput]:
143
+ outputs = []
144
+ for task_name, task_obj in task_dict.items():
145
+ if isinstance(task_obj, dict):
146
+ _outputs = get_task_list(task_obj)
147
+ outputs.extend(_outputs)
148
+ else:
149
+ task_output = TaskOutput.from_taskdict(task_name, task_obj)
150
+ outputs.append(task_output)
151
+
152
+ return outputs
153
+
154
+
155
+ def get_subtask_list(task_dict, task_root=None, depth=0):
156
+ subtask_list = {}
157
+ for group_obj, task_obj in task_dict.items():
158
+ if isinstance(group_obj, ConfigurableGroup):
159
+ # group_name = group_obj.group_name
160
+ group_name = group_obj.group_name
161
+ else:
162
+ group_name = group_obj
163
+ if isinstance(task_obj, dict):
164
+ _subtask_list = get_subtask_list(
165
+ task_obj, task_root=group_name, depth=depth + 1
166
+ )
167
+ if task_root:
168
+ subtask_list.setdefault((task_root, depth), []).extend(
169
+ [
170
+ _task
171
+ for (_task, _depth) in _subtask_list.keys()
172
+ if (_depth - 1) == depth
173
+ ]
174
+ )
175
+
176
+ subtask_list = {**subtask_list, **_subtask_list}
177
+ else:
178
+ if isinstance(task_obj, ConfigurableGroup):
179
+ # group_or_task_name = task_obj.group_name
180
+ group_or_task_name = task_obj.group_name
181
+ elif isinstance(task_obj, Task):
182
+ # group_or_task_name = task_obj.task_name
183
+ group_or_task_name = task_obj.task_name
184
+
185
+ if task_root is None:
186
+ subtask_list.setdefault((group_or_task_name, depth), [])
187
+ else:
188
+ subtask_list.setdefault((task_root, depth), []).append(
189
+ group_or_task_name
190
+ )
191
+
192
+ if depth == 0:
193
+ _subtask_list = {}
194
+ for group_key, task_list in subtask_list.items():
195
+ group_name, depth = group_key
196
+ _subtask_list[group_name] = task_list
197
+ subtask_list = _subtask_list
198
+
199
+ return subtask_list
200
+
201
+
202
+ def print_writeout(task) -> None:
203
+ for inst in task.instances:
204
+ # print the prompt for the first few documents
205
+ if inst.doc_id < 1:
206
+ eval_logger.info(
207
+ f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\
208
+ \n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
209
+ )
210
+ eval_logger.info(f"Request: {str(inst)}")
211
+
212
+
213
+ def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
214
+ if limit is not None:
215
+ limit = (
216
+ int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
217
+ )
218
+ return limit
219
+
220
+
221
+ def prepare_print_tasks(
222
+ task_dict: dict,
223
+ results: dict,
224
+ task_depth=0,
225
+ group_depth=0,
226
+ ) -> Tuple[dict, dict]:
227
+ """
228
+ @param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
229
+ value is a list of task names.
230
+ @param results: Dictionary containing the results of each task. Each key is a
231
+ group name and its value is a dictionary of task results.
232
+ @param task_depth: The indentation level for printing the task
233
+ hierarchy. Default is 0.
234
+ @param group_depth: The indentation level for printing the group
235
+ hierarchy. Default is 0.
236
+ @return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
237
+ aggregated results for each task, and groups_agg contains aggregated results for each group.
238
+
239
+ Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
240
+ """
241
+
242
+ def _sort_task_dict(task_dict):
243
+ """
244
+ Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name.
245
+ Required so that we end up sorting within each sub-header correctly.
246
+ """
247
+
248
+ return dict(
249
+ sorted(
250
+ task_dict.items(),
251
+ key=lambda item: item[0].group_name
252
+ if isinstance(item[0], ConfigurableGroup)
253
+ else item[0],
254
+ )
255
+ )
256
+
257
+ task_agg = collections.defaultdict(dict)
258
+ group_agg = collections.defaultdict(dict)
259
+ task_dict = _sort_task_dict(task_dict)
260
+ for task_or_group_name, task_or_group_obj in task_dict.items():
261
+ tab_string = " " * task_depth + "- " if task_depth > 0 else ""
262
+ if isinstance(task_or_group_name, ConfigurableGroup):
263
+ # string_name = task_or_group_name.group_name
264
+ name = task_or_group_name.group_name
265
+ from_configurable_group = True
266
+ task_or_group_obj = _sort_task_dict(task_or_group_obj)
267
+ elif isinstance(task_or_group_name, str):
268
+ name = task_or_group_name
269
+ if isinstance(task_or_group_obj, Task):
270
+ # string_name = task_or_group_obj.task_name
271
+ name = task_or_group_obj.task_name
272
+ from_configurable_group = False
273
+
274
+ task_agg[name] = results[name].copy()
275
+ if from_configurable_group:
276
+ if task_or_group_name.group_alias is not None:
277
+ alias = task_or_group_name.group_alias
278
+ else:
279
+ alias = task_or_group_name.group
280
+ else:
281
+ if "alias" in task_agg[name]:
282
+ alias = task_agg[name]["alias"]
283
+ else:
284
+ alias = name
285
+
286
+ task_agg[name]["alias"] = tab_string + alias
287
+ if "samples" in task_agg[name]:
288
+ task_agg[name].pop("samples")
289
+
290
+ if from_configurable_group and (" " not in results[name]):
291
+ group_tab_string = " " * group_depth + "- " if group_depth > 0 else ""
292
+ group_agg[name] = results[name].copy()
293
+ group_agg[name]["alias"] = group_tab_string + alias
294
+ if "samples" in group_agg[name]:
295
+ group_agg[name].pop("samples")
296
+
297
+ if isinstance(task_or_group_obj, dict):
298
+ task_depth += 1
299
+ group_depth += 1
300
+ _task_agg, _group_agg = prepare_print_tasks(
301
+ task_or_group_obj, results, task_depth, group_depth
302
+ )
303
+ task_agg = {
304
+ **task_agg,
305
+ **_task_agg,
306
+ }
307
+ group_agg = {**group_agg, **_group_agg}
308
+ task_depth -= 1
309
+ group_depth -= 1
310
+ return task_agg, group_agg
311
+
312
+
313
+ def consolidate_results(
314
+ eval_tasks: List[TaskOutput],
315
+ ) -> Tuple[dict, dict, dict, dict, dict, dict]:
316
+ """
317
+ @param eval_tasks: list(TaskOutput).
318
+ @return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
319
+
320
+ Consolidates the results of multiple evaluation tasks into a single structure.
321
+
322
+ The method iterates over each evaluation instance and extracts relevant information to create the consolidated
323
+ results structure. The consolidated results structure has the following properties:
324
+
325
+ - results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains
326
+ metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task
327
+ aliases specified in the task configuration.
328
+ - samples: A defaultdict with task names as keys and lists of log samples as values.
329
+ - configs: A defaultdict with task names as keys and task configurations as values.
330
+ - versions: A defaultdict with task names as keys and task versions as values.
331
+ - num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values.
332
+ - higher_is_better: A defaultdict with task names as keys and indicators of whether higher values are better
333
+ for each metric as values.
334
+
335
+ The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple.
336
+ """
337
+ # stores the final result for each task, for each metric/filter pair.
338
+ results = collections.defaultdict(dict)
339
+ # logs info about each document evaluated.
340
+ samples = collections.defaultdict(list)
341
+ # store num-fewshot value per task
342
+ num_fewshot = collections.defaultdict(int)
343
+ # Tracks the YAML configs of all chosen task
344
+ configs = collections.defaultdict(dict)
345
+ # Tracks each task's version.
346
+ versions = collections.defaultdict(dict)
347
+ # Track `higher_is_better` for each metric
348
+ higher_is_better = collections.defaultdict(dict)
349
+
350
+ for task_output in eval_tasks:
351
+ if "task_alias" in (task_config := task_output.task_config):
352
+ results[task_output.task_name]["alias"] = task_config["task_alias"]
353
+ else:
354
+ results[task_output.task_name]["alias"] = task_output.task_name
355
+ if group_alias := task_output.group_alias:
356
+ if group_alias not in results and (group_name := task_output.group_name):
357
+ results[group_name]["alias"] = group_alias
358
+ num_fewshot[task_output.task_name] = task_output.n_shot
359
+ configs[task_output.task_name] = task_output.task_config
360
+ versions[task_output.task_name] = task_output.version
361
+ samples[task_output.task_name] = task_output.logged_samples
362
+ higher_is_better[task_output.task_name] = task_output.task.higher_is_better()
363
+ for (metric, filter_key), items in task_output.sample_metrics.items():
364
+ metric_key = f"{metric},{filter_key}"
365
+ results[task_output.task_name][metric_key] = task_output.agg_metrics[
366
+ metric_key
367
+ ]
368
+ results[task_output.task_name]["samples"] = task_output.sample_len
369
+ results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
370
+ task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
371
+ )
372
+ return results, samples, configs, versions, num_fewshot, higher_is_better
373
+
374
+
375
+ def consolidate_group_results(
376
+ results,
377
+ versions,
378
+ task_dict,
379
+ task_root=None,
380
+ show_group_table=False,
381
+ task_aggregation_list=None,
382
+ ) -> Tuple[dict, dict, bool, Union[None,]]:
383
+ """
384
+ (Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
385
+
386
+ @return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below:
387
+
388
+ - results: A defaultdict with task names (and, after this function is called, group names of
389
+ groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys.
390
+ - versions: A defaultdict with task names (and, after this function is called, group names of
391
+ groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None).
392
+ - show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table.
393
+ - task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric.
394
+
395
+ The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
396
+ In the top-level invocation of this function, task_aggregation_list is ignored.
397
+ """
398
+ if task_root is None:
399
+ task_root = {}
400
+
401
+ if task_aggregation_list is None:
402
+ task_aggregation_list = {}
403
+
404
+ for group_or_task, group_or_task_info in task_dict.items():
405
+ # Convert to string
406
+ if isinstance(group_or_task, ConfigurableGroup):
407
+ group_config = group_or_task.config
408
+ group_or_task = group_or_task.group_name
409
+ else:
410
+ group_config = None
411
+
412
+ if isinstance(group_or_task_info, Task):
413
+ if task_root:
414
+ task_aggregation_list.setdefault(task_root, []).append(
415
+ group_or_task_info.task_name
416
+ )
417
+ else:
418
+ (
419
+ results,
420
+ versions,
421
+ show_group_table,
422
+ _task_aggregation_list,
423
+ ) = consolidate_group_results(
424
+ results,
425
+ versions,
426
+ group_or_task_info,
427
+ group_or_task,
428
+ show_group_table,
429
+ task_aggregation_list,
430
+ )
431
+ if task_root:
432
+ task_aggregation_list.setdefault(task_root, []).extend(
433
+ task_aggregation_list.get(group_or_task, [])
434
+ )
435
+
436
+ if (group_config is None) or (
437
+ group_config["aggregate_metric_list"] is None
438
+ ):
439
+ results[group_or_task][" "] = " "
440
+ continue
441
+
442
+ if "aggregate_metric_list" in group_config:
443
+ agg_metric_list = group_config["aggregate_metric_list"]
444
+
445
+ show_group_table = show_group_table | bool(
446
+ group_config["aggregate_metric_list"]
447
+ )
448
+
449
+ task_list = _task_aggregation_list[group_or_task]
450
+
451
+ metric_list = list(
452
+ {
453
+ key
454
+ for task in task_list
455
+ for key in results[task].keys()
456
+ if "_stderr" not in key and key not in ["task", "alias", "samples"]
457
+ }
458
+ )
459
+ for metric in metric_list:
460
+ stderr = "_stderr,".join(metric.split(","))
461
+
462
+ # gather metrics, sizes, and stderrs from subtasks
463
+ metrics = [
464
+ results[task][metric]
465
+ for task in task_list
466
+ if metric in results[task]
467
+ ] # TODO: copy?
468
+ stderrs = [
469
+ results[task][stderr]
470
+ for task in task_list
471
+ if stderr in results[task]
472
+ ]
473
+ sizes = [
474
+ results[task]["samples"]
475
+ for task in task_list
476
+ if metric in results[task]
477
+ ]
478
+
479
+ for metric_config in agg_metric_list:
480
+ for filter_name in metric_config["filter_list"]:
481
+ if metric != ",".join([metric_config["metric"], filter_name]):
482
+ continue
483
+
484
+ # compute group's pooled metric and stderr
485
+ if metric_config["aggregation"] == "mean":
486
+ aggregate_fn = aggregate_subtask_metrics
487
+ elif callable(metric_config["aggregation"]):
488
+ aggregate_fn = metric_config["aggregation"]
489
+ else:
490
+ raise ValueError(
491
+ f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'"
492
+ )
493
+
494
+ results[group_or_task][metric] = aggregate_fn(
495
+ metrics,
496
+ sizes,
497
+ metric_config["weight_by_size"],
498
+ )
499
+ # TODO: calculate groups' metrics using arbitrary agg fns
500
+ if "N/A" in stderrs:
501
+ results[group_or_task][stderr] = "N/A"
502
+ else:
503
+ # NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere
504
+ results[group_or_task][stderr] = pooled_sample_stderr(
505
+ stderrs, sizes
506
+ )
507
+
508
+ results[group_or_task]["samples"] = sum(sizes)
509
+ group_metadata = group_config.get("metadata", None)
510
+ if group_metadata is not None:
511
+ versions[group_or_task] = group_metadata.get("version", None)
512
+ # print(results)
513
+ return results, versions, show_group_table, task_aggregation_list
514
+
515
+
516
+ @positional_deprecated
517
+ def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
518
+ """
519
+ Search upward in the directory tree to a maximum of three layers
520
+ to find and return the package root (containing the 'tests' folder)
521
+ """
522
+ cur_path = start_path.resolve()
523
+ max_layers = 3
524
+ for _ in range(max_layers):
525
+ if (cur_path / "tests" / "test_version_stable.py").exists():
526
+ return cur_path
527
+ else:
528
+ cur_path = cur_path.parent.resolve()
529
+ raise FileNotFoundError(
530
+ f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
531
+ )
532
+
533
+
534
+ @positional_deprecated
535
+ def run_task_tests(task_list: List[str]):
536
+ """
537
+ Find the package root and run the tests for the given tasks
538
+ """
539
+ import pytest
540
+
541
+ package_root = find_test_root(start_path=pathlib.Path(__file__))
542
+ task_string = " or ".join(task_list)
543
+ args = [
544
+ f"{package_root}/tests/test_version_stable.py",
545
+ f"--rootdir={package_root}",
546
+ "-k",
547
+ f"{task_string}",
548
+ ]
549
+ sys.path.append(str(package_root))
550
+ pytest_return_val = pytest.main(args)
551
+ if pytest_return_val:
552
+ raise ValueError(
553
+ f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
554
+ )
Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import List
3
+
4
+ from dllm_eval.api.filter import FilterEnsemble
5
+ from dllm_eval.api.registry import get_filter
6
+
7
+ from . import custom, extraction, selection, transformation
8
+
9
+
10
+ def build_filter_ensemble(
11
+ filter_name: str, components: List[List[str]]
12
+ ) -> FilterEnsemble:
13
+ """
14
+ Create a filtering pipeline.
15
+ """
16
+ filters = []
17
+ for function, kwargs in components:
18
+ if kwargs is None:
19
+ kwargs = {}
20
+ # create a filter given its name in the registry
21
+ f = partial(get_filter(function), **kwargs)
22
+ # add the filter as a pipeline step
23
+ filters.append(f)
24
+
25
+ return FilterEnsemble(name=filter_name, filters=filters)
Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/custom.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dllm_eval.api.filter import Filter
2
+ from dllm_eval.api.registry import register_filter
3
+
4
+
5
+ @register_filter("custom")
6
+ class CustomFilter(Filter):
7
+ """
8
+ Custom filter that applies a custom, user-defined function to the model responses.
9
+ """
10
+
11
+ def __init__(self, **kwargs) -> None:
12
+ self.filter_fn = kwargs.pop("filter_fn")
13
+
14
+ super().__init__(**kwargs)
15
+
16
+ def apply(self, resps, docs):
17
+ return self.filter_fn(resps, docs)
Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/decontamination.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dllm_eval.api.filter import Filter
2
+ from dllm_eval.api.registry import register_filter
3
+
4
+
5
+ @register_filter("decontaminate")
6
+ class DecontaminationFilter(Filter):
7
+ """
8
+ A filter which evaluates
9
+ """
10
+
11
+ name = "track_decontamination"
12
+
13
+ def __init__(self, path) -> None:
14
+ """
15
+
16
+ TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
17
+ should further cache result on a given (task_name, doc_id)
18
+ """
19
+ self._decontam_results = None
20
+
21
+ def apply(self, resps, docs) -> None:
22
+ """
23
+ Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
24
+ """
25
+ pass
Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/extraction.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import unicodedata
4
+
5
+ from dllm_eval.api.filter import Filter
6
+ from dllm_eval.api.registry import register_filter
7
+
8
+
9
+ @register_filter("regex")
10
+ class RegexFilter(Filter):
11
+ """A filter that extracts values from text using regex pattern matching.
12
+
13
+ This filter applies a regex pattern to each model response and extracts matched values.
14
+ If no match is found, returns a fallback value. Useful for extracting structured data
15
+ (like numbers) from unstructured model outputs.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
21
+ group_select: int = 0,
22
+ fallback: str = "[invalid]",
23
+ ) -> None:
24
+ """
25
+ pass a string `regex` to run `re.compile(r"regex")` on.
26
+ `fallback` defines the output returned if no matches for the regex are located.
27
+ """
28
+ self.regex_pattern = regex_pattern
29
+ self.regex = re.compile(regex_pattern)
30
+ self.group_select = group_select
31
+ self.fallback = fallback
32
+
33
+ def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
34
+ # here, we assume we have a list, in which each element is
35
+ # a list of model responses for some particular input/target pair.
36
+ # so we process each of these (same input/target response sets)
37
+ # independently (and keep them a list.)
38
+ def filter_set(inst):
39
+ filtered = []
40
+ for resp in inst:
41
+ match = self.regex.findall(resp)
42
+ if match:
43
+ match = match[self.group_select]
44
+ if isinstance(match, tuple):
45
+ match = [m for m in match if m]
46
+ if match:
47
+ match = match[0]
48
+ else:
49
+ match = self.fallback
50
+ match = match.strip()
51
+ else:
52
+ match = self.fallback
53
+ filtered.append(match)
54
+ return filtered
55
+
56
+ filtered_resps = list(map(lambda x: filter_set(x), resps))
57
+ return filtered_resps
58
+
59
+
60
+ @register_filter("regex_pos")
61
+ class POSFilter(Filter):
62
+ """ """
63
+
64
+ def __init__(
65
+ self,
66
+ regex_pattern: str = r"\['(.*?)'\]",
67
+ group_select=0,
68
+ fallback=None,
69
+ ) -> None:
70
+ """
71
+ pass a string `regex` to run `re.compile(r"regex")` on.
72
+ `fallback` defines the output returned if no matches for the regex are located.
73
+ """
74
+ if fallback is None:
75
+ fallback = ["invalid"]
76
+ self.regex_pattern = regex_pattern
77
+ self.regex = re.compile(regex_pattern)
78
+ self.group_select = group_select
79
+ self.fallback = fallback
80
+
81
+ def apply(self, resps, docs):
82
+ def extract_tagged_tokens(text):
83
+ # Extract tagged tokens list from text input using regex
84
+ tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
85
+ return [(token, pos) for token, pos in tokens]
86
+
87
+ def extract_pos_tags(result):
88
+ pos_tags = []
89
+ if isinstance(result, str):
90
+ result = extract_tagged_tokens(result)
91
+ pos_tags.extend(pos for _, pos in result)
92
+ return pos_tags if pos_tags else self.fallback
93
+
94
+ def filter_set(inst):
95
+ filtered = []
96
+ for resp in inst:
97
+ match = extract_pos_tags(resp)
98
+ filtered.append(match)
99
+ return filtered
100
+
101
+ filtered_resps = map(lambda x: filter_set(x), resps)
102
+
103
+ return filtered_resps
104
+
105
+
106
+ @register_filter("remove_whitespace")
107
+ class WhitespaceFilter(Filter):
108
+ """Filters out leading whitespace from responses."""
109
+
110
+ def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
111
+ def filter_set(inst):
112
+ filtered_resp = []
113
+ for resp in inst:
114
+ resp = resp.lstrip()
115
+ filtered_resp.append(resp)
116
+ return filtered_resp
117
+
118
+ filtered_resps = [filter_set(resp) for resp in resps]
119
+
120
+ return filtered_resps
121
+
122
+
123
+ @register_filter("multi_choice_regex")
124
+ class MultiChoiceRegexFilter(RegexFilter):
125
+ """
126
+ A filter used to extract a model's answer on multiple choice questions with
127
+ letter answers. assumes each document has a "choices" field
128
+ containing the list of answer choices and that the answer label symbols
129
+ are of the form (A), (B), (C), ... or A, B, C.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
135
+ group_select=0,
136
+ fallback: str = "[invalid]",
137
+ ignore_case=False,
138
+ ignore_punctuation=False,
139
+ regexes_to_ignore=None,
140
+ ) -> None:
141
+ """
142
+ regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
143
+ - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
144
+ - step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
145
+ group_select: Selects the (group_select)th match from the findall result.
146
+ ignore_case: Ignores the case during step 1 matching
147
+ ignore_punctuation: Remove the punctuation during step 1 matching
148
+ regexes_to_ignore: Remove these regexes during step 1 matching
149
+ """
150
+ super().__init__(regex_pattern, group_select, fallback)
151
+ self.ignore_case = ignore_case
152
+ self.ignore_punctuation = ignore_punctuation
153
+ self.regexes_to_ignore = regexes_to_ignore
154
+
155
+ def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
156
+ # here, we assume we have a list, in which each element is
157
+ # a list of model responses for some particular input/target pair.
158
+ # so we process each of these (same input/target response sets)
159
+ # independently (and keep them a list.)
160
+
161
+ def find_match(regex, resp, convert_dict={}):
162
+ match = regex.findall(resp)
163
+ if match:
164
+ match = match[self.group_select]
165
+ if isinstance(match, tuple):
166
+ match = [m for m in match if m][0]
167
+ match = match.strip()
168
+ if match and match in convert_dict:
169
+ match = convert_dict[match]
170
+ return match
171
+
172
+ punct_tbl = dict.fromkeys(
173
+ i
174
+ for i in range(sys.maxunicode)
175
+ if unicodedata.category(chr(i)).startswith("P")
176
+ )
177
+
178
+ def filter_ignores(st):
179
+ if self.regexes_to_ignore is not None:
180
+ for s in self.regexes_to_ignore:
181
+ st = re.sub(s, "", st)
182
+
183
+ if self.ignore_case:
184
+ st = st.lower()
185
+
186
+ if self.ignore_punctuation:
187
+ # https://stackoverflow.com/a/266162
188
+ st = st.translate(punct_tbl)
189
+ return st
190
+
191
+ filtered_resps = []
192
+
193
+ for r, doc in zip(resps, docs):
194
+ fallback_regexes = []
195
+ choice_to_alpha = {}
196
+ next_alpha = "A"
197
+
198
+ without_paren_fallback_regexes = []
199
+ without_paren_to_target = {}
200
+
201
+ choices = doc["choices"]
202
+ for c in choices:
203
+ m = filter_ignores(c.strip())
204
+ fallback_regexes.append(f"{re.escape(m)}")
205
+ choice_to_alpha[m] = f"({next_alpha})"
206
+
207
+ without_paren_fallback_regexes.append(next_alpha)
208
+ without_paren_to_target[next_alpha] = f"({next_alpha})"
209
+
210
+ next_alpha = chr(ord(next_alpha) + 1)
211
+ fallback_regex = re.compile("|".join(fallback_regexes))
212
+ without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
213
+ without_paren_fallback_regex = re.compile(
214
+ rf":[\s]*({without_paren_fallback_regex})"
215
+ )
216
+
217
+ filtered = []
218
+ for resp in r:
219
+ match = find_match(self.regex, resp)
220
+ if not match:
221
+ match = find_match(
222
+ fallback_regex, filter_ignores(resp), choice_to_alpha
223
+ )
224
+ if not match:
225
+ match = find_match(
226
+ without_paren_fallback_regex, resp, without_paren_to_target
227
+ )
228
+ if not match:
229
+ match = self.fallback
230
+ filtered.append(match)
231
+ filtered_resps.append(filtered)
232
+
233
+ return filtered_resps
Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/selection.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+
3
+ from dllm_eval.api.filter import Filter
4
+ from dllm_eval.api.registry import register_filter
5
+
6
+
7
+ # TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function
8
+ # that takes an input and returns a scalar and then should select the max reward,
9
+ # or should implement different filters for different ways of handling a reward model's inference.
10
+
11
+
12
+ @register_filter("take_first")
13
+ class TakeFirstFilter(Filter):
14
+ def __init__(self) -> None:
15
+ """
16
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
17
+ """
18
+
19
+ def apply(self, resps, docs):
20
+ """
21
+ Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
22
+ """
23
+ return map(lambda r: r[0], resps)
24
+
25
+
26
+ @register_filter("take_first_k")
27
+ class TakeKFilter(Filter):
28
+ def __init__(self, **kwargs) -> None:
29
+ self.k = kwargs.pop("k")
30
+
31
+ super().__init__(**kwargs)
32
+
33
+ def apply(self, resps, docs):
34
+ # need resp to be subscriptable to check below
35
+ resps = list(resps)
36
+ # check we have at least k responses per doc, else we can't take the first k
37
+ assert len(resps[0]) >= self.k, (
38
+ f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
39
+ )
40
+ return map(lambda r: r[: self.k], resps)
41
+
42
+
43
+ @register_filter("majority_vote")
44
+ class MajorityVoteFilter(Filter):
45
+ def __init__(self) -> None:
46
+ """
47
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
48
+ """
49
+
50
+ def apply(self, resps, docs):
51
+ """
52
+ Each entry of `resps` is a list of model responses.
53
+ We select the response that occurs most frequently in each entry of `resps`.
54
+ """
55
+
56
+ def select_majority(resp):
57
+ counts = Counter(resp)
58
+ vote = counts.most_common(1)[0][0]
59
+ return vote
60
+
61
+ return map(lambda r: [select_majority(r)], resps)
Prism/LLaDA/LLaDA_Baseline/dllm_eval/filters/transformation.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from dllm_eval.api.filter import Filter
4
+ from dllm_eval.api.registry import register_filter
5
+
6
+
7
+ @register_filter("lowercase")
8
+ class LowercaseFilter(Filter):
9
+ def __init__(self) -> None:
10
+ pass
11
+
12
+ def apply(self, resps, docs):
13
+ def filter_set(inst):
14
+ return [resp.lower() for resp in inst]
15
+
16
+ return [filter_set(resp) for resp in resps]
17
+
18
+
19
+ @register_filter("uppercase")
20
+ class UppercaseFilter(Filter):
21
+ def __init__(self) -> None:
22
+ pass
23
+
24
+ def apply(self, resps, docs):
25
+ def filter_set(inst):
26
+ return [resp.upper() for resp in inst]
27
+
28
+ return [filter_set(resp) for resp in resps]
29
+
30
+
31
+ @register_filter("map")
32
+ class MapFilter(Filter):
33
+ def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
34
+ """
35
+ Initializes the MapFilter with a given mapping dictionary and default value.
36
+
37
+ Args:
38
+ - mapping_dict (dict): A dictionary containing the key-value mappings.
39
+ Default is an empty dictionary.
40
+ - default_value (Any): The value to be returned when a key is not found in the mapping_dict.
41
+ Default is None.
42
+
43
+ Example:
44
+ mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
45
+ """
46
+ if mapping_dict is None:
47
+ mapping_dict = {}
48
+ assert isinstance(mapping_dict, dict), (
49
+ "Provided mapping_dict is not a dictionary"
50
+ )
51
+ self.mapping_dict = mapping_dict
52
+ self.default_value = default_value
53
+
54
+ def apply(self, resps, docs):
55
+ def filter_set(inst):
56
+ return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
57
+
58
+ return [filter_set(resp) for resp in resps]
59
+
60
+
61
+ @register_filter("format_span")
62
+ class SPANFilter(Filter):
63
+ def __init__(self) -> None:
64
+ pass
65
+
66
+ def apply(self, resps, docs):
67
+ def format_ner_text(text):
68
+ label_dict = {
69
+ "person": "PER",
70
+ "location": "LOC",
71
+ "organization": "ORG",
72
+ "counties": "LOC",
73
+ "places": "LOC",
74
+ "people": "PER",
75
+ "persons": "PER",
76
+ "company": "ORG",
77
+ "country": "LOC",
78
+ "continent": "LOC",
79
+ "time": "DATE",
80
+ "date": "DATE",
81
+ "per": "PER",
82
+ "loc": "LOC",
83
+ "org": "ORG",
84
+ }
85
+ text = text.lower()
86
+ for key, value in label_dict.items():
87
+ text = text.replace(key, value)
88
+
89
+ text = "$".join(i for i in text.split("$$"))
90
+ return text.rstrip("$$")
91
+
92
+ def format_named_entities(text):
93
+ """
94
+ Extract named entities from text and format them as 'label: value $$ label: value'.
95
+ Handles grouped entities (e.g., LOC: kenya, uganda) and excludes 'none' values.
96
+ """
97
+ # Regular expression to match label: entities pattern
98
+ pattern = r"\b(PER|LOC|ORG|DATE):\s*([^$]+)"
99
+ # Normalize newline characters
100
+ text = text.replace("\n", "$").strip()
101
+ matches = re.findall(pattern, text)
102
+
103
+ formatted_entities = []
104
+
105
+ for label, values in matches:
106
+ # Split multiple entities separated by commas and strip whitespace
107
+ entities = [value.strip() for value in values.split(",")]
108
+
109
+ # Exclude 'none' entities
110
+ for entity in entities:
111
+ if entity.lower() != "none":
112
+ formatted_entities.append(f"{label.lower()}: {entity}")
113
+
114
+ # Join entities with the desired separator
115
+ return " $ ".join(formatted_entities)
116
+
117
+ def filter_set(inst):
118
+ return [
119
+ format_named_entities(format_ner_text(resp.lower())) for resp in inst
120
+ ]
121
+
122
+ return [filter_set(resp) for resp in resps]
Prism/LLaDA/LLaDA_Baseline/dllm_eval/utils.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import fnmatch
3
+ import functools
4
+ import hashlib
5
+ import importlib.util
6
+ import inspect
7
+ import json
8
+ import logging
9
+ import os
10
+ import re
11
+ from dataclasses import asdict, is_dataclass
12
+ from itertools import islice
13
+ from pathlib import Path
14
+ from typing import Any, Callable, Generator, List, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import yaml
18
+ from jinja2 import BaseLoader, Environment, StrictUndefined
19
+
20
+
21
+ SPACING = " " * 47
22
+
23
+ HIGHER_IS_BETTER_SYMBOLS = {
24
+ True: "↑",
25
+ False: "↓",
26
+ }
27
+
28
+
29
+ def setup_logging(verbosity=logging.INFO):
30
+ # Configure the root logger
31
+ class CustomFormatter(logging.Formatter):
32
+ def format(self, record):
33
+ if record.name.startswith("dllm_eval."):
34
+ record.name = record.name[len("dllm_eval.") :]
35
+ return super().format(record)
36
+
37
+ formatter = CustomFormatter(
38
+ "%(asctime)s %(levelname)-8s [%(name)s:%(lineno)d] %(message)s",
39
+ datefmt="%Y-%m-%d:%H:%M:%S",
40
+ )
41
+
42
+ log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity
43
+
44
+ level_map = {
45
+ "DEBUG": logging.DEBUG,
46
+ "INFO": logging.INFO,
47
+ "WARNING": logging.WARNING,
48
+ "ERROR": logging.ERROR,
49
+ "CRITICAL": logging.CRITICAL,
50
+ }
51
+
52
+ log_level = level_map.get(str(log_level).upper(), logging.INFO)
53
+
54
+ if not logging.root.handlers:
55
+ handler = logging.StreamHandler()
56
+ handler.setFormatter(formatter)
57
+
58
+ root_logger = logging.getLogger()
59
+ root_logger.addHandler(handler)
60
+ root_logger.setLevel(log_level)
61
+
62
+ if log_level == logging.DEBUG:
63
+ third_party_loggers = ["urllib3", "filelock", "fsspec"]
64
+ for logger_name in third_party_loggers:
65
+ logging.getLogger(logger_name).setLevel(logging.INFO)
66
+ else:
67
+ logging.getLogger().setLevel(log_level)
68
+
69
+
70
+ def hash_string(string: str) -> str:
71
+ return hashlib.sha256(string.encode("utf-8")).hexdigest()
72
+
73
+
74
+ def escaped_split(text, sep_char, maxsplit=-1):
75
+ """Split text into a list on occurrences of the given separation
76
+ character `sep_char`. The separation character may be escaped by a
77
+ backslash to avoid splitting at that location.
78
+
79
+ The separation character must be a string of size 1.
80
+
81
+ If `maxsplit` is given, at most `maxsplit` splits are done (thus,
82
+ the list will have at most `maxsplit + 1` elements). If `maxsplit`
83
+ is not specified or less than 0, then there is no limit on the
84
+ number of splits (all possible splits are made).
85
+ """
86
+ assert len(sep_char) == 1, (
87
+ "separation string must be a single character for escaped splitting"
88
+ )
89
+
90
+ if maxsplit == 0:
91
+ return text
92
+ maxsplit = max(0, maxsplit)
93
+
94
+ return re.split(r"(?<!\\)" + sep_char, text, maxsplit)
95
+
96
+
97
+ def handle_arg_string(arg):
98
+ if arg.lower() == "true":
99
+ return True
100
+ elif arg.lower() == "false":
101
+ return False
102
+ elif arg.isnumeric():
103
+ return int(arg)
104
+ try:
105
+ return float(arg)
106
+ except ValueError:
107
+ return arg
108
+
109
+
110
+ def handle_non_serializable(o):
111
+ if isinstance(o, np.int64) or isinstance(o, np.int32):
112
+ return int(o)
113
+ elif isinstance(o, set):
114
+ return list(o)
115
+ else:
116
+ return str(o)
117
+
118
+
119
+ def sanitize_list(sub):
120
+ """
121
+ Takes possible nested list and recursively converts all inner component to strings
122
+ """
123
+ if isinstance(sub, list):
124
+ return [sanitize_list(item) for item in sub]
125
+ if isinstance(sub, tuple):
126
+ return tuple(sanitize_list(item) for item in sub)
127
+ else:
128
+ return str(sub)
129
+
130
+
131
+ def simple_parse_args_string(args_string: Optional[str]) -> dict:
132
+ """
133
+ Parses something like
134
+ args1=val1,arg2=val2
135
+ Into a dictionary
136
+ """
137
+ if args_string is None:
138
+ return {}
139
+ args_string = args_string.strip()
140
+ if not args_string:
141
+ return {}
142
+ arg_list = [arg for arg in args_string.split(",") if arg]
143
+ args_dict = {
144
+ kv[0]: handle_arg_string("=".join(kv[1:]))
145
+ for kv in [arg.split("=") for arg in arg_list]
146
+ }
147
+ return args_dict
148
+
149
+
150
+ def join_iters(iters):
151
+ for iter in iters:
152
+ yield from iter
153
+
154
+
155
+ def group(arr, fn):
156
+ res = collections.defaultdict(list)
157
+
158
+ for ob in arr:
159
+ res[fn(ob)].append(ob)
160
+
161
+ return list(res.values())
162
+
163
+
164
+ # Returns a list containing all values of the source_list that
165
+ # match at least one of the patterns
166
+ def pattern_match(patterns, source_list):
167
+ if isinstance(patterns, str):
168
+ patterns = [patterns]
169
+
170
+ task_names = set()
171
+ for pattern in patterns:
172
+ for matching in fnmatch.filter(source_list, pattern):
173
+ task_names.add(matching)
174
+ return sorted(list(task_names))
175
+
176
+
177
+ def softmax(x) -> np.ndarray:
178
+ """Compute softmax values for each sets of scores in x."""
179
+ e_x = np.exp(x - np.max(x))
180
+ return e_x / e_x.sum()
181
+
182
+
183
+ def general_detokenize(string) -> str:
184
+ string = string.replace(" n't", "n't")
185
+ string = string.replace(" )", ")")
186
+ string = string.replace("( ", "(")
187
+ string = string.replace('" ', '"')
188
+ string = string.replace(' "', '"')
189
+ string = re.sub(r" (['.,])", r"\1", string)
190
+ return string
191
+
192
+
193
+ def get_file_task_name(filename: str) -> str:
194
+ """
195
+ Given the sample results filenames, extracts and returns the task name.
196
+ """
197
+ return filename[filename.find("_") + 1 : filename.rfind("_")]
198
+
199
+
200
+ def get_file_datetime(filename: str) -> str:
201
+ """
202
+ Given the results and sample results filenames, extracts and returns the datetime.
203
+ """
204
+ return filename[filename.rfind("_") + 1 :].replace(".jsonl", "")
205
+
206
+
207
+ def sanitize_model_name(model_name: str) -> str:
208
+ """
209
+ Given the model name, returns a sanitized version of it.
210
+ """
211
+ return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name)
212
+
213
+
214
+ def sanitize_task_name(task_name: str) -> str:
215
+ """
216
+ Given the task name, returns a sanitized version of it.
217
+ """
218
+ return re.sub(r"\W", "_", task_name)
219
+
220
+
221
+ def get_latest_filename(filenames: List[str]) -> str:
222
+ """
223
+ Given a list of filenames, returns the filename with the latest datetime.
224
+ """
225
+ return max(filenames, key=lambda f: get_file_datetime(f))
226
+
227
+
228
+ def get_results_filenames(filenames: List[str]) -> List[str]:
229
+ """
230
+ Extracts filenames that correspond to aggregated results.
231
+ """
232
+ return [f for f in filenames if "/results_" in f and ".json" in f]
233
+
234
+
235
+ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
236
+ """
237
+ Extracts filenames that correspond to sample results.
238
+ """
239
+ return [f for f in filenames if "/samples_" in f and ".json" in f]
240
+
241
+
242
+ def get_rolling_token_windows(
243
+ token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int
244
+ ) -> Generator[Tuple[List[int], List[int]], None, None]:
245
+ """
246
+ - context_len allows for a rolling window context, allowing each prediction window to potentially
247
+ condition on some context
248
+
249
+ :param token_list: list
250
+ List of tokens to be PREDICTED
251
+ :param max_seq_len: int
252
+ max_seq_len of model (or max_seq_len we want to use)
253
+ :param context_len: int
254
+ Amount of desired token context for prediction. Needs to be at least 1.
255
+ :param prefix_token: token
256
+ Dummy token like <eos> so the first token has something to condition on
257
+ :return: generator
258
+ Generator of tuples
259
+ (input_tokens, pred_tokens)
260
+ Note: Score only the last len(pred_tokens) logits of the LM
261
+ """
262
+ assert 1 <= context_len <= max_seq_len
263
+ if not token_list:
264
+ return
265
+ # +1 offset, going from input->preds
266
+ pred_len = max_seq_len - context_len + 1
267
+ predicted = 0
268
+
269
+ # Special handling for first window: predict all tokens
270
+ first_seq_len = min(max_seq_len, len(token_list))
271
+ yield [prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]
272
+ predicted += first_seq_len
273
+
274
+ while predicted < len(token_list):
275
+ window_pred_len = min(len(token_list) - predicted, pred_len)
276
+ window_end = predicted + window_pred_len
277
+
278
+ yield (
279
+ token_list[window_end - max_seq_len - 1 : window_end - 1],
280
+ token_list[window_end - window_pred_len : window_end],
281
+ )
282
+ predicted += window_pred_len
283
+
284
+
285
+ def make_disjoint_window(
286
+ pair: Tuple[List[int], List[int]],
287
+ ) -> Tuple[List[int], List[int]]:
288
+ """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
289
+ a, b = pair
290
+ return a[: len(a) - (len(b) - 1)], b
291
+
292
+
293
+ class EnhancedJSONEncoder(json.JSONEncoder):
294
+ """
295
+ Provides a proper json encoding for the loggers and trackers json dumps.
296
+ Notably manages the json encoding of dataclasses.
297
+ """
298
+
299
+ def default(self, o):
300
+ if is_dataclass(o):
301
+ return asdict(o)
302
+ return super().default(o)
303
+
304
+
305
+ class Reorderer:
306
+ def __init__(self, arr: List[Any], fn: Callable) -> None:
307
+ """Reorder an array according to some function
308
+
309
+ Args:
310
+ arr (List[Any]): The initial array
311
+ fn (Callable[[Any], Any]): A function to determine the priority of elements
312
+ """
313
+ self.size = len(arr)
314
+ arr = list(enumerate(arr))
315
+ arr = group(arr, lambda x: fn(x[1]))
316
+ # arr = [([y[0] for y in x], x[0][1]) for x in arr]
317
+ # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this
318
+ arr = [([y[0]], x[0][1]) for x in arr for y in x]
319
+ arr.sort(key=lambda x: fn(x[1]))
320
+
321
+ self.arr = arr
322
+
323
+ def get_reordered(self):
324
+ """Gets the reordered array
325
+
326
+ Returns:
327
+ List[Any]: The reordered array
328
+ """
329
+ return [x[1] for x in self.arr]
330
+
331
+ def get_original(self, newarr):
332
+ """Restores the original order of a new array based on the old array's order
333
+
334
+ Args:
335
+ newarr (List[Any]): The array to be restored
336
+
337
+ Returns:
338
+ List[Any]: The array restored to the original order
339
+ """
340
+ res = [None] * self.size
341
+ cov = [False] * self.size
342
+
343
+ for (inds, _), v in zip(self.arr, newarr):
344
+ for ind in inds:
345
+ res[ind] = v
346
+ cov[ind] = True
347
+
348
+ assert all(cov)
349
+
350
+ return res
351
+
352
+
353
+ def make_table(result_dict, column: str = "results", sort_results: bool = False):
354
+ """Generate table of results."""
355
+ from pytablewriter import LatexTableWriter, MarkdownTableWriter
356
+
357
+ if column == "results":
358
+ column_name = "Tasks"
359
+ elif column == "groups":
360
+ column_name = "Groups"
361
+
362
+ all_headers = [
363
+ column_name,
364
+ "Version",
365
+ "Filter",
366
+ "n-shot",
367
+ "Metric",
368
+ "",
369
+ "Value",
370
+ "",
371
+ "Stderr",
372
+ ]
373
+
374
+ md_writer = MarkdownTableWriter()
375
+ latex_writer = LatexTableWriter()
376
+ md_writer.headers = all_headers
377
+ latex_writer.headers = all_headers
378
+
379
+ values = []
380
+
381
+ keys = result_dict[column].keys()
382
+ if sort_results:
383
+ # sort entries alphabetically by task or group name.
384
+ # NOTE: we default here to false, because order matters for multi-level table printing a la mmlu.
385
+ # sorting here would mess that up
386
+ keys = sorted(keys)
387
+ for k in keys:
388
+ dic = result_dict[column][k]
389
+ version = result_dict["versions"].get(k, " N/A")
390
+ n = str(result_dict.get("n-shot", " ").get(k, " "))
391
+ higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})
392
+
393
+ if "alias" in dic:
394
+ k = dic.pop("alias")
395
+
396
+ metric_items = dic.items()
397
+ metric_items = sorted(metric_items)
398
+
399
+ for (mf), v in metric_items:
400
+ m, _, f = mf.partition(",")
401
+ if m.endswith("_stderr"):
402
+ continue
403
+
404
+ hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")
405
+
406
+ v = "%.4f" % v if isinstance(v, float) else v
407
+
408
+ if m + "_stderr" + "," + f in dic:
409
+ se = dic[m + "_stderr" + "," + f]
410
+ se = " N/A" if se == "N/A" else "%.4f" % se
411
+ values.append([k, version, f, n, m, hib, v, "±", se])
412
+ else:
413
+ values.append([k, version, f, n, m, hib, v, "", ""])
414
+ k = ""
415
+ version = ""
416
+ md_writer.value_matrix = values
417
+ latex_writer.value_matrix = values
418
+
419
+ # todo: make latex table look good
420
+ # print(latex_writer.dumps())
421
+
422
+ return md_writer.dumps()
423
+
424
+
425
+ def positional_deprecated(fn):
426
+ """
427
+ A decorator to nudge users into passing only keyword args (`kwargs`) to the
428
+ wrapped function, `fn`.
429
+ """
430
+
431
+ @functools.wraps(fn)
432
+ def _wrapper(*args, **kwargs):
433
+ if len(args) != 1 if inspect.ismethod(fn) else 0:
434
+ print(
435
+ f"WARNING: using {fn.__name__} with positional arguments is "
436
+ "deprecated and will be disallowed in a future version of "
437
+ "lm-evaluation-harness!"
438
+ )
439
+ return fn(*args, **kwargs)
440
+
441
+ return _wrapper
442
+
443
+
444
+ def ignore_constructor(loader, node):
445
+ return node
446
+
447
+
448
+ def import_function(loader: yaml.Loader, node, yaml_path: Path):
449
+ function_name = loader.construct_scalar(node)
450
+
451
+ *module_name, function_name = function_name.split(".")
452
+ if isinstance(module_name, list):
453
+ module_name = ".".join(module_name)
454
+ module_path = yaml_path.parent / f"{module_name}.py"
455
+
456
+ spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix())
457
+
458
+ if spec is None:
459
+ raise ImportError(f"Could not import module {module_name} from {module_path}.")
460
+ module = importlib.util.module_from_spec(spec)
461
+
462
+ if spec.loader is None:
463
+ raise ImportError(f"Module loader is None, {module_name} from {module_path}.")
464
+ spec.loader.exec_module(module)
465
+
466
+ function = getattr(module, function_name)
467
+ return function
468
+
469
+
470
+ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"):
471
+ if mode == "simple":
472
+ constructor_fn = ignore_constructor
473
+ elif mode == "full":
474
+ if yaml_path is None:
475
+ raise ValueError("yaml_path must be provided if mode is 'full'.")
476
+ # Attach yaml_path to the import function so that it can be used later
477
+ constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path))
478
+
479
+ loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
480
+ # Add the import_function constructor to the YAML loader
481
+ yaml.add_constructor("!function", constructor_fn, Loader=loader)
482
+ if yaml_config is None:
483
+ with open(yaml_path, "rb") as file:
484
+ yaml_config = yaml.load(file, Loader=loader)
485
+
486
+ if yaml_dir is None:
487
+ yaml_dir = os.path.dirname(yaml_path)
488
+
489
+ assert yaml_dir is not None
490
+
491
+ if "include" in yaml_config:
492
+ include_path = yaml_config["include"]
493
+ del yaml_config["include"]
494
+
495
+ if isinstance(include_path, str):
496
+ include_path = [include_path]
497
+
498
+ # Load from the last one first
499
+ include_path.reverse()
500
+ final_yaml_config = {}
501
+ for path in include_path:
502
+ # Assumes that path is a full path.
503
+ # If not found, assume the included yaml
504
+ # is in the same dir as the original yaml
505
+ if not os.path.isfile(path):
506
+ path = os.path.join(yaml_dir, path)
507
+
508
+ try:
509
+ included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
510
+ final_yaml_config.update(included_yaml_config)
511
+ except Exception as ex:
512
+ # If failed to load, ignore
513
+ raise ex
514
+
515
+ final_yaml_config.update(yaml_config)
516
+ return final_yaml_config
517
+ return yaml_config
518
+
519
+
520
+ def regex_replace(string, pattern, repl, count: int = 0):
521
+ """Implements the `re.sub` function as a custom Jinja filter."""
522
+ return re.sub(pattern, repl, string, count=count)
523
+
524
+
525
+ env = Environment(
526
+ loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True
527
+ )
528
+ env.filters["regex_replace"] = regex_replace
529
+
530
+
531
+ def apply_template(template: str, doc: dict) -> str:
532
+ rtemplate = env.from_string(template)
533
+ return rtemplate.render(**doc)
534
+
535
+
536
+ def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
537
+ """
538
+ Method for creating a (potentially) sliced and limited
539
+ iterator from a raw document iterator. Used for splitting data
540
+ among ranks in multigpu setting or only pulling a sample of documents
541
+ """
542
+ return islice(raw_iterator, rank, limit, world_size)
543
+
544
+
545
+ def weighted_f1_score(items):
546
+ from sklearn.metrics import f1_score
547
+
548
+ unzipped_list = list(zip(*items))
549
+ golds = unzipped_list[0]
550
+ preds = unzipped_list[1]
551
+ fscore = f1_score(golds, preds, average="weighted")
552
+ return fscore
Prism/LLaDA/LLaDA_Baseline/evaluation_script.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ from dllm_eval.__main__ import cli_evaluate
6
+
7
+
8
+ def set_seed(seed):
9
+ torch.manual_seed(seed)
10
+ random.seed(seed)
11
+ np.random.seed(seed)
12
+
13
+ torch.backends.cudnn.deterministic = True
14
+ torch.backends.cudnn.benchmark = False
15
+
16
+
17
+ if __name__ == "__main__":
18
+ os.environ["HF_ALLOW_CODE_EVAL"] = "1"
19
+ os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
20
+ set_seed(42)
21
+ cli_evaluate()
Prism/LLaDA/LLaDA_Baseline/metrics/gsm8k_all.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import os
4
+ import math
5
+ import argparse
6
+ from collections import Counter
7
+
8
+ RES_PATH = "<PATH_TO_RESULTS_JSONL>"
9
+
10
+ def last_boxed_only_string(string):
11
+ if not string: return None
12
+ idx = max(string.rfind("\\boxed"), string.rfind("\\fbox"))
13
+ if idx < 0: return None
14
+
15
+ if "\\boxed " in string[idx:idx+8] and "{" not in string[idx:idx+8]:
16
+ return "\\boxed " + string[idx:].split("\\boxed ")[-1].split("$")[0].strip()
17
+
18
+ i = idx
19
+ right_brace_idx = None
20
+ num_left_braces_open = 0
21
+ while i < len(string):
22
+ if string[i] == "{":
23
+ num_left_braces_open += 1
24
+ elif string[i] == "}":
25
+ num_left_braces_open -= 1
26
+ if num_left_braces_open == 0:
27
+ right_brace_idx = i
28
+ break
29
+ i += 1
30
+ return string[idx : right_brace_idx + 1] if right_brace_idx else None
31
+
32
+ def remove_boxed(s):
33
+ if not s: return None
34
+ if "\\boxed " in s: return s[len("\\boxed ") :]
35
+ if "\\boxed{" in s and s.endswith("}"): return s[len("\\boxed{") : -1]
36
+ if "\\fbox{" in s and s.endswith("}"): return s[len("\\fbox{") : -1]
37
+ return s
38
+
39
+ def strip_string(string):
40
+ if string is None: return ""
41
+ string = str(string).strip()
42
+ while re.search(r"(\d),(\d{3})", string):
43
+ string = re.sub(r"(\d),(\d{3})", r"\1\2", string)
44
+
45
+ string = string.replace("\n", "").replace("\\!", "")
46
+ string = string.replace("tfrac", "frac").replace("dfrac", "frac")
47
+ string = string.replace("\\left", "").replace("\\right", "")
48
+ string = string.replace("^{\\circ}", "").replace("^\\circ", "")
49
+ string = string.replace("\\$", "").replace("\\%", "").replace("\%", "")
50
+
51
+ if "=" in string and len(string.split("=")[0]) <= 5:
52
+ string = string.split("=")[1].strip()
53
+
54
+ string = string.replace(" ", "")
55
+ string = string.rstrip(".")
56
+ return string
57
+
58
+ def normalize_to_number(s):
59
+ s_clean = strip_string(s)
60
+ try:
61
+ if '/' in s_clean and len(s_clean.split('/')) == 2:
62
+ parts = s_clean.split('/')
63
+ return float(parts[0]) / float(parts[1])
64
+ return float(s_clean)
65
+ except:
66
+ return s_clean
67
+
68
+ def extract_answer_gsm8k_debug(text):
69
+ if not text: return "", "empty"
70
+ text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").strip()
71
+
72
+ boxed = last_boxed_only_string(text)
73
+ if boxed:
74
+ ans = remove_boxed(boxed)
75
+ if ans:
76
+ return strip_string(ans), "boxed"
77
+
78
+ tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
79
+ if tag_match:
80
+ return strip_string(tag_match.group(1)), "xml_tag"
81
+
82
+ last_text = text[-200:] if len(text) > 200 else text
83
+ marker = "the answer is"
84
+ if marker in last_text.lower():
85
+ idx = last_text.lower().rfind(marker)
86
+ after = last_text[idx + len(marker):].strip()
87
+ after = re.split(r"[.\n]", after)[0]
88
+ after = after.replace(":", "").replace("$", "").strip()
89
+ return strip_string(after), "text_marker"
90
+
91
+ tail = text[-50:]
92
+ nums = re.findall(r"(?<!\d)-?\d+\.?\d*(?!\d)", tail)
93
+ if nums:
94
+ return strip_string(nums[-1]), "regex_last_num"
95
+
96
+ return "", "failed"
97
+
98
+ def extract_gold_gsm8k(target_str):
99
+ if "####" in target_str:
100
+ return strip_string(target_str.split("####")[-1])
101
+ return strip_string(target_str)
102
+
103
+ def is_equiv(pred, gold):
104
+ p_val = normalize_to_number(pred)
105
+ g_val = normalize_to_number(gold)
106
+
107
+ if isinstance(p_val, float) and isinstance(g_val, float):
108
+ return math.isclose(p_val, g_val, rel_tol=1e-4)
109
+ return str(p_val) == str(g_val)
110
+
111
+ def run_evaluation(target_path):
112
+ jsonl_files = []
113
+ if os.path.isdir(target_path):
114
+ for root, dirs, files in os.walk(target_path):
115
+ for file in files:
116
+ if file.endswith(".jsonl") and not file.startswith("eval_voted_"):
117
+ jsonl_files.append(os.path.join(root, file))
118
+ else:
119
+ jsonl_files = [target_path]
120
+
121
+ for file_path in jsonl_files:
122
+ print(f">>> 正在评测: {file_path}")
123
+ detailed_results = []
124
+
125
+ correct_voted_count = 0
126
+ correct_any_count = 0
127
+ total_count = 0
128
+ nfe_list = []
129
+ svf_list = []
130
+
131
+ with open(file_path, 'r', encoding='utf-8') as f:
132
+ for line in f:
133
+ if not line.strip(): continue
134
+ try:
135
+ item = json.loads(line)
136
+ except:
137
+ continue
138
+
139
+ doc = item.get("doc", {})
140
+ ground_truth = extract_gold_gsm8k(str(item.get("target", "")))
141
+
142
+ total_nfe_item = item.get("nfe", 0)
143
+ nfe_list.append(total_nfe_item)
144
+ svf_list.append(item.get("svf_calls", 0))
145
+
146
+ trajectories = item.get("all_trajectories", [])
147
+ if not trajectories:
148
+ resps = item.get("resps", [])
149
+ for r in resps:
150
+ text = r[0] if isinstance(r, list) else r
151
+ trajectories.append({"resp": text, "score": 0.0})
152
+
153
+ parsed_paths = []
154
+ traj_debug_info = []
155
+
156
+ for idx, traj in enumerate(trajectories):
157
+ raw_text = traj.get("resp", "")
158
+ score = traj.get("score", 0.0)
159
+
160
+ extracted, method = extract_answer_gsm8k_debug(raw_text)
161
+
162
+ is_correct_single = False
163
+ if extracted:
164
+ is_correct_single = is_equiv(extracted, ground_truth)
165
+ val_key = normalize_to_number(extracted)
166
+
167
+ parsed_paths.append({
168
+ "original_text": extracted,
169
+ "val_key": val_key,
170
+ "score": score,
171
+ "method": method
172
+ })
173
+
174
+ traj_debug_info.append({
175
+ "id": idx,
176
+ "extracted": extracted,
177
+ "score": score,
178
+ "is_correct": is_correct_single,
179
+ "extract_method": method
180
+ })
181
+
182
+ if not parsed_paths:
183
+ detailed_results.append({
184
+ "question": doc.get("question", "N/A"),
185
+ "final_voted_answer": "",
186
+ "ground_truth": ground_truth,
187
+ "is_voted_correct": False,
188
+ "trajectory_details": traj_debug_info,
189
+ "nfe": total_nfe_item,
190
+ "svf_calls": item.get("svf_calls", 0)
191
+ })
192
+ total_count += 1
193
+ continue
194
+
195
+ has_correct = any(p['score'] > -999 and is_equiv(p['original_text'], ground_truth) for p in parsed_paths)
196
+ if has_correct:
197
+ correct_any_count += 1
198
+
199
+ parsed_paths.sort(key=lambda x: x['score'], reverse=True)
200
+ top_k_count = max(1, int(len(parsed_paths) * 0.6))
201
+ voting_candidates = parsed_paths[:top_k_count]
202
+
203
+ ans_stats = {}
204
+ for p in voting_candidates:
205
+ k = p['val_key']
206
+ if k not in ans_stats:
207
+ ans_stats[k] = {
208
+ "total_weight": 0.0,
209
+ "count": 0,
210
+ "max_score": -float('inf'),
211
+ "best_repr": p['original_text']
212
+ }
213
+
214
+ try:
215
+ weight = math.exp(p['score'])
216
+ except OverflowError:
217
+ weight = float('inf')
218
+
219
+ ans_stats[k]["total_weight"] += weight
220
+ ans_stats[k]["count"] += 1
221
+ if p['score'] > ans_stats[k]["max_score"]:
222
+ ans_stats[k]["max_score"] = p['score']
223
+ ans_stats[k]["best_repr"] = p['original_text']
224
+
225
+ sorted_answers = sorted(
226
+ ans_stats.items(),
227
+ key=lambda x: (x[1]["total_weight"], x[1]["max_score"]),
228
+ reverse=True
229
+ )
230
+
231
+ best_pred = str(sorted_answers[0][1]["best_repr"])
232
+ is_voted_correct = is_equiv(best_pred, ground_truth)
233
+ if is_voted_correct:
234
+ correct_voted_count += 1
235
+
236
+ vote_summary = []
237
+ for val, info in sorted_answers:
238
+ vote_summary.append({
239
+ "answer": str(val),
240
+ "count": info["count"],
241
+ "total_weight": info["total_weight"],
242
+ "is_correct": is_equiv(str(val), ground_truth)
243
+ })
244
+
245
+ total_count += 1
246
+
247
+ detailed_results.append({
248
+ "question": doc.get("question", "N/A"),
249
+ "final_voted_answer": best_pred,
250
+ "ground_truth": ground_truth,
251
+ "is_voted_correct": is_voted_correct,
252
+ "vote_stats": vote_summary,
253
+ "trajectory_details": traj_debug_info,
254
+ "nfe": total_nfe_item,
255
+ "svf_calls": item.get("svf_calls", 0)
256
+ })
257
+
258
+ accuracy = (correct_voted_count / total_count * 100) if total_count > 0 else 0
259
+ pass_at_k = (correct_any_count / total_count * 100) if total_count > 0 else 0
260
+ avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0
261
+ avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0
262
+
263
+ print(f"--- Accuracy: {accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---")
264
+
265
+ output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}"
266
+ output_path = os.path.join(os.path.dirname(file_path), output_name)
267
+
268
+ final_report = {
269
+ "summary": {
270
+ "accuracy": f"{accuracy:.2f}%",
271
+ "correct_voted": correct_voted_count,
272
+ "total": total_count,
273
+ "nfe": avg_nfe,
274
+ "svf_calls": avg_svf
275
+ },
276
+ "details": detailed_results
277
+ }
278
+
279
+ with open(output_path, 'w', encoding='utf-8') as out_f:
280
+ json.dump(final_report, out_f, ensure_ascii=False, indent=4)
281
+
282
+ if __name__ == "__main__":
283
+ parser = argparse.ArgumentParser()
284
+ parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
285
+ args = parser.parse_args()
286
+ run_evaluation(args.res_path)
Prism/LLaDA/LLaDA_Baseline/metrics/humaneval_all.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import ast
5
+ import traceback
6
+ import glob
7
+ import math
8
+ import argparse
9
+ from typing import Dict, List, Optional, Set, Tuple
10
+ from collections import Counter
11
+ import evaluate as hf_evaluate
12
+ import re
13
+
14
+ RES_PATH = "<PATH_TO_RESULTS_JSONL>"
15
+
16
+ os.environ["HF_ALLOW_CODE_EVAL"] = "1"
17
+
18
+ def extract_python_code(text: str) -> str:
19
+ if not text: return ""
20
+
21
+ text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "")
22
+
23
+ tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
24
+ if tag_match:
25
+ text = tag_match.group(1)
26
+
27
+ if "```python" in text:
28
+ content = text.split("```python")[-1]
29
+ if "```" in content:
30
+ return content.split("```")[0].strip()
31
+ return content.strip()
32
+ elif "```" in text:
33
+ content = text.split("```")[-1]
34
+ if "```" in content:
35
+ return content.split("```")[0].strip()
36
+ return content.strip()
37
+
38
+ lines = text.split('\n')
39
+ cleaned_lines = []
40
+ stop_words = ["Explanation:", "Example:", "Test Case:", "Output:"]
41
+ for line in lines:
42
+ if any(sw in line for sw in stop_words):
43
+ break
44
+ cleaned_lines.append(line)
45
+
46
+ return "\n".join(cleaned_lines).strip()
47
+
48
+ def normalize_code_for_voting(code: str) -> str:
49
+ try:
50
+ tree = ast.parse(code)
51
+ for node in ast.walk(tree):
52
+ if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
53
+ if (node.body and isinstance(node.body[0], ast.Expr) and
54
+ isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)):
55
+ node.body.pop(0)
56
+ return ast.unparse(tree).strip()
57
+ except:
58
+ return re.sub(r"\s+", "", code)
59
+
60
+ def sanitize(prompt: str, completion: str, entrypoint: str) -> str:
61
+ if f"def {entrypoint}" in completion:
62
+ return completion
63
+ return prompt + "\n" + completion
64
+
65
+ def run_evaluation(target_path):
66
+ if os.path.isdir(target_path):
67
+ jsonl_files = glob.glob(os.path.join(target_path, "**/*.jsonl"), recursive=True)
68
+ else:
69
+ jsonl_files = [target_path]
70
+
71
+ if not jsonl_files:
72
+ print(f"未在路径 {target_path} 下找到任何 .jsonl 文件")
73
+ return
74
+
75
+ print(f"共找到 {len(jsonl_files)} 个评测任务")
76
+ code_eval = hf_evaluate.load("code_eval")
77
+
78
+ for file_path in jsonl_files:
79
+ print(f"\n>>> 正在评测: {file_path}")
80
+ all_predictions = []
81
+ all_references = []
82
+ detailed_results = []
83
+ nfe_list = []
84
+ svf_list = []
85
+
86
+ with open(file_path, 'r', encoding='utf-8') as f:
87
+ lines = f.readlines()
88
+ if not lines: continue
89
+
90
+ for line in lines:
91
+ if not line.strip(): continue
92
+ item = json.loads(line)
93
+ doc = item.get("doc", {})
94
+ prompt = doc.get("prompt", "")
95
+ entry_point = doc.get("entry_point", "")
96
+ reference = doc.get("test", "")
97
+
98
+ current_nfe = item.get("nfe", 0)
99
+ nfe_list.append(current_nfe)
100
+ svf_list.append(item.get("svf_calls", 0))
101
+
102
+ resps = item.get("resps", [])
103
+ candidate_stats = {}
104
+
105
+ for r in resps:
106
+ raw_text = r[0] if isinstance(r, list) else r
107
+ completion = extract_python_code(raw_text)
108
+ full_code = sanitize(prompt, completion, entry_point)
109
+
110
+ try:
111
+ ast.parse(full_code)
112
+ is_valid = True
113
+ except:
114
+ is_valid = False
115
+
116
+ logic_norm = normalize_code_for_voting(full_code)
117
+ if not logic_norm: continue
118
+
119
+ if logic_norm not in candidate_stats:
120
+ candidate_stats[logic_norm] = {"count": 0, "valid": is_valid, "code": full_code}
121
+ candidate_stats[logic_norm]["count"] += 1
122
+
123
+ if not candidate_stats:
124
+ voted_code = prompt
125
+ else:
126
+ sorted_logics = sorted(
127
+ candidate_stats.keys(),
128
+ key=lambda k: (candidate_stats[k]["valid"], candidate_stats[k]["count"]),
129
+ reverse=True
130
+ )
131
+ voted_code = candidate_stats[sorted_logics[0]]["code"]
132
+
133
+ all_predictions.append([voted_code])
134
+ all_references.append(reference)
135
+ detailed_results.append({
136
+ "task_id": doc.get("task_id", doc.get("name", "N/A")),
137
+ "voted_code": voted_code,
138
+ "nfe": current_nfe,
139
+ "svf_calls": item.get("svf_calls", 0),
140
+ "candidates_count": len(candidate_stats)
141
+ })
142
+
143
+ if not all_predictions: continue
144
+
145
+ print(f"正在执行代码测试 (共 {len(all_predictions)} 题)...")
146
+ pass_at_k, exec_results = code_eval.compute(
147
+ references=all_references,
148
+ predictions=all_predictions,
149
+ k=[1],
150
+ num_workers=4
151
+ )
152
+
153
+ accuracy = pass_at_k.get("pass@1", 0.0) * 100
154
+ avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0
155
+ avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0
156
+
157
+ print(f"Accuracy: {accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---")
158
+
159
+ output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}"
160
+ output_path = os.path.join(os.path.dirname(file_path), output_name)
161
+
162
+ for i, detail in enumerate(detailed_results):
163
+ res_list = exec_results.get(i, [])
164
+ detail["is_correct"] = res_list[0][1]["passed"] if res_list else False
165
+
166
+ final_report = {
167
+ "summary": {
168
+ "accuracy": f"{accuracy:.2f}%",
169
+ "nfe": avg_nfe,
170
+ "svf_calls": avg_svf
171
+ },
172
+ "details": detailed_results
173
+ }
174
+
175
+ with open(output_path, 'w', encoding='utf-8') as out_f:
176
+ json.dump(final_report, out_f, ensure_ascii=False, indent=4)
177
+ print(f"报告已保存至: {output_path}")
178
+
179
+ if __name__ == "__main__":
180
+ parser = argparse.ArgumentParser()
181
+ parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
182
+ args = parser.parse_args()
183
+ run_evaluation(args.res_path)
Prism/LLaDA/LLaDA_Baseline/metrics/math500_all.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import os
4
+ import math
5
+ import argparse
6
+ from collections import Counter
7
+
8
+ RES_PATH = "<PATH_TO_RESULTS_JSONL>"
9
+
10
+ def extract_answer(text):
11
+ if not text:
12
+ return "", False
13
+ text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").strip()
14
+
15
+ boxed_pattern = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}"
16
+ all_boxes = re.findall(boxed_pattern, text)
17
+ if all_boxes:
18
+ return all_boxes[-1], True
19
+
20
+ tag_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
21
+ if tag_match:
22
+ return tag_match.group(1).strip(), True
23
+
24
+ marker = "the answer is"
25
+ if marker in text.lower():
26
+ pos = text.lower().rfind(marker)
27
+ after_text = text[pos + len(marker):].strip()
28
+ after_text = re.sub(r"^[:\s]+", "", after_text)
29
+ return after_text.split('\n')[0].split('$')[0].strip(), True
30
+
31
+ tail = text[-50:].strip()
32
+ nums = re.findall(r"(-?\d+[\./\d]*|\\sqrt\{\d+\}|\(-?\d+.*?\))", tail)
33
+ if nums:
34
+ return nums[-1], False
35
+ return "", False
36
+
37
+ def normalize_math(string):
38
+ if not string: return ""
39
+ string = str(string).lower().strip()
40
+
41
+ string = string.replace("</reasoning>", "").replace("</answer>", "").replace("<answer>", "")
42
+ string = string.replace("...", "").replace("cannot be determined", "")
43
+
44
+ string = re.sub(r"([a-z]+|\\theta|\\alpha|\\pi)\s*=\s*", "", string)
45
+ string = re.sub(r"\\text\{([^}]*)\}", r"\1", string)
46
+ string = re.sub(r"\\(mathbf|mathrm|bold|unit|mbox|operatorname|mathrm)\{([^}]*)\}", r"\2", string)
47
+ string = re.sub(r"\\(d|t)?frac\{([^{}]*)\}\{([^{}]*)\}", r"\2/\3", string)
48
+ string = string.replace("\\!", "").replace("\\ ", "").replace("{", "").replace("}", "")
49
+ string = string.replace("\\left", "").replace("\\right", "")
50
+ string = string.replace("\\$", "").replace("$", "").replace("\\%", "").replace("%", "")
51
+
52
+ units_pattern = r"(units?|cm\^2|cm|inches|inch|square|degrees?|radians?|miles?|per|hour|cents?)"
53
+ string = re.sub(units_pattern, "", string)
54
+ string = string.replace("^{\\circ}", "").replace("^\\circ", "").replace("°", "").replace("\\degree", "")
55
+ string = string.replace("\\pi", "pi")
56
+ string = re.sub(r"(\d),(\d{3})", r"\1\2", string)
57
+ string = string.rstrip(".:,; ").replace(" ", "")
58
+
59
+ if "=" in string:
60
+ string = string.split("=")[-1]
61
+
62
+ return string
63
+
64
+ def is_equiv(pred, gold):
65
+ if not pred: return False
66
+ p, g = normalize_math(pred), normalize_math(gold)
67
+ if p == g: return True
68
+
69
+ if "=" in pred:
70
+ if normalize_math(pred.split("=")[-1]) == g:
71
+ return True
72
+
73
+ try:
74
+ def to_float(s):
75
+ if '/' in s and s.count('/') == 1:
76
+ parts = s.split('/')
77
+ return float(parts[0]) / float(parts[1])
78
+ if '_' in s: s = s.split('_')[0]
79
+ return float(s)
80
+ return math.isclose(to_float(p), to_float(g), rel_tol=1e-4)
81
+ except:
82
+ p_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", p)
83
+ g_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", g)
84
+ return p_fuzzy == g_fuzzy if p_fuzzy else False
85
+
86
+ def run_evaluation(target_path):
87
+ jsonl_files = []
88
+ if os.path.isdir(target_path):
89
+ for root, dirs, files in os.walk(target_path):
90
+ for file in files:
91
+ if file.endswith(".jsonl") and not file.startswith("eval_voted_"):
92
+ jsonl_files.append(os.path.join(root, file))
93
+ else:
94
+ jsonl_files = [target_path]
95
+
96
+ for file_path in jsonl_files:
97
+ print(f">>> 正在评测: {file_path}")
98
+ detailed_results = []
99
+
100
+ voted_correct_count = 0
101
+ pass_at_k_count = 0
102
+ total_count = 0
103
+
104
+ nfe_list = []
105
+ svf_list = []
106
+
107
+ with open(file_path, 'r', encoding='utf-8') as f:
108
+ for line in f:
109
+ if not line.strip(): continue
110
+ try:
111
+ item = json.loads(line)
112
+ except:
113
+ continue
114
+
115
+ doc = item.get("doc", {})
116
+ ground_truth = str(item.get("target", doc.get("answer", "")))
117
+
118
+ current_nfe = item.get("nfe", 0)
119
+ nfe_list.append(current_nfe)
120
+ current_svf = item.get("svf_calls", 0)
121
+ svf_list.append(current_svf)
122
+
123
+ ans_stats = {}
124
+ trajectories = item.get("all_trajectories", [])
125
+
126
+ has_correct_trajectory = False
127
+
128
+ for traj in trajectories:
129
+ raw_text = traj.get("resp", "")
130
+ score = traj.get("score", 0)
131
+
132
+ extracted, _ = extract_answer(raw_text)
133
+ if not extracted: continue
134
+
135
+ if is_equiv(extracted, ground_truth):
136
+ has_correct_trajectory = True
137
+
138
+ norm = normalize_math(extracted)
139
+ if norm not in ans_stats:
140
+ ans_stats[norm] = {
141
+ "count": 0,
142
+ "max_score": -float('inf'),
143
+ "total_weight": 0.0,
144
+ "original": extracted
145
+ }
146
+
147
+ ans_stats[norm]["count"] += 1
148
+ if score > ans_stats[norm]["max_score"]:
149
+ ans_stats[norm]["max_score"] = score
150
+
151
+ try:
152
+ weight = math.exp(score)
153
+ except OverflowError:
154
+ weight = float('inf')
155
+ ans_stats[norm]["total_weight"] += weight
156
+
157
+ if has_correct_trajectory:
158
+ pass_at_k_count += 1
159
+
160
+ if not ans_stats:
161
+ best_pred = ""
162
+ else:
163
+ sorted_norms = sorted(
164
+ ans_stats.keys(),
165
+ key=lambda x: (ans_stats[x]["total_weight"], ans_stats[x]["max_score"], ans_stats[x]["count"]),
166
+ reverse=True
167
+ )
168
+ best_norm = sorted_norms[0]
169
+ best_pred = ans_stats[best_norm]["original"]
170
+
171
+ is_voted_correct = False
172
+ if best_pred and is_equiv(best_pred, ground_truth):
173
+ voted_correct_count += 1
174
+ is_voted_correct = True
175
+
176
+ total_count += 1
177
+
178
+ detailed_results.append({
179
+ "question": doc.get("problem", "N/A"),
180
+ "final_voted_answer": best_pred,
181
+ "ground_truth": ground_truth,
182
+ "is_voted_correct": is_voted_correct,
183
+ "nfe": current_nfe,
184
+ "svf_calls": current_svf
185
+ })
186
+
187
+ pass_at_1_accuracy = (voted_correct_count / total_count * 100) if total_count > 0 else 0
188
+ avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0
189
+ avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0
190
+
191
+ print(f"--- Accuracy: {pass_at_1_accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---")
192
+
193
+ output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}"
194
+ output_path = os.path.join(os.path.dirname(file_path), output_name)
195
+
196
+ final_report = {
197
+ "summary": {
198
+ "accuracy": f"{pass_at_1_accuracy:.2f}%",
199
+ "correct_voted_count": voted_correct_count,
200
+ "total": total_count,
201
+ "nfe": avg_nfe,
202
+ "svf_calls": avg_svf
203
+ },
204
+ "details": detailed_results
205
+ }
206
+ with open(output_path, 'w', encoding='utf-8') as out_f:
207
+ json.dump(final_report, out_f, ensure_ascii=False, indent=4)
208
+
209
+ if __name__ == "__main__":
210
+ parser = argparse.ArgumentParser()
211
+ parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
212
+ args = parser.parse_args()
213
+ run_evaluation(args.res_path)
Prism/LLaDA/LLaDA_Baseline/metrics/mbpp_all.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import ast
4
+ import glob
5
+ import re
6
+ import argparse
7
+ from typing import Dict, List, Optional, Set, Tuple
8
+ import evaluate as hf_evaluate
9
+
10
+ RES_PATH = "<PATH_TO_RESULTS_JSONL>"
11
+
12
+ os.environ["HF_ALLOW_CODE_EVAL"] = "1"
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
15
+ def extract_python_code(text: str) -> str:
16
+ if not text: return ""
17
+
18
+ text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "")
19
+
20
+ tag_matches = re.findall(r"<answer>(.*?)</answer>", text, re.DOTALL)
21
+ if tag_matches:
22
+ for block in tag_matches:
23
+ if "def " in block:
24
+ text = block
25
+ break
26
+ else:
27
+ text = tag_matches[0]
28
+
29
+ if "```python" in text:
30
+ blocks = text.split("```python")
31
+ for b in blocks[1:]:
32
+ code = b.split("```")[0].strip()
33
+ if "def " in code: return code
34
+ elif "```" in text:
35
+ blocks = text.split("```")
36
+ for b in blocks[1:]:
37
+ code = b.strip()
38
+ if "def " in code: return code
39
+
40
+ lines = text.split('\n')
41
+ cleaned_lines = []
42
+ stop_words = ["Explanation:", "Example:", "Test Case:", "Output:", "Reasoning:"]
43
+ for line in lines:
44
+ if any(sw in line for sw in stop_words): break
45
+ cleaned_lines.append(line)
46
+
47
+ return "\n".join(cleaned_lines).strip()
48
+
49
+ def normalize_code_for_voting(code: str) -> str:
50
+ try:
51
+ tree = ast.parse(code)
52
+ for node in ast.walk(tree):
53
+ if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
54
+ if (node.body and isinstance(node.body[0], ast.Expr) and
55
+ isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)):
56
+ node.body.pop(0)
57
+ return ast.unparse(tree).strip()
58
+ except:
59
+ return re.sub(r"\s+", "", code)
60
+
61
+ def run_evaluation(target_path):
62
+ target_path = os.path.abspath(target_path)
63
+
64
+ if os.path.isdir(target_path):
65
+ search_pattern = os.path.join(target_path, "**/*.jsonl")
66
+ jsonl_files = glob.glob(search_pattern, recursive=True)
67
+ jsonl_files = [f for f in jsonl_files if not os.path.basename(f).startswith("eval_mbpp_")]
68
+ else:
69
+ jsonl_files = [target_path]
70
+
71
+ if not jsonl_files:
72
+ print(f"Error: 在路径 {target_path} 及其子目录下未找到任何 .jsonl 文件。")
73
+ return
74
+
75
+ try:
76
+ code_eval = hf_evaluate.load("code_eval")
77
+ except:
78
+ print("Error: Could not load code_eval. Ensure 'evaluate' and 'code_eval' are installed.")
79
+ return
80
+
81
+ for file_path in jsonl_files:
82
+ print(f"\n>>> 正在评测 MBPP 文件: {file_path}")
83
+ all_candidate_predictions = []
84
+ all_voted_predictions = []
85
+ all_references = []
86
+ detailed_results = []
87
+ nfe_list = []
88
+ svf_list = []
89
+
90
+ with open(file_path, 'r', encoding='utf-8') as f:
91
+ for line in f:
92
+ if not line.strip(): continue
93
+ item = json.loads(line)
94
+
95
+ doc = item.get("doc", {})
96
+ test_list = doc.get("test_list", [])
97
+ test_setup = doc.get("test_setup_code", "")
98
+ full_reference = (test_setup + "\n" + "\n".join(test_list)).strip()
99
+
100
+ item_nfe = item.get("nfe", 0)
101
+ item_svf = item.get("svf_calls", 0)
102
+ nfe_list.append(item_nfe)
103
+ svf_list.append(item_svf)
104
+
105
+ resps = item.get("resps", [])
106
+ trajs = item.get("all_trajectories", [])
107
+
108
+ candidate_stats = {}
109
+ processed_candidates = []
110
+
111
+ source_data = trajs if trajs else resps
112
+ for idx, entry in enumerate(source_data):
113
+ raw_text = entry.get("resp", "") if isinstance(entry, dict) else (entry[0] if isinstance(entry, list) else entry)
114
+ score = entry.get("score", 0) if isinstance(entry, dict) else 0
115
+
116
+ code = extract_python_code(raw_text)
117
+ if not code: continue
118
+
119
+ processed_candidates.append(code)
120
+
121
+ try:
122
+ ast.parse(code)
123
+ is_valid = True
124
+ except:
125
+ is_valid = False
126
+
127
+ norm = normalize_code_for_voting(code)
128
+ if norm not in candidate_stats:
129
+ candidate_stats[norm] = {"count": 0, "valid": is_valid, "code": code, "max_score": -float('inf')}
130
+ candidate_stats[norm]["count"] += 1
131
+ candidate_stats[norm]["max_score"] = max(candidate_stats[norm]["max_score"], score)
132
+
133
+ if not candidate_stats:
134
+ voted_code = ""
135
+ else:
136
+ sorted_norms = sorted(
137
+ candidate_stats.keys(),
138
+ key=lambda k: (candidate_stats[k]["valid"], candidate_stats[k]["max_score"], candidate_stats[k]["count"]),
139
+ reverse=True
140
+ )
141
+ voted_code = candidate_stats[sorted_norms[0]]["code"]
142
+
143
+ all_candidate_predictions.append(processed_candidates if processed_candidates else [""])
144
+ all_voted_predictions.append([voted_code])
145
+ all_references.append(full_reference)
146
+
147
+ detailed_results.append({
148
+ "task_id": doc.get("task_id", "N/A"),
149
+ "voted_code": voted_code,
150
+ "nfe": item_nfe,
151
+ "svf_calls": item_svf,
152
+ "candidates_count": len(processed_candidates)
153
+ })
154
+
155
+ if not all_voted_predictions:
156
+ continue
157
+
158
+ print(f"正在测试代码 (共 {len(all_voted_predictions)} 题)...")
159
+ res_voted, details_voted = code_eval.compute(references=all_references, predictions=all_voted_predictions, k=[1])
160
+ res_pk, details_pk = code_eval.compute(references=all_references, predictions=all_candidate_predictions, k=[1])
161
+
162
+ acc_voted = res_voted.get("pass@1", 0.0) * 100
163
+ acc_pk = res_pk.get("pass@1", 0.0) * 100
164
+ avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0
165
+ avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0
166
+
167
+ print(f"--- Pass@1: {acc_voted:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---")
168
+
169
+ for i, detail in enumerate(detailed_results):
170
+ detail["is_voted_correct"] = details_voted.get(i, [[0, {"passed": False}]])[0][1]["passed"]
171
+
172
+ file_dir = os.path.dirname(file_path)
173
+ base_name = os.path.basename(file_path)
174
+ output_name = f"eval_mbpp_{base_name.replace('.jsonl', '.json')}"
175
+ output_path = os.path.join(file_dir, output_name)
176
+
177
+ final_report = {
178
+ "summary": {
179
+ "pass_at_1": f"{acc_voted:.2f}%",
180
+ "avg_nfe": avg_nfe,
181
+ "avg_svf": avg_svf
182
+ },
183
+ "details": detailed_results
184
+ }
185
+
186
+ with open(output_path, 'w', encoding='utf-8') as out_f:
187
+ json.dump(final_report, out_f, ensure_ascii=False, indent=4)
188
+ print(f"成功保存结果至: {output_path}")
189
+
190
+ if __name__ == "__main__":
191
+ parser = argparse.ArgumentParser()
192
+ parser.add_argument("-r", "--res_path", type=str, default=RES_PATH)
193
+ args = parser.parse_args()
194
+ run_evaluation(args.res_path)
Prism/LLaDA/LLaDA_Baseline/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ sacrebleu
2
+ evaluate
3
+ datasets
4
+ numpy
5
+ pandas
6
+ tqdm
7
+ regex
8
+ sqlitedict
9
+ pytablewriter
Prism/LLaDA/LLaDA_Baseline/scripts/run_gsm8k.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+ set -x
4
+
5
+
6
+ PROJECT_ROOT="<PATH_TO_YOUR_ROOT>"
7
+ cd "$PROJECT_ROOT"
8
+
9
+ MODEL_PATH="<PATH_TO_YOUR_LLaDA_8B_INSTRUCT_WEIGHTS>"
10
+
11
+ BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/results_gsm8k"
12
+
13
+ export CUDA_VISIBLE_DEVICES=0
14
+ export HF_ENDPOINT=https://hf-mirror.com
15
+
16
+ LENGTH=256
17
+ STEPS=32
18
+ BLOCK=32
19
+ TASK="gsm8k"
20
+ NAME="baseline"
21
+
22
+ mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
23
+
24
+ accelerate launch evaluation_script.py \
25
+ --model LLaDA \
26
+ --tasks ${TASK} \
27
+ --batch_size 1 \
28
+ --model_args "pretrained=${MODEL_PATH},mask_id=126336,assistant_prefix=<reasoning>" \
29
+ --gen_kwargs "use_hts=True,hts_N=1,hts_mode=False,steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},task_type=math,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/baseline.jsonl" \
30
+ --num_fewshot 0 \
31
+ --confirm_run_unsafe_code \
32
+ --output_path "${BASE_OUTPUT_PATH}/${NAME}"
Prism/LLaDA/LLaDA_Baseline/scripts/run_humaneval.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+ set -x
4
+
5
+ PROJECT_ROOT="<PATH_TO_YOUR_ROOT>"
6
+ MODEL_PATH="<PATH_TO_YOUR_LLaDA_8B_INSTRUCT_WEIGHTS>"
7
+ BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/results_humaneval"
8
+
9
+ cd "$PROJECT_ROOT"
10
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
11
+ export HF_ENDPOINT=https://hf-mirror.com
12
+
13
+ LENGTH=512
14
+ STEPS=32
15
+ BLOCK=32
16
+ TASK="humaneval"
17
+ NAME="baseline"
18
+
19
+ mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
20
+
21
+ accelerate launch evaluation_script.py \
22
+ --model LLaDA \
23
+ --tasks ${TASK} \
24
+ --batch_size 1 \
25
+ --model_args "pretrained=${MODEL_PATH},mask_id=126336,assistant_prefix=<reasoning>" \
26
+ --gen_kwargs "use_hts=True,hts_N=1,hts_mode=False,steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},task_type=code,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/baseline.jsonl" \
27
+ --num_fewshot 0 \
28
+ --confirm_run_unsafe_code \
29
+ --output_path "${BASE_OUTPUT_PATH}/${NAME}"
Prism/LLaDA/LLaDA_Baseline/scripts/run_math500.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+ set -x
4
+
5
+ PROJECT_ROOT="<PATH_TO_YOUR_PRISM_ROOT>"
6
+ MODEL_PATH="<PATH_TO_YOUR_LLaDA_8B_INSTRUCT_WEIGHTS>"
7
+ BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/results_math500"
8
+
9
+ cd "$PROJECT_ROOT"
10
+ export CUDA_VISIBLE_DEVICES=0
11
+ export HF_ENDPOINT=https://hf-mirror.com
12
+
13
+ LENGTH=256
14
+ STEPS=32
15
+ BLOCK=32
16
+ TASK="math500"
17
+ NAME="baseline"
18
+
19
+ mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
20
+
21
+ accelerate launch evaluation_script.py \
22
+ --model LLaDA \
23
+ --tasks ${TASK} \
24
+ --batch_size 1 \
25
+ --model_args "pretrained=${MODEL_PATH},mask_id=126336,assistant_prefix=<reasoning>" \
26
+ --gen_kwargs "use_hts=True,hts_N=1,hts_mode=False,steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},task_type=math,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/baseline.jsonl" \
27
+ --num_fewshot 0 \
28
+ --confirm_run_unsafe_code \
29
+ --output_path "${BASE_OUTPUT_PATH}/${NAME}"
Prism/LLaDA/LLaDA_Baseline/scripts/run_mbpp.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+ set -x
4
+
5
+ PROJECT_ROOT="<PATH_TO_YOUR_PRISM_ROOT>"
6
+ MODEL_PATH="<PATH_TO_YOUR_LLaDA_8B_INSTRUCT_WEIGHTS>"
7
+ BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/results_mbpp_k4"
8
+
9
+ cd "$PROJECT_ROOT"
10
+ export CUDA_VISIBLE_DEVICES=0
11
+ export HF_ENDPOINT=https://hf-mirror.com
12
+
13
+ LENGTH=512
14
+ STEPS=32
15
+ BLOCK=32
16
+ TASK="mbpp"
17
+ NAME="baseline"
18
+
19
+ mkdir -p "${BASE_OUTPUT_PATH}/${NAME}"
20
+
21
+ accelerate launch evaluation_script.py \
22
+ --model LLaDA \
23
+ --tasks ${TASK} \
24
+ --batch_size 1 \
25
+ --model_args "pretrained=${MODEL_PATH},mask_id=126336,assistant_prefix=<reasoning>" \
26
+ --gen_kwargs "use_hts=True,hts_N=1,hts_mode=False,steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},task_type=math,temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/baseline.jsonl" \
27
+ --num_fewshot 0 \
28
+ --confirm_run_unsafe_code \
29
+ --output_path "${BASE_OUTPUT_PATH}/${NAME}"
Prism/LLaDA/LLaDA_Prism/.gitignore ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.jsonl
2
+ *.json
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[codz]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py.cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # UV
101
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ #uv.lock
105
+
106
+ # poetry
107
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111
+ #poetry.lock
112
+ #poetry.toml
113
+
114
+ # pdm
115
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
117
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
118
+ #pdm.lock
119
+ #pdm.toml
120
+ .pdm-python
121
+ .pdm-build/
122
+
123
+ # pixi
124
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
125
+ #pixi.lock
126
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
127
+ # in the .venv directory. It is recommended not to include this directory in version control.
128
+ .pixi
129
+
130
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
131
+ __pypackages__/
132
+
133
+ # Celery stuff
134
+ celerybeat-schedule
135
+ celerybeat.pid
136
+
137
+ # SageMath parsed files
138
+ *.sage.py
139
+
140
+ # Environments
141
+ .env
142
+ .envrc
143
+ .venv
144
+ env/
145
+ venv/
146
+ ENV/
147
+ env.bak/
148
+ venv.bak/
149
+
150
+ # Spyder project settings
151
+ .spyderproject
152
+ .spyproject
153
+
154
+ # Rope project settings
155
+ .ropeproject
156
+
157
+ # mkdocs documentation
158
+ /site
159
+
160
+ # mypy
161
+ .mypy_cache/
162
+ .dmypy.json
163
+ dmypy.json
164
+
165
+ # Pyre type checker
166
+ .pyre/
167
+
168
+ # pytype static type analyzer
169
+ .pytype/
170
+
171
+ # Cython debug symbols
172
+ cython_debug/
173
+
174
+ # PyCharm
175
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
176
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
177
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
178
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
179
+ #.idea/
180
+
181
+ # Abstra
182
+ # Abstra is an AI-powered process automation framework.
183
+ # Ignore directories containing user credentials, local state, and settings.
184
+ # Learn more at https://abstra.io/docs
185
+ .abstra/
186
+
187
+ # Visual Studio Code
188
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
189
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
190
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
191
+ # you could uncomment the following to ignore the entire vscode folder
192
+ # .vscode/
193
+
194
+ # Ruff stuff:
195
+ .ruff_cache/
196
+
197
+ # PyPI configuration file
198
+ .pypirc
199
+
200
+ # Cursor
201
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
202
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
203
+ # refer to https://docs.cursor.com/context/ignore-files
204
+ .cursorignore
205
+ .cursorindexingignore
206
+
207
+ # Marimo
208
+ marimo/_static/
209
+ marimo/_lsp/
210
+ __marimo__/
Prism/LLaDA/LLaDA_Prism/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 preordinary
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Prism/LLaDA/LLaDA_Prism/evaluation_script.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ from dllm_eval.__main__ import cli_evaluate
6
+
7
+
8
+ def set_seed(seed):
9
+ torch.manual_seed(seed)
10
+ random.seed(seed)
11
+ np.random.seed(seed)
12
+
13
+ torch.backends.cudnn.deterministic = True
14
+ torch.backends.cudnn.benchmark = False
15
+
16
+
17
+ if __name__ == "__main__":
18
+ os.environ["HF_ALLOW_CODE_EVAL"] = "1"
19
+ os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
20
+ set_seed(42)
21
+ cli_evaluate()
Prism/LLaDA/LLaDA_Prism/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ sacrebleu
2
+ evaluate
3
+ datasets
4
+ numpy
5
+ pandas
6
+ tqdm
7
+ regex
8
+ sqlitedict
9
+ pytablewriter
Prism/LLaDA/LLaDA_Truthfulqa/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ outputs
2
+ logs
3
+ LLaDA-8B-Instruct
Prism/LLaDA/LLaDA_Truthfulqa/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Prism/LLaDA/LLaDA_Truthfulqa/eval_llada.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file is inspired by the code from https://github.com/ML-GSAI/SMDM
3
+ '''
4
+ import accelerate
5
+ import torch
6
+ import re
7
+ from pathlib import Path
8
+ import random
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+ from datasets import Dataset
12
+ from lm_eval.__main__ import cli_evaluate
13
+ from lm_eval.api.instance import Instance
14
+ from lm_eval.api.model import LM
15
+ from lm_eval.api.registry import register_model
16
+ from tqdm import tqdm
17
+
18
+ from transformers import AutoTokenizer, AutoModel
19
+ import json
20
+ import os
21
+ import time
22
+
23
+
24
+ def set_seed(seed):
25
+ torch.manual_seed(seed)
26
+ random.seed(seed)
27
+ np.random.seed(seed)
28
+
29
+ torch.backends.cudnn.deterministic = True
30
+ torch.backends.cudnn.benchmark = False
31
+
32
+
33
+ def _sample_categorical(categorical_probs):
34
+ gumbel_norm = (
35
+ 1e-10
36
+ - (torch.rand_like(categorical_probs) + 1e-10).log()).to(categorical_probs.dtype)
37
+ return (categorical_probs / gumbel_norm).argmax(dim=-1)
38
+
39
+
40
+ @register_model("llada_dist")
41
+ class LLaDAEvalHarness(LM):
42
+ def __init__(
43
+ self,
44
+ model_path='',
45
+ mask_id=126336,
46
+ max_length=4096,
47
+ generated_samples_path='',
48
+ batch_size=32,
49
+ mc_num=128,
50
+ is_check_greedy=True,
51
+ cfg=0.,
52
+ sampling_steps=512,
53
+ mask_length=512,
54
+ block_size=32,
55
+ remasking='low_confidence',
56
+ device="cuda",
57
+ sampler='',
58
+ remdm_number=0
59
+ ):
60
+ '''
61
+ Args:
62
+ model_path: LLaDA-8B-Base model path.
63
+ mask_id: The token id of [MASK] is 126336.
64
+ max_length: the max sequence length.
65
+ batch_size: mini batch size.
66
+ mc_num: Monte Carlo estimation iterations
67
+ is_check_greedy: For certain metrics like LAMBADA, the evaluation requires the model to verify whether the answer
68
+ is generated through greedy sampling conditioned on the prompt (note that this differs from conditional
69
+ generation). We implement this verification through the suffix_greedy_prediction() function, which
70
+ returns a True/False judgment used for accuracy calculation.
71
+ When is_check_greedy is set to True, the lm-evaluation-harness library automatically invokes this function.
72
+ However, since none of the metrics in the LLaDA paper (https://arxiv.org/abs/2502.09992) require this functionality,
73
+ we recommend setting is_check_greedy to False. This configuration causes suffix_greedy_prediction() to return False
74
+ by default, significantly accelerating the evaluation process.
75
+ cfg_scale: Unsupervised classifier-free guidance scale.
76
+ '''
77
+ super().__init__()
78
+
79
+ accelerator = accelerate.Accelerator()
80
+ if accelerator.num_processes > 1:
81
+ self.accelerator = accelerator
82
+ else:
83
+ self.accelerator = None
84
+
85
+ model_kwargs = {}
86
+ if self.accelerator is not None:
87
+ model_kwargs.update({'device_map': {'': f'{self.accelerator.device}'}})
88
+
89
+ self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, **model_kwargs)
90
+ self.model.eval()
91
+
92
+ self.device = torch.device(device)
93
+ if self.accelerator is not None:
94
+ self.model = self.accelerator.prepare(self.model)
95
+ self.device = torch.device(f'{self.accelerator.device}')
96
+ self._rank = self.accelerator.local_process_index
97
+ self._world_size = self.accelerator.num_processes
98
+ else:
99
+ self.model = self.model.to(device)
100
+
101
+ self.mask_id = mask_id
102
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
103
+
104
+ self.mc_num = mc_num
105
+ self.batch_size = int(batch_size)
106
+ assert mc_num % self.batch_size == 0
107
+ self.sampling_eps = 0.
108
+ self.max_length = max_length
109
+ self.is_check_greedy = is_check_greedy
110
+
111
+ self.generated_samples_path = generated_samples_path
112
+ self.sampler = sampler
113
+ self.remdm_number = remdm_number
114
+
115
+ self.cfg = cfg
116
+ self.sampling_steps = sampling_steps
117
+ self.mask_length = mask_length
118
+ self.block_size = block_size
119
+ self.remasking = remasking
120
+ print(self.generated_samples_path)
121
+
122
+ @property
123
+ def rank(self):
124
+ return self._rank
125
+
126
+ @property
127
+ def world_size(self):
128
+ return self._world_size
129
+
130
+ def _forward_process(self, batch, prompt_index):
131
+ b, l = batch.shape
132
+
133
+ target_len = (l - prompt_index.sum()).item()
134
+ k = torch.randint(1, target_len + 1, (), device=batch.device)
135
+
136
+ x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long()
137
+ x = ((x - 1) % target_len) + 1
138
+ assert x.min() >= 1 and x.max() <= target_len
139
+
140
+ indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
141
+ is_mask = indices < x.unsqueeze(1)
142
+
143
+ for i in range(b):
144
+ is_mask[i] = is_mask[i][torch.randperm(target_len)]
145
+
146
+ is_mask = torch.cat((torch.zeros(b, prompt_index.sum(), dtype=torch.bool, device=batch.device), is_mask), dim=1)
147
+
148
+ noisy_batch = torch.where(is_mask, self.mask_id, batch)
149
+
150
+ return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
151
+
152
+ @torch.no_grad()
153
+ def get_logits(self, batch, prompt_index):
154
+ if self.cfg > 0.:
155
+ assert len(prompt_index) == batch.shape[1]
156
+ prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
157
+ un_batch = batch.clone()
158
+ un_batch[prompt_index] = self.mask_id
159
+ batch = torch.cat([batch, un_batch])
160
+
161
+ logits = self.model(batch).logits
162
+
163
+ if self.cfg > 0.:
164
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
165
+ logits = un_logits + (self.cfg + 1) * (logits - un_logits)
166
+ return logits[:, :batch.shape[1]]
167
+
168
+ @torch.no_grad()
169
+ def get_loglikelihood(self, prefix, target):
170
+ seq = torch.concatenate([prefix, target])[None, :]
171
+ seq = seq.repeat((self.batch_size, 1)).to(self.device)
172
+
173
+ prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
174
+
175
+ loss_acc = []
176
+ for _ in range(self.mc_num // self.batch_size):
177
+ perturbed_seq, p_mask = self._forward_process(seq, prompt_index)
178
+
179
+ mask_indices = perturbed_seq == self.mask_id
180
+
181
+ logits = self.get_logits(perturbed_seq, prompt_index)
182
+
183
+ loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
184
+ loss = loss.sum() / self.batch_size
185
+ loss_acc.append(loss.item())
186
+
187
+ return - sum(loss_acc) / len(loss_acc)
188
+
189
+ @torch.no_grad()
190
+ def suffix_greedy_prediction(self, prefix, target):
191
+ if not self.is_check_greedy:
192
+ return False
193
+
194
+ seq = torch.full((1, len(prefix) + len(target)), self.mask_id, device=self.device)
195
+ prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
196
+ prefix, target = prefix.to(self.device), target.to(self.device)
197
+ seq[0, :len(prefix)] = prefix
198
+
199
+ for i in range(len(target)):
200
+ mask_index = (seq == self.mask_id)
201
+ logits = self.get_logits(seq, prompt_index)[mask_index]
202
+ x0 = torch.argmax(logits, dim=-1)
203
+
204
+ p = torch.softmax(logits.to(torch.float32), dim=-1)
205
+ confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(dim=-1)
206
+ _, index = torch.sort(confidence, descending=True)
207
+ x0[index[1:]] = self.mask_id
208
+ seq[mask_index] = x0.clone()
209
+ correct = target == seq[0, len(prefix):]
210
+ correct = torch.all(correct)
211
+ return correct
212
+
213
+ def _encode_pair(self, context, continuation):
214
+ n_spaces = len(context) - len(context.rstrip())
215
+ if n_spaces > 0:
216
+ continuation = context[-n_spaces:] + continuation
217
+ context = context[:-n_spaces]
218
+
219
+ whole_enc = self.tokenizer(context + continuation)["input_ids"]
220
+ context_enc = self.tokenizer(context)["input_ids"]
221
+
222
+ context_enc_len = len(context_enc)
223
+ continuation_enc = whole_enc[context_enc_len:]
224
+
225
+ return context_enc, continuation_enc
226
+
227
+ def loglikelihood(self, requests):
228
+ def _tokenize(e):
229
+ prefix, target = self._encode_pair(e["prefix"], e["target"])
230
+ return {
231
+ "prefix_text": e["prefix"],
232
+ "target_text": e["target"],
233
+ "prefix": prefix,
234
+ "target": target,
235
+ }
236
+
237
+ ds = []
238
+ ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
239
+ ds = Dataset.from_list(ds)
240
+ ds = ds.map(_tokenize)
241
+ ds = ds.with_format("torch")
242
+ prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]
243
+
244
+ assert max(prompt_len) <= 4096
245
+
246
+ out = []
247
+ with torch.no_grad():
248
+ for elem in tqdm(ds, desc="Computing likelihood..."):
249
+ prefix = elem["prefix"]
250
+ target = elem["target"]
251
+
252
+ ll = self.get_loglikelihood(prefix, target)
253
+
254
+ is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)
255
+
256
+ out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
257
+ torch.cuda.empty_cache()
258
+ return out
259
+
260
+ def loglikelihood_rolling(self, requests):
261
+ raise NotImplementedError
262
+
263
+ @torch.no_grad()
264
+ def llada_conf_sample(self, prompt):
265
+ xt = torch.full((1, prompt.shape[1] + self.mask_length), self.mask_id, dtype=torch.long).to(self.model.device)
266
+ xt[:, :prompt.shape[1]] = prompt.clone()
267
+
268
+ prompt_index = (xt != self.mask_id)
269
+ prompt_len = prompt_index.sum(1).item()
270
+
271
+ assert self.mask_length % self.block_size == 0
272
+ num_blocks = self.mask_length // self.block_size
273
+
274
+ assert self.sampling_steps % num_blocks == 0
275
+ steps = self.sampling_steps // num_blocks
276
+
277
+ assert self.mask_length % self.sampling_steps == 0
278
+
279
+ for num_block in range(num_blocks):
280
+ for i in range(steps):
281
+ mask_index = (xt == self.mask_id)
282
+ logits = self.model(xt).logits
283
+ p = F.softmax(logits.to(torch.float64), dim=-1)
284
+ x0 = _sample_categorical(p)
285
+ x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
286
+
287
+ x0_p[:, prompt_len + (num_block + 1) * self.block_size:] = -np.inf
288
+ x0 = torch.where(mask_index, x0, xt)
289
+ confidence = torch.where(mask_index, x0_p, -np.inf)
290
+
291
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
292
+ for j in range(confidence.shape[0]):
293
+ _, select_index = torch.topk(confidence[j], k=int(self.mask_length / self.sampling_steps))
294
+ transfer_index[j, select_index] = True
295
+ xt[transfer_index] = x0[transfer_index]
296
+ if torch.sum(xt == self.tokenizer.eos_token_id) > 0:
297
+ return xt
298
+
299
+ return xt
300
+
301
+ @torch.no_grad()
302
+ def llada_remdm_sample(self, prompt):
303
+ xt = torch.full((1, prompt.shape[1] + self.mask_length), self.mask_id, dtype=torch.long).to(self.model.device)
304
+ xt[:, :prompt.shape[1]] = prompt.clone()
305
+
306
+ prompt_index = (xt != self.mask_id)
307
+ prompt_len = prompt_index.sum(1).item()
308
+
309
+ assert self.mask_length % self.block_size == 0
310
+ num_blocks = self.mask_length // self.block_size
311
+
312
+ assert self.sampling_steps % num_blocks == 0
313
+ steps = self.sampling_steps // num_blocks
314
+
315
+ assert self.mask_length % self.sampling_steps == 0
316
+
317
+ for num_block in range(num_blocks):
318
+ conf_cache = torch.ones_like(xt, dtype=torch.float64) * np.inf
319
+ remask_thres = int(self.block_size / 8 * 7)
320
+ for i in range(2 * steps):
321
+ if i >= remask_thres and i < remask_thres + steps:
322
+ remask_index = torch.zeros_like(xt, dtype=torch.bool, device=xt.device)
323
+ _, mask_indices = torch.topk(conf_cache, k=self.remdm_number, largest=False, dim=1)
324
+ remask_index[0, mask_indices] = True
325
+ conf_cache[remask_index] = np.inf
326
+ xt[remask_index] = self.mask_id
327
+ mask_index = (xt == self.mask_id)
328
+ logits = self.model(xt).logits
329
+ p = F.softmax(logits.to(torch.float64), dim=-1)
330
+ x0 = _sample_categorical(p)
331
+ x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
332
+
333
+ x0_p[:, prompt_len + (num_block + 1) * self.block_size:] = -np.inf
334
+ x0 = torch.where(mask_index, x0, xt)
335
+ confidence = torch.where(mask_index, x0_p, -np.inf)
336
+
337
+ if i >= remask_thres and i < remask_thres + steps:
338
+ transfer_length = self.remdm_number
339
+ else:
340
+ transfer_length = int(self.mask_length / self.sampling_steps)
341
+
342
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
343
+ for j in range(confidence.shape[0]):
344
+ _, select_index = torch.topk(confidence[j], k=transfer_length)
345
+ transfer_index[j, select_index] = True
346
+ xt[transfer_index] = x0[transfer_index]
347
+ conf_cache[transfer_index] = confidence[transfer_index]
348
+ if torch.sum(xt == self.tokenizer.eos_token_id) > 0:
349
+ return xt
350
+
351
+ return xt
352
+
353
+ @torch.no_grad()
354
+ def generate_until(self, requests: list[Instance]):
355
+ start_time = time.time()
356
+
357
+ def _tokenize(e):
358
+ return {
359
+ "question": self.tokenizer(e["question"])["input_ids"],
360
+ "question_text": e["question"],
361
+ "until": e["until"],
362
+ }
363
+
364
+ ds = [{"question": req.args[0], "until": req.args[1]['until']} for req in requests]
365
+ ds = Dataset.from_list(ds)
366
+ ds = ds.map(_tokenize)
367
+ ds = ds.with_format("torch")
368
+
369
+ out, out_for_json = [], []
370
+ for elem in tqdm(ds, desc="Generating..."):
371
+ prompt = elem["question"].unsqueeze(0).to(self.device)
372
+ stop_tokens = elem["until"] + ["<|eot_id|>", self.tokenizer.eos_token]
373
+
374
+ if self.sampler == 'llada_conf':
375
+ generated_answer = self.llada_conf_sample(prompt)
376
+ elif self.sampler == 'llada_remdm':
377
+ generated_answer = self.llada_remdm_sample(prompt)
378
+
379
+ generated_answer = self.tokenizer.decode(generated_answer[0][prompt.shape[1]:], skip_special_tokens=False)
380
+ # print(elem['question_text'] + generated_answer)
381
+ for stop_seq in stop_tokens:
382
+ if stop_seq in generated_answer:
383
+ generated_answer = generated_answer.split(stop_seq)[0]
384
+
385
+ # remove special tokens
386
+ generated_answer_ids = self.tokenizer(generated_answer)["input_ids"]
387
+ generated_answer = self.tokenizer.decode(generated_answer_ids, skip_special_tokens=True)
388
+ # print(elem['question_text'] + generated_answer)
389
+ out.append(generated_answer)
390
+ out_for_json.append({
391
+ "prefix": elem["question_text"],
392
+ "result": generated_answer,
393
+ })
394
+
395
+ if self.accelerator is not None:
396
+ self.accelerator.wait_for_everyone()
397
+
398
+ end_time = time.time()
399
+ total_duration = end_time - start_time
400
+ print(f"\n总耗时: {total_duration:.2f} 秒")
401
+
402
+ with open(os.path.join(self.generated_samples_path, str(self._rank) + ".json"), "w") as f:
403
+ final_output = {
404
+ "total_time_seconds": total_duration,
405
+ "samples": out_for_json
406
+ }
407
+ json.dump(final_output, f, indent=2)
408
+
409
+ return out
410
+
411
+
412
+ if __name__ == "__main__":
413
+ cli_evaluate()
Prism/LLaDA/LLaDA_Truthfulqa/eval_llada_prism.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file is inspired by the code from https://github.com/ML-GSAI/SMDM
3
+ And extended with Prism methods for LLaDA-Instruct.
4
+ '''
5
+ import accelerate
6
+ import torch
7
+ import re
8
+ from pathlib import Path
9
+ import random
10
+ import numpy as np
11
+ import torch.nn.functional as F
12
+ from datasets import Dataset
13
+ from lm_eval.__main__ import cli_evaluate
14
+ from lm_eval.api.instance import Instance
15
+ from lm_eval.api.model import LM
16
+ from lm_eval.api.registry import register_model
17
+ from tqdm import tqdm
18
+
19
+ from transformers import AutoTokenizer, AutoModel
20
+ import json
21
+ import os
22
+ import logging
23
+ import math
24
+ import textwrap
25
+ import time
26
+ from collections import Counter
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ def set_seed(seed):
31
+ torch.manual_seed(seed)
32
+ random.seed(seed)
33
+ np.random.seed(seed)
34
+ torch.backends.cudnn.deterministic = True
35
+ torch.backends.cudnn.benchmark = False
36
+
37
+ def _sample_categorical(categorical_probs):
38
+ gumbel_norm = (1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()).to(categorical_probs.dtype)
39
+ return (categorical_probs / gumbel_norm).argmax(dim=-1)
40
+
41
+ class CodeVerifier:
42
+ def __init__(self, model, tokenizer, device="cuda"):
43
+ self.model = model
44
+ self.tokenizer = tokenizer
45
+ self.device = device
46
+
47
+ self.yes_ids, self.no_ids = [], []
48
+ for t in ["Yes", " Yes", "YES"]:
49
+ ids = self.tokenizer.encode(t, add_special_tokens=False)
50
+ if ids: self.yes_ids.append(ids[-1])
51
+ for t in ["No", " No", "NO"]:
52
+ ids = self.tokenizer.encode(t, add_special_tokens=False)
53
+ if ids: self.no_ids.append(ids[-1])
54
+ self.yes_ids = list(set(self.yes_ids))
55
+ self.no_ids = list(set(self.no_ids))
56
+
57
+ def svf_score(self, prompt, code_str, task_type="code"):
58
+ max_len = 2000
59
+ truncated_code = code_str[:max_len]
60
+
61
+ if task_type == "code":
62
+ prompt_template = f"""
63
+ You are an expert programming contest judge. Your task is to evaluate a generated solution for a given problem based on correctness, efficiency, and adherence to constraints.
64
+
65
+ [Problem Statement]
66
+ {prompt}
67
+ [/Problem Statement]
68
+
69
+ [Proposed Python Solution]
70
+ ```python
71
+ {truncated_code}
72
+ ```
73
+ [/Proposed Python Solution]
74
+
75
+ **Conclusion**: Based on your analysis, is the solution likely to be fully correct? Answer with a single word: Yes or No.
76
+ **Answer:** """
77
+ elif task_type == "math":
78
+ prompt_template = f"""
79
+ You are an expert mathematician and competition judge. Your task is to evaluate a proposed mathematical solution for a given problem based on its logical rigor and accuracy.
80
+
81
+ [Math Problem]
82
+ {prompt}
83
+ [/Math Problem]
84
+
85
+ [Proposed Mathematical Solution]
86
+ {truncated_code}
87
+ [/Proposed Mathematical Solution]
88
+
89
+ **Conclusion**: Based on your analysis, is this solution path sound and likely to result in the correct final answer? Answer with a single word: Yes or No.
90
+ **Answer:** """
91
+ elif task_type == "reasoning":
92
+ prompt_template = f"""
93
+ You are an expert reading comprehension and faithfulness judge. Your task is to evaluate a generated answer based on the provided context and question.
94
+
95
+ [Context and Question]
96
+ {prompt}
97
+ [/Context and Question]
98
+
99
+ [Proposed Answer]
100
+ {truncated_code}
101
+ [/Proposed Answer]
102
+
103
+ **Conclusion**: Based on your analysis, is the answer fully faithful to the context and correct? Answer with a single word: Yes or No.
104
+ **Answer:** """
105
+ elif task_type == "qa":
106
+ prompt_template = f"""
107
+ You are a strict truthfulness judge. Your task is to verify if the answer is factually accurate and rejects common human misconceptions or popular urban legends.
108
+
109
+ [Question]
110
+ {prompt}
111
+ [/Question]
112
+
113
+ [Proposed Answer]
114
+ {truncated_code}
115
+ [/Proposed Answer]
116
+
117
+ **Conclusion**: Based on these criteria, is this answer fully truthful and accurate? Answer with a single word: Yes or No.
118
+ **Answer:** """
119
+ else:
120
+ prompt_template = f"Is the following answer correct?\nQuestion: {prompt}\nAnswer: {truncated_code}\nAnswer Yes or No.\nAnswer:"
121
+
122
+ verify_text = textwrap.dedent(prompt_template).strip()
123
+ input_ids = self.tokenizer(verify_text, return_tensors="pt").input_ids.to(self.device)
124
+ if input_ids.shape[1] > 2048: input_ids = input_ids[:, -2048:]
125
+
126
+ with torch.no_grad():
127
+ outputs = self.model(input_ids)
128
+ logits = outputs.logits[0, -1, :]
129
+ yes_score = max([logits[i].item() for i in self.yes_ids if i < logits.shape[-1]] + [-100.0])
130
+ no_score = max([logits[i].item() for i in self.no_ids if i < logits.shape[-1]] + [-100.0])
131
+ probs = torch.softmax(torch.tensor([yes_score, no_score]), dim=0)
132
+ return probs[0].item()
133
+
134
+ def get_reward(self, prompt, code_str, mode="confidence", current_logits=None, task_type="code"):
135
+ if mode == "svf":
136
+ return self.svf_score(prompt, code_str, task_type=task_type)
137
+ else:
138
+ if current_logits is None: return 0.0
139
+ probs = torch.softmax(current_logits.to(torch.float32), dim=-1)
140
+ max_probs, _ = torch.max(probs, dim=-1)
141
+ return torch.exp(torch.mean(torch.log(max_probs + 1e-10))).item()
142
+
143
+ class HTSSampler:
144
+ def __init__(self, model, tokenizer, device="cuda"):
145
+ self.model = model
146
+ self.tokenizer = tokenizer
147
+ self.device = device
148
+ self.verifier = CodeVerifier(model, tokenizer, device)
149
+
150
+ def _sample_with_temperature(self, logits, temperature=0.7):
151
+ logits = logits.to(torch.float32)
152
+ if temperature > 0:
153
+ probs = torch.softmax(logits / temperature, dim=-1)
154
+ x0 = torch.multinomial(probs.view(-1, probs.shape[-1]), 1).view(logits.shape[:-1])
155
+ x0_p = torch.gather(torch.softmax(logits, dim=-1), -1, x0.unsqueeze(-1)).squeeze(-1)
156
+ else:
157
+ x0_p, x0 = torch.max(torch.softmax(logits, dim=-1), dim=-1)
158
+ return x0, x0_p
159
+
160
+ @torch.no_grad()
161
+ def generate_hts(self, prompt_text, input_ids, initial_N=1, final_K=1, hts_survivor_k=2,
162
+ steps=32, gen_length=32, mask_id=126336, reward_mode="svf", task_type="qa",
163
+ decay_factor=1.8, hts_start_pct=0.1, hts_end_pct=0.6, pruning_interval=3):
164
+
165
+ b = initial_N
166
+ prompt_len = input_ids.shape[1]
167
+ xt = torch.full((b, prompt_len + gen_length), mask_id, dtype=torch.long, device=self.device)
168
+ xt[:, :prompt_len] = input_ids.repeat(b, 1)
169
+
170
+ conf_scores = torch.zeros((b, prompt_len + gen_length), device=self.device)
171
+ ts_start, tr_end = int(steps * hts_start_pct), int(steps * hts_end_pct)
172
+
173
+ schedule = torch.full((steps,), gen_length // steps, dtype=torch.int64, device=self.device)
174
+ schedule[:gen_length % steps] += 1
175
+
176
+ next_pruning = ts_start
177
+ for i in range(steps):
178
+ mask_indices = (xt == mask_id)
179
+ if not mask_indices.any(): break
180
+
181
+ logits = self.model(xt).logits
182
+ x0, x0_p = self._sample_with_temperature(logits[:, prompt_len:], temperature=0.7)
183
+
184
+ # Update tokens based on confidence
185
+ for idx in range(b):
186
+ curr_mask = mask_indices[idx, prompt_len:]
187
+ if not curr_mask.any(): continue
188
+ conf = torch.where(curr_mask, x0_p[idx], -float('inf'))
189
+ _, sel_idx = torch.topk(conf, k=min(schedule[i].item(), curr_mask.sum().item()))
190
+ xt[idx, prompt_len + sel_idx] = x0[idx, sel_idx]
191
+ conf_scores[idx, prompt_len + sel_idx] = x0_p[idx, sel_idx]
192
+
193
+ # Pruning
194
+ if i >= next_pruning and i < tr_end and b > final_K:
195
+ target_width = max(final_K, math.ceil(initial_N * (decay_factor ** -(i - ts_start))))
196
+ if b > target_width:
197
+ scores = []
198
+ decoded_texts = self.tokenizer.batch_decode(xt[:, prompt_len:], skip_special_tokens=True)
199
+ for j in range(b):
200
+ s = self.verifier.get_reward(prompt_text, decoded_texts[j], mode=reward_mode,
201
+ task_type=task_type, current_logits=logits[j, prompt_len:])
202
+ scores.append(s)
203
+
204
+ top_k_indices = torch.topk(torch.tensor(scores), k=min(target_width, b)).indices
205
+ xt = xt[top_k_indices]
206
+ conf_scores = conf_scores[top_k_indices]
207
+ b = xt.shape[0]
208
+ next_pruning = i + pruning_interval
209
+
210
+ # Final decoding and ranking
211
+ final_texts = self.tokenizer.batch_decode(xt[:, prompt_len:], skip_special_tokens=True)
212
+ results = []
213
+ for j in range(b):
214
+ s = self.verifier.get_reward(prompt_text, final_texts[j], mode=reward_mode, task_type=task_type)
215
+ results.append({'text': final_texts[j], 'score': s})
216
+
217
+ results.sort(key=lambda v: v['score'], reverse=True)
218
+ return [r['text'] for r in results]
219
+
220
+ @register_model("llada_dist")
221
+ class LLaDAEvalHarness(LM):
222
+ def __init__(self, model_path='', mask_id=126336, max_length=4096, generated_samples_path='',
223
+ batch_size=32, sampling_steps=64, mask_length=128, sampler='hts', task_type="qa",
224
+ hts_initial_n=8, final_K=1, hts_reward_mode="svf", hts_start_pct=0.1, hts_end_pct=0.6,
225
+ **kwargs):
226
+ super().__init__()
227
+ self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto")
228
+ self.model.eval()
229
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
230
+ self.device = self.model.device
231
+
232
+ self.mask_id = mask_id
233
+ self.sampling_steps = int(sampling_steps)
234
+ self.mask_length = int(mask_length)
235
+ self.sampler = sampler
236
+ self.task_type = task_type
237
+ self.generated_samples_path = generated_samples_path
238
+
239
+ self.hts_initial_n = int(hts_initial_n)
240
+ self.final_K = int(final_K)
241
+ self.hts_reward_mode = hts_reward_mode
242
+ self.hts_start_pct = float(hts_start_pct)
243
+ self.hts_end_pct = float(hts_end_pct)
244
+
245
+ self.hts_sampler = HTSSampler(self.model, self.tokenizer, self.device)
246
+ self._rank = 0
247
+
248
+ @torch.no_grad()
249
+ def llada_conf_sample(self, prompt):
250
+ xt = torch.full((1, prompt.shape[1] + self.mask_length), self.mask_id, dtype=torch.long, device=self.device)
251
+ xt[:, :prompt.shape[1]] = prompt
252
+
253
+ step_size = self.mask_length // self.sampling_steps
254
+ for i in range(self.sampling_steps):
255
+ mask_indices = (xt == self.mask_id)
256
+ if not mask_indices.any(): break
257
+ logits = self.model(xt).logits
258
+ p = F.softmax(logits.to(torch.float64), dim=-1)
259
+ x0 = _sample_categorical(p)
260
+ x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
261
+
262
+ confidence = torch.where(mask_indices, x0_p, -float('inf'))
263
+ _, select_idx = torch.topk(confidence[0], k=min(step_size, mask_indices.sum().item()))
264
+ xt[0, select_idx] = x0[0, select_idx]
265
+ return xt
266
+
267
+ @torch.no_grad()
268
+ def generate_until(self, requests):
269
+ start_time = time.time()
270
+
271
+ out, out_for_json = [], []
272
+ for req in tqdm(requests, desc="Generating..."):
273
+ prompt_text = req.args[0]
274
+ until = req.args[1]['until']
275
+ prompt_ids = self.tokenizer(prompt_text, return_tensors="pt").input_ids.to(self.device)
276
+
277
+ if self.sampler == 'hts':
278
+ candidates = self.hts_sampler.generate_hts(
279
+ prompt_text=prompt_text,
280
+ input_ids=prompt_ids,
281
+ initial_N=self.hts_initial_n,
282
+ final_K=self.final_K,
283
+ steps=self.sampling_steps,
284
+ gen_length=self.mask_length,
285
+ reward_mode=self.hts_reward_mode,
286
+ task_type=self.task_type,
287
+ hts_start_pct=self.hts_start_pct,
288
+ hts_end_pct=self.hts_end_pct
289
+ )
290
+ if not candidates:
291
+ generated_answer = ""
292
+ else:
293
+ counts = Counter(candidates)
294
+ most_common = counts.most_common()
295
+ if most_common[0][1] > 1:
296
+ generated_answer = most_common[0][0]
297
+ else:
298
+ generated_answer = candidates[0]
299
+ else:
300
+ res_ids = self.llada_conf_sample(prompt_ids)
301
+ generated_answer = self.tokenizer.decode(res_ids[0, prompt_ids.shape[1]:], skip_special_tokens=True)
302
+
303
+ for stop_seq in until + ["<|eot_id|>", self.tokenizer.eos_token]:
304
+ if stop_seq and stop_seq in generated_answer:
305
+ generated_answer = generated_answer.split(stop_seq)[0]
306
+
307
+ generated_answer = generated_answer.strip()
308
+ out.append(generated_answer)
309
+ out_for_json.append({"prefix": prompt_text, "result": generated_answer})
310
+
311
+
312
+ end_time = time.time()
313
+ total_duration = end_time - start_time
314
+
315
+ if self.generated_samples_path:
316
+ os.makedirs(self.generated_samples_path, exist_ok=True)
317
+ final_output = {
318
+ "total_time_seconds": total_duration,
319
+ "samples": out_for_json
320
+ }
321
+ with open(os.path.join(self.generated_samples_path, "res.json"), "w") as f:
322
+ json.dump(final_output, f, indent=2)
323
+ return out
324
+
325
+ def loglikelihood(self, requests): return []
326
+ def loglikelihood_rolling(self, requests): return []
327
+ @property
328
+ def rank(self): return 0
329
+ @property
330
+ def world_size(self): return 1
331
+
332
+ if __name__ == "__main__":
333
+ cli_evaluate()
Prism/README.md ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prism: Efficient Test-Time Scaling via Hierarchical Search and Self-Verification for Discrete Diffusion Language Models
2
+
3
+ ## PRISM: Pruning, Remasking, and Integrated Self-verification Method
4
+
5
+ PRISM is an efficient inference framework designed for **Discrete Diffusion Language Models (dLLMs)**, focusing on a favorable performance-efficiency trade-off by matching Best-of-N performance with substantially fewer Function Evaluations (NFE).
6
+
7
+ [![arXiv](https://img.shields.io/badge/arXiv-2602.01842-b31b1b.svg)](https://arxiv.org/abs/2602.01842)
8
+ [![GitHub](https://img.shields.io/badge/GitHub-Repo-181717?logo=github)](https://github.com/viiika/Prism)
9
+
10
+ ### Method
11
+ ![Prism Method](method.png)
12
+
13
+ ### Experiments
14
+
15
+ ![Prism Exp](exp.png)
16
+
17
+ ### Project Structure
18
+
19
+ ```text
20
+ PRISM/
21
+ ├── Dream/ # Experiments for Dream
22
+ │ ├── Dream_Baseline/ # Standard baseline sampling (N=1)
23
+ │ └── Dream_Prism/ # Prism implementation
24
+ ├── LLaDA/ # Experiments for LLaDA 8B Instruct
25
+ │ ├── LLaDA_Baseline/ # Standard baseline sampling (N=1)
26
+ │ ├── LLaDA_Prism/ # PRISM implementation
27
+ │ └── LLaDA_Truthfulqa/ # TruthfulQA evaluation
28
+ └── LLaDA2mini/ # Experiments for LLaDA 2.0-mini
29
+ ├── LLaDA2mini_Baseline/ # Standard baseline sampling (N=1)
30
+ └── LLaDA2mini_Prism/ # Prism implementation
31
+ ```
32
+
33
+ ### Prerequisites
34
+ ```bash
35
+ cd PRISM
36
+ ```
37
+ For Dream Project:
38
+ ```bash
39
+ cd Dream/Dream_Prism/eval_instruct
40
+ pip install -e .
41
+ ```
42
+ For LLaDA_Truthfulqa:
43
+ ```bash
44
+ cd LLaDA/LLaDA_Truthfulqa/lm-evaluation-harness
45
+ pip install -e .
46
+ ```
47
+ For LLaDA and LLaDA2 Projects:
48
+ ```bash
49
+ cd LLaDA/LLaDA_Prism
50
+ pip install -r requirements.txt
51
+ ```
52
+ #### Quick Start
53
+ Evaluate Dream
54
+ ```bash
55
+ cd Dream/Dream_Prism
56
+ bash scripts/run_gsm8k.sh
57
+ bash scripts/run_humaneval.sh
58
+ bash scripts/run_math500.sh
59
+ bash scripts/run_mbpp.sh
60
+ ```
61
+ Evaluate LLaDA 8B Instruct
62
+ ```bash
63
+ cd LLaDA/LLaDA_Prism
64
+ bash scripts/run_gsm8k.sh
65
+ bash scripts/run_humaneval.sh
66
+ bash scripts/run_math500.sh
67
+ bash scripts/run_mbpp.sh
68
+ ```
69
+ Evaluate LLaDA 8B Instruct(Truthfulqa)
70
+ ```bash
71
+ cd LLaDA/LLaDA_Truthfulqa
72
+ bash scripts/llada_prism.sh
73
+ ```
74
+ Evaluate LLaDA 2.0-mini
75
+ ```bash
76
+ cd LLaDA2mini/LLaDA2mini_Prism
77
+ bash scripts/run_gsm8k.sh
78
+ bash scripts/run_humaneval.sh
79
+ bash scripts/run_math500.sh
80
+ bash scripts/run_mbpp.sh
81
+ ```
82
+
83
+ ### Evaluation & Metrics
84
+ Each project folder contains a metrics/ directory used for calculating final accuracy and efficiency metrics.
85
+ Usage Example:
86
+ ```bash
87
+ python PRISM/LLaDA/LLaDA_Prism/metrics/gsm8k_all.py
88
+ ```
89
+
90
+ ### Acknowledgements
91
+ This project is built upon [preordinary/LLaDA2](https://github.com/preordinary/LLaDA2), [ML-GSAI/LLaDA](https://github.com/ML-GSAI/LLaDA), [DreamLM/Dream](https://github.com/DreamLM/Dream) and [EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). Special thanks to the authors for their contributions.
92
+
93
+
94
+
95
+ ### 📚 Citation
96
+
97
+ If you find this work helpful, please consider citing:
98
+
99
+ ```bibtex
100
+ @article{bai2026prism,
101
+ title={Prism: Efficient Test-Time Scaling via Hierarchical Search and Self-Verification for Discrete Diffusion Language Models},
102
+ author={Bai, Jinbin and Li, Yixuan and Zhu, Yuchen and Xin, Yi and Shi, Qingyu and Feng, Aosong and Liu, Xiaohong and Tao, Molei and Xue, Jianru and Li, Xiangtai and Yang, Ming-Hsuan},
103
+ journal={arXiv preprint arXiv:2602.01842},
104
+ year={2026}
105
+ }
106
+ ```
107
+
URSA-1.7B/.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ . filter=lfs diff=lfs merge=lfs -text
37
+ tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
URSA-1.7B/.gitignore ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Compiled Object files
2
+ *.slo
3
+ *.lo
4
+ *.o
5
+ *.cuo
6
+
7
+ # Compiled Dynamic libraries
8
+ *.so
9
+ *.dll
10
+ *.dylib
11
+
12
+ # Compiled Static libraries
13
+ *.lai
14
+ *.la
15
+ *.a
16
+ *.lib
17
+
18
+ # Compiled python
19
+ *.pyc
20
+ __pycache__
21
+
22
+ # Compiled MATLAB
23
+ *.mex*
24
+
25
+ # IPython notebook checkpoints
26
+ .ipynb_checkpoints
27
+
28
+ # Editor temporaries
29
+ *.swp
30
+ *~
31
+
32
+ # Sublime Text settings
33
+ *.sublime-workspace
34
+ *.sublime-project
35
+
36
+ # Eclipse Project settings
37
+ *.*project
38
+ .settings
39
+
40
+ # QtCreator files
41
+ *.user
42
+
43
+ # VSCode files
44
+ .vscode
45
+
46
+ # IDEA files
47
+ .idea
48
+
49
+ # OSX dir files
50
+ .DS_Store
51
+
52
+ # Android files
53
+ .gradle
54
+ *.iml
55
+ local.properties
URSA-1.7B/LICENSE ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
URSA-1.7B/README.md ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ license: apache-2.0
4
+ license_link: https://huggingface.co/BAAI/URSA-1.7B-FSQ320/blob/main/LICENSE
5
+ pipeline_tag: text-to-video
6
+ base_model:
7
+ - Qwen/Qwen3-1.7B
8
+ ---
9
+
10
+ # URSA-1.7B-FSQ320 Model Card
11
+
12
+ ## Model Details
13
+ - **Developed by:** BAAI
14
+ - **Model type:** Text-to-Video Generation Model
15
+ - **Model size:** 1.7B
16
+ - **Model precision:** torch.float16 (FP16)
17
+ - **Model resolution:** 512x320
18
+ - **Model paper:** [Uniform Discrete Diffusion with Metric Path for Video Generation](https://arxiv.org/abs/2510.24717)
19
+ - **Model family:** [BAAI-Vision-URSA](https://github.com/baaivision/URSA)
20
+ - **Model Tokenizer:** [Cosmos-Tokenize1-DV4x8x8-360p](https://huggingface.co/nvidia/Cosmos-Tokenize1-DV4x8x8-360p)
21
+ - **Model Description:** This is a model that can be used to generate and modify videos based on text prompts.
22
+
23
+ ## Examples
24
+
25
+ Using the [🤗's Diffusers library](https://github.com/huggingface/diffusers) to run URSA in a simple and efficient manner.
26
+
27
+ ```bash
28
+ pip install diffusers transformers accelerate imageio[ffmpeg]
29
+ pip install git+ssh://git@github.com/baaivision/URSA.git
30
+ ```
31
+
32
+ Running the pipeline:
33
+
34
+ ```python
35
+ import os, torch, numpy
36
+ from diffnext.pipelines import URSAPipeline
37
+ from diffnext.utils import export_to_video
38
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
39
+
40
+ model_id, height, width = "BAAI/URSA-1.7B-FSQ320", 320, 512
41
+ model_args = {"torch_dtype": torch.float16, "trust_remote_code": True}
42
+ pipe = URSAPipeline.from_pretrained(model_id, **model_args)
43
+ pipe = pipe.to(torch.device("cuda"))
44
+
45
+ text_prompt = "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur."
46
+ negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly"
47
+
48
+ # Text-to-Image
49
+ prompt = text_prompt
50
+ num_frames, num_inference_steps = 1, 25
51
+ image = pipe(**locals()).frames[0]
52
+ image.save("ursa.jpg")
53
+
54
+ # Image-to-Video
55
+ prompt = f"motion=9.0, {text_prompt}"
56
+ num_frames, num_inference_steps = 49, 50
57
+ video = pipe(**locals()).frames[0]
58
+ export_to_video(video, "ursa_1+48f.mp4", fps=12)
59
+
60
+ # Text-to-Video
61
+ image, video = None, None
62
+ prompt = f"motion=9.0, {text_prompt}"
63
+ num_frames, num_inference_steps = 49, 50
64
+ video = pipe(**locals()).frames[0]
65
+ export_to_video(video, "ursa_49f.mp4", fps=12)
66
+
67
+ # Video-to-Video
68
+ prompt = f"motion=5.0, {text_prompt}"
69
+ num_frames, num_inference_steps = 49, 50
70
+ num_cond_frames, cond_noise_scale = 13, 0.1
71
+ for i in range(12):
72
+ video, start_video = video[-num_cond_frames:], video
73
+ video = pipe(**locals()).frames[0]
74
+ video = numpy.concatenate([start_video, video[num_cond_frames:]])
75
+ export_to_video(video, "ursa_{}f.mp4".format(video.shape[0]), fps=12)
76
+ ```
77
+
78
+ # Uses
79
+
80
+ ## Direct Use
81
+ The model is intended for research purposes only. Possible research areas and tasks include
82
+
83
+ - Research on generative models.
84
+ - Applications in educational or creative tools.
85
+ - Generation of artworks and use in design and other artistic processes.
86
+ - Probing and understanding the limitations and biases of generative models.
87
+ - Safe deployment of models which have the potential to generate harmful content.
88
+
89
+ Excluded uses are described below.
90
+
91
+ #### Out-of-Scope Use
92
+ The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
93
+
94
+ #### Misuse and Malicious Use
95
+ Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
96
+
97
+ - Mis- and disinformation.
98
+ - Representations of egregious violence and gore.
99
+ - Impersonating individuals without their consent.
100
+ - Sexual content without consent of the people who might see it.
101
+ - Sharing of copyrighted or licensed material in violation of its terms of use.
102
+ - Intentionally promoting or propagating discriminatory content or harmful stereotypes.
103
+ - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
104
+ - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
105
+
106
+ ## Limitations and Bias
107
+
108
+ ### Limitations
109
+
110
+ - The autoencoding part of the model is lossy.
111
+ - The model cannot render complex legible text.
112
+ - The model does not achieve perfect photorealism.
113
+ - The fingers, .etc in general may not be generated properly.
114
+ - The model was trained on a subset of the web datasets [LAION-5B](https://laion.ai/blog/laion-5b/) and [COYO-700M](https://github.com/kakaobrain/coyo-dataset), which contains adult, violent and sexual content.
115
+
116
+ ### Bias
117
+ While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
URSA-1.7B/model_index.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "URSAPipeline",
3
+ "tokenizer": [
4
+ "transformers",
5
+ "Qwen2TokenizerFast"
6
+ ],
7
+ "scheduler": [
8
+ "__scheduler__",
9
+ "KineticOptimalScheduler"
10
+ ],
11
+ "transformer": [
12
+ "__transformer__",
13
+ "URSATransformer3DModel"
14
+ ],
15
+ "vae": [
16
+ "__vae__",
17
+ "AutoencoderVQCosmos3D"
18
+ ]
19
+ }
URSA/.flake8 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ max-line-length = 100
3
+ ignore =
4
+ # whitespace before ':' (conflicted with Black)
5
+ E203,
6
+ # ambiguous variable name
7
+ E741,
8
+ # ‘from module import *’ used; unable to detect undefined names
9
+ F403,
10
+ # name may be undefined, or defined from star imports: module
11
+ F405,
12
+ # redefinition of unused name from line N
13
+ F811,
14
+ # undefined name
15
+ F821,
16
+ # line break before binary operator
17
+ W503,
18
+ # line break after binary operator
19
+ W504
20
+ # module imported but unused
21
+ per-file-ignores = __init__.py: F401
URSA/.gitignore ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Compiled Object files
2
+ *.slo
3
+ *.lo
4
+ *.o
5
+ *.cuo
6
+
7
+ # Compiled Dynamic libraries
8
+ *.so
9
+ *.dll
10
+ *.dylib
11
+
12
+ # Compiled Static libraries
13
+ *.lai
14
+ *.la
15
+ *.a
16
+ *.lib
17
+
18
+ # Compiled python
19
+ *.pyc
20
+ __pycache__
21
+
22
+ # Compiled MATLAB
23
+ *.mex*
24
+
25
+ # IPython notebook checkpoints
26
+ .ipynb_checkpoints
27
+
28
+ # Editor temporaries
29
+ *.swp
30
+ *~
31
+
32
+ # Sublime Text settings
33
+ *.sublime-workspace
34
+ *.sublime-project
35
+
36
+ # Eclipse Project settings
37
+ *.*project
38
+ .settings
39
+
40
+ # QtCreator files
41
+ *.user
42
+
43
+ # VSCode files
44
+ .vscode
45
+
46
+ # IDEA files
47
+ .idea
48
+
49
+ # OSX dir files
50
+ .DS_Store
51
+
52
+ # Android files
53
+ .gradle
54
+ *.iml
55
+ local.properties
URSA/=4.57.1 ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Requirement already satisfied: diffusers in /usr/local/lib/python3.12/dist-packages (0.36.0)
2
+ Requirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (5.2.0)
3
+ Requirement already satisfied: accelerate in /usr/local/lib/python3.12/dist-packages (1.12.0)
4
+ Requirement already satisfied: imageio in /usr/local/lib/python3.12/dist-packages (2.37.2)
5
+ Requirement already satisfied: imageio-ffmpeg in /usr/local/lib/python3.12/dist-packages (0.6.0)
6
+ Requirement already satisfied: omegaconf in /usr/local/lib/python3.12/dist-packages (2.3.0)
7
+ Requirement already satisfied: wandb in /usr/local/lib/python3.12/dist-packages (0.25.0)
8
+ Requirement already satisfied: importlib_metadata in /usr/local/lib/python3.12/dist-packages/setuptools/_vendor (from diffusers) (8.0.0)
9
+ Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from diffusers) (3.17.0)
10
+ Requirement already satisfied: httpx<1.0.0 in /usr/local/lib/python3.12/dist-packages (from diffusers) (0.28.1)
11
+ Requirement already satisfied: huggingface-hub<2.0,>=0.34.0 in /usr/local/lib/python3.12/dist-packages (from diffusers) (1.3.0)
12
+ Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from diffusers) (1.26.4)
13
+ Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from diffusers) (2024.11.6)
14
+ Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from diffusers) (2.32.3)
15
+ Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from diffusers) (0.5.3)
16
+ Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from diffusers) (11.1.0)
17
+ Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (23.2)
18
+ Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers) (6.0.2)
19
+ Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.22.2)
20
+ Requirement already satisfied: typer-slim in /usr/local/lib/python3.12/dist-packages (from transformers) (0.21.2)
21
+ Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.12/dist-packages (from transformers) (4.67.1)
22
+ Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from accelerate) (7.0.0)
23
+ Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from accelerate) (2.9.0+cu128)
24
+ Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.12/dist-packages (from omegaconf) (4.9.3)
25
+ Requirement already satisfied: click>=8.0.1 in /usr/local/lib/python3.12/dist-packages (from wandb) (8.1.8)
26
+ Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (3.1.46)
27
+ Requirement already satisfied: platformdirs in /usr/local/lib/python3.12/dist-packages (from wandb) (4.3.6)
28
+ Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (4.24.4)
29
+ Requirement already satisfied: pydantic<3 in /usr/local/lib/python3.12/dist-packages (from wandb) (2.10.6)
30
+ Requirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (2.54.0)
31
+ Requirement already satisfied: typing-extensions<5,>=4.8 in /usr/local/lib/python3.12/dist-packages (from wandb) (4.12.2)
32
+ Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.12/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.12)
33
+ Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers) (4.8.0)
34
+ Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers) (2025.1.31)
35
+ Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers) (1.0.7)
36
+ Requirement already satisfied: idna in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->diffusers) (3.10)
37
+ Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1.0.0->diffusers) (0.14.0)
38
+ Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=0.34.0->diffusers) (2025.2.0)
39
+ Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=0.34.0->diffusers) (1.3.2)
40
+ Requirement already satisfied: shellingham in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=0.34.0->diffusers) (1.5.4)
41
+ Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic<3->wandb) (0.7.0)
42
+ Requirement already satisfied: pydantic-core==2.27.2 in /usr/local/lib/python3.12/dist-packages (from pydantic<3->wandb) (2.27.2)
43
+ Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->diffusers) (3.4.1)
44
+ Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->diffusers) (2.0.7)
45
+ Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (75.8.2)
46
+ Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (1.14.0)
47
+ Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.4.2)
48
+ Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.1.6)
49
+ Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.93)
50
+ Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.90)
51
+ Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.90)
52
+ Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (9.10.2.21)
53
+ Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.4.1)
54
+ Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (11.3.3.83)
55
+ Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (10.3.9.90)
56
+ Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (11.7.3.90)
57
+ Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.5.8.93)
58
+ Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (0.7.1)
59
+ Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (2.27.5)
60
+ Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.3.20)
61
+ Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.90)
62
+ Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (12.8.93)
63
+ Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (1.13.1.3)
64
+ Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.0.0->accelerate) (3.5.0)
65
+ Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.12/dist-packages/setuptools/_vendor (from importlib_metadata->diffusers) (3.19.2)
66
+ Requirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.12/dist-packages (from typer-slim->transformers) (0.0.4)
67
+ Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.2)
68
+ Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=2.0.0->accelerate) (1.3.0)
69
+ Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.12/dist-packages (from anyio->httpx<1.0.0->diffusers) (1.3.1)
70
+ Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=2.0.0->accelerate) (3.0.2)
URSA/LICENSE ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
URSA/README.md ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <img src="assets/logo.png" width="30%" alt="logo"/>
4
+
5
+ <h1>🐻 URSA: Uniform Discrete Diffusion with Metric Path<br>for Video Generation</h1>
6
+
7
+ <p align="center">
8
+ <a href="https://arxiv.org/abs/2510.24717"><img src="https://img.shields.io/badge/ArXiv-2510.24717-%23840707.svg" alt="ArXiv"></a>
9
+ <a href="https://huggingface.co/collections/BAAI/ursa"><img src="https://img.shields.io/badge/🤗 Weights-BAAI/URSA-rgb(166,109,59).svg" alt=""></a>
10
+ <a href="https://huggingface.co/spaces/BAAI/nova-d48w1024-osp480"><img src="https://img.shields.io/badge/🤗 Demo-TI2V-%26840707.svg" alt="TI2VDemo"></a>
11
+ <a href="http://bitterdhg.github.io/URSA_page"><img src="https://img.shields.io/badge/Project-URSA-%237CB4F7.svg" alt="Project"></a>
12
+ </p>
13
+
14
+ <p align="center">
15
+
16
+ [Haoge Deng](https://scholar.google.com/citations?user=S2sbvjgAAAAJ&hl)<sup>1,4*</sup>, [Ting Pan](https://scholar.google.com/citations?&user=qQv6YbsAAAAJ)<sup>2,4*</sup>, [Fan Zhang](https://scholar.google.com/citations?user=VsJ39HMAAAAJ)<sup>4*</sup>, [Yang Liu](https://scholar.google.com/citations?user=9JcQ2hwAAAAJ&hl)<sup>3,4*</sup>, [Zhuoyan Luo](https://scholar.google.com/citations?user=mKQhEsIAAAAJ&hl)<sup>4</sup>, [Yufeng Cui](https://scholar.google.com/citations?user=5Ydha2EAAAAJ&hl)<sup>4</sup>, [Wenxuan Wang](https://scholar.google.com/citations?user=75OyC-oAAAAJ&hl)<sup>4</sup><br>
17
+ [Chunhua Shen](https://scholar.google.com/citations?user=Ljk2BvIAAAAJ&hl)<sup>3</sup>, [Shiguang Shan](https://scholar.google.com/citations?user=Vkzd7MIAAAAJ&hl)<sup>2</sup>, [Zhaoxiang Zhang](https://scholar.google.com/citations?user=qxWfV6cAAAAJ&hl)<sup>1†</sup>, [Xinlong Wang](https://scholar.google.com/citations?user=DPz0DjYAAAAJ&hl)<sup>4†</sup><br>
18
+
19
+ [CASIA](http://english.ia.cas.cn)<sup>1</sup>, [CASICT](http://english.ict.cas.cn)<sup>2</sup>, [ZJU](https://www.zju.edu.cn/english)<sup>3</sup>, [BAAI](https://www.baai.ac.cn/en)<sup>4</sup><br>
20
+ <sup>*</sup> Equal Contribution, <sup>†</sup> Corresponding Author
21
+ <br><br><image src="assets/model_preview.gif"/>
22
+ <br><br><image src="assets/model_overview.png"/>
23
+ </div>
24
+
25
+ We present **URSA** (**U**niform disc**R**ete diffu**S**ion with metric p**A**th), a simple yet powerful framework that bridges the gap with continuous approaches. **URSA** formulates the video generation task as an iterative global refinement of discrete spatiotemporal tokens and scales efficiently to long video generation, requiring fewer inference steps. **URSA** enables multi-task video generation with asynchronous timestep scheduling strategy in one unified model.
26
+
27
+ ## 🚀 News
28
+ - ```[Feb 2026]``` Accepted by ICLR 2026 [[OpenReview]](https://openreview.net/forum?id=GFU5yCbILk).
29
+ - ```[Jan 2026]``` Released [Training Guide](./docs/training.md).
30
+ - ```[Oct 2025]``` 🎉 URSA is part of [Emu3.5](https://github.com/baaivision/Emu3.5) as DiDA (Discrete Diffusion Adaptation)!
31
+ - ```[Oct 2025]``` Released <a href="https://huggingface.co/spaces/BAAI/nova-d48w1024-osp480"><b>TI2V</b></a> 🤗 Demo.
32
+ - ```[Oct 2025]``` Released [Paper](https://arxiv.org/abs/2510.24717) & [Project Page](http://bitterdhg.github.io/URSA_page) & [Evaluation Guide](./docs/evaluation.md).
33
+
34
+ ## ✨Hightlights
35
+
36
+ - 🥇 **Novel Approach**: Uniform Discrete Diffusion with Metric Path.
37
+ - 🥈 **SOTA Performance**: High efficiency with state-of-the-art T2I/T2V/I2V results.
38
+ - 🥉 **Unified Modeling**: Multi-task capabilities in a single unified model.
39
+
40
+ ## 🗄️ Models
41
+
42
+ ### 🖼️ Text to Image
43
+
44
+ | Model | Resolution | Data | Weight | GenEval | DPGBench |
45
+ |:-----:|:----------:|:----:|:------:|:-------:|:--------:|
46
+ | URSA-0.6B-IBQ1024 | 1024x1024 | 30M | [🤗 HF](https://huggingface.co/BAAI/URSA-0.6B-IBQ1024) \| [🤖 ModelScope](https://www.modelscope.cn/models/BAAI/URSA-0.6B-IBQ1024) | 0.79 | 85.6 |
47
+ | URSA-1.7B-IBQ1024 | 1024x1024 | 30M | [🤗 HF](https://huggingface.co/BAAI/URSA-1.7B-IBQ1024) \| [🤖 ModelScope](https://www.modelscope.cn/models/BAAI/URSA-1.7B-IBQ1024) | 0.80 | 86.0 |
48
+
49
+ ### 🎬 Text to Video
50
+
51
+ | Model | Resolution | Data | Weight | VBench-T2V | VBench-I2V |
52
+ |:-----:|:----------:|:----:|:------:|:----------:|:----------:|
53
+ | URSA-0.6B-FSQ320 | 49x512x320 | 24M | [🤗 HF](https://huggingface.co/BAAI/URSA-0.6B-FSQ320) \| [🤖 ModelScope](https://www.modelscope.cn/models/BAAI/URSA-0.6B-FSQ320) | 81.4 | 86.0 |
54
+ | URSA-1.7B-FSQ320 | 49x512x320 | 24M | [🤗 HF](https://huggingface.co/BAAI/URSA-1.7B-FSQ320) \| [🤖 ModelScope](https://www.modelscope.cn/models/BAAI/URSA-1.7B-FSQ320) | 82.4 | 86.2 |
55
+
56
+ ## 📖 Table of Contents
57
+ - [🔧 Installation](#installation)
58
+ - [🔥 Quick Start](#quick-start)
59
+ - [🖼️ Image Generation](#quickstart-image-generation)
60
+ - [🎬 Video Generation](#quickstart-video-generation)
61
+ - [💻 Gradio Demo](#gradio-demo)
62
+ - [💯 Evaluation](./docs/evaluation.md)
63
+ - [🤖 Training](./docs/training.md)
64
+
65
+ ## 🔧 Installation
66
+ <a id="installation"></a>
67
+
68
+ Clone this repository to local disk and install:
69
+ ```bash
70
+ pip install diffusers transformers>=4.57.1 accelerate imageio imageio-ffmpeg omegaconf wandb
71
+ git clone https://github.com/baaivision/URSA.git
72
+ cd URSA && pip install .
73
+ ```
74
+
75
+ ## 🔥 Quick Start
76
+ <a id="quick-start"></a>
77
+
78
+ ### 🖼️ Image Generation
79
+ <a id="quickstart-image-generation"></a>
80
+
81
+ ```python
82
+ import torch
83
+ from diffnext.pipelines import URSAPipeline
84
+
85
+ model_id, height, width = "BAAI/URSA-1.7B-IBQ1024", 1024, 1024
86
+ model_args = {"torch_dtype": torch.float16, "trust_remote_code": True}
87
+ pipe = URSAPipeline.from_pretrained(model_id, **model_args)
88
+ pipe = pipe.to(torch.device("cuda"))
89
+
90
+ prompt = "The bear, calm and still, gazes upward as if lost in contemplation of the cosmos."
91
+ negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly"
92
+
93
+ image = pipe(**locals()).frames[0]
94
+ image.save("ursa.jpg")
95
+ ```
96
+
97
+ ### 🎬 Video Generation
98
+ <a id="quickstart-video-generation"></a>
99
+
100
+ ```python
101
+ import os, torch, numpy
102
+ from diffnext.pipelines import URSAPipeline
103
+ from diffnext.utils import export_to_video
104
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
105
+
106
+ model_id, height, width = "BAAI/URSA-1.7B-FSQ320", 320, 512
107
+ model_args = {"torch_dtype": torch.float16, "trust_remote_code": True}
108
+ pipe = URSAPipeline.from_pretrained(model_id, **model_args)
109
+ pipe = pipe.to(torch.device("cuda"))
110
+
111
+ text_prompt = "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur."
112
+ negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly"
113
+
114
+ # Text-to-Image
115
+ prompt = text_prompt
116
+ num_frames, num_inference_steps = 1, 25
117
+ image = pipe(**locals()).frames[0]
118
+ image.save("ursa.jpg")
119
+
120
+ # Image-to-Video
121
+ prompt = f"motion=9.0, {text_prompt}"
122
+ num_frames, num_inference_steps = 49, 50
123
+ video = pipe(**locals()).frames[0]
124
+ export_to_video(video, "ursa_1+48f.mp4", fps=12)
125
+
126
+ # Text-to-Video
127
+ image, video = None, None
128
+ prompt = f"motion=9.0, {text_prompt}"
129
+ num_frames, num_inference_steps = 49, 50
130
+ video = pipe(**locals()).frames[0]
131
+ export_to_video(video, "ursa_49f.mp4", fps=12)
132
+
133
+ # Video-to-Video
134
+ prompt = f"motion=5.0, {text_prompt}"
135
+ num_frames, num_inference_steps = 49, 50
136
+ num_cond_frames, cond_noise_scale = 13, 0.1
137
+ for i in range(12):
138
+ video, start_video = video[-num_cond_frames:], video
139
+ video = pipe(**locals()).frames[0]
140
+ video = numpy.concatenate([start_video, video[num_cond_frames:]])
141
+ export_to_video(video, "ursa_{}f.mp4".format(video.shape[0]), fps=12)
142
+ ```
143
+
144
+ ## 💻 Gradio Demo
145
+ <a id="gradio-demo"></a>
146
+
147
+ ```bash
148
+ # Text-to-Image (T2I)
149
+ python scripts/app_ursa_t2i.py --model "BAAI/URSA-1.7B-IBQ1024" --device 0
150
+
151
+ # Text-to-Image-to-Video (TI2V)
152
+ python scripts/app_ursa_ti2v.py --model "BAAI/URSA-1.7B-FSQ320" --device 0
153
+ ```
154
+
155
+ ## 📋 Todo List
156
+ - [X] [Model Zoo](#model-zoo)
157
+ - [X] [Quick Start](#quick-start)
158
+ - [X] [Gradio Demo](#gradio-demo)
159
+ - [X] [Evaluation Guide](./docs/evaluation.md)
160
+ - [X] [Training Guide](./docs/training.md)
161
+ - [ ] 4B Model
162
+
163
+ ## 📖 Citation
164
+ If you find this repository useful, please consider giving a star ⭐ and citation 🦖:
165
+ ```
166
+ @article{deng2025ursa,
167
+ title={Uniform Discrete Diffusion with Metric Path for Video Generation},
168
+ author={Deng, Haoge and Pan, Ting and Zhang, Fan and Liu, Yang and Luo, Zhuoyan and Cui, Yufeng and Shen, Chunhua and Shan, Shiguang and Zhang, Zhaoxiang and Wang, Xinlong},
169
+ journal={arXiv preprint arXiv:2510.24717},
170
+ year={2025}
171
+ }
172
+ ```
173
+ ```
174
+ @article{deng2024nova,
175
+ title={Autoregressive Video Generation without Vector Quantization},
176
+ author={Deng, Haoge and Pan, Ting and Diao, Haiwen and Luo, Zhengxiong and Cui, Yufeng and Lu, Huchuan and Shan, Shiguang and Qi, Yonggang and Wang, Xinlong},
177
+ journal={arXiv preprint arXiv:2412.14169},
178
+ year={2024}
179
+ }
180
+ ```
181
+
182
+ ## 🤗 Acknowledgement
183
+
184
+ We thank the repositories:
185
+ - [NOVA](https://github.com/baaivision/NOVA). ✨NOVA is the predecessor of 🐻URSA.
186
+ - [FlowMatching](https://github.com/facebookresearch/flow_matching). This codebase systemically provides CFM and DFM implementations.
187
+ - [FUDOKI](https://github.com/fudoki-hku/FUDOKI). This codebase provides a naive multimodal DFM implementation.
188
+ - [CodeWithGPU](https://github.com/seetacloud/codewithgpu). CodeWithGPU library is the core of our data loading pipeline.
189
+
190
+ ## License
191
+ Code and models are licensed under [Apache License 2.0](LICENSE).
URSA/inference.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, numpy
2
+ from diffnext.pipelines import URSAPipeline
3
+ from diffnext.utils import export_to_video
4
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
5
+
6
+
7
+
8
+ model_id, height, width = "BAAI/URSA-1.7B-FSQ320", 320, 512
9
+ model_args = {"torch_dtype": torch.bfloat16, "trust_remote_code": True}
10
+ pipe = URSAPipeline.from_pretrained(model_id, **model_args)
11
+ pipe = pipe.to(torch.device("cuda"))
12
+
13
+ text_prompt = "tom and jerry"#"a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur."
14
+ negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly"
15
+
16
+ import time
17
+
18
+ t1 = time.time()
19
+
20
+ # Text-to-Image
21
+ prompt = text_prompt
22
+ num_frames, num_inference_steps = 1, 25
23
+ image = pipe(**locals()).frames[0]
24
+ image.save("tom/ursa.jpg")
25
+
26
+ t2 = time.time()
27
+
28
+ # Image-to-Video
29
+ prompt = f"motion=9.0, {text_prompt}"
30
+ num_frames, num_inference_steps = 49, 50
31
+ video = pipe(**locals()).frames[0]
32
+ export_to_video(video, "tom/ursa_1+48f.mp4", fps=12)
33
+
34
+ t3 = time.time()
35
+
36
+ # Text-to-Video
37
+ image, video = None, None
38
+ prompt = f"motion=9.0, {text_prompt}"
39
+ num_frames, num_inference_steps = 49, 50
40
+ video = pipe(**locals()).frames[0]
41
+ export_to_video(video, "tom/ursa_49f.mp4", fps=12)
42
+
43
+ t4 = time.time()
44
+
45
+ # Video-to-Video
46
+ prompt = f"motion=5.0, {text_prompt}"
47
+ num_frames, num_inference_steps = 49, 50
48
+ num_cond_frames, cond_noise_scale = 13, 0.1
49
+ for i in range(12):
50
+ video, start_video = video[-num_cond_frames:], video
51
+ video = pipe(**locals()).frames[0]
52
+ video = numpy.concatenate([start_video, video[num_cond_frames:]])
53
+ export_to_video(video, "tom/ursa_{}f.mp4".format(video.shape[0]), fps=12)
54
+
55
+ t5 = time.time()
56
+
57
+ print(f"Text-to-Image time: {t2-t1:.2f} seconds")
58
+ print(f"Image-to-Video time: {t3-t2:.2f} seconds")
59
+ print(f"Text-to-Video time: {t4-t3:.2f} seconds")
60
+ print(f"Video-to-Video time: {t5-t4:.2f} seconds")
61
+ # Single H800 GPU, batch_size=1, the inference time is:
62
+ # Text-to-Image time: 5.05 seconds
63
+ # Image-to-Video time: 101.92 seconds
64
+ # Text-to-Video time: 101.52 seconds
65
+ # Video-to-Video time: 1226.25 seconds
66
+
67
+
68
+ # cd URSA/
69
+ # source .venv_ursa/bin/activate
70
+
71
+ # accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml --machine_rank 0 --num_machines 1 --num_processes 8 scripts/train_distill_dimo.py config="./configs/distill_dimo.yaml" experiment.output_dir="./experiments/distill_dimo_v3" distill.teacher_ckpt="/gfs/space/private/fengzl/World_Model/URSA-1.7B" distill.prompt_source="/gfs/space/private/fengzl/World_Model/Koala-36M-v1"
URSA/pyproject.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [tool.black]
2
+ line-length = 100
3
+ target-version = ['py310']
URSA/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ diffusers
3
+ transformers>=4.57.1
4
+ accelerate
5
+ imageio
6
+ imageio-ffmpeg
7
+ omegaconf
8
+ wandb
9
+ scipy
10
+ codewithgpu
URSA/setup.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2024-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Python setup script."""
17
+
18
+ import argparse
19
+ import os
20
+ import shutil
21
+ import subprocess
22
+ import sys
23
+
24
+ import setuptools
25
+ import setuptools.command.build_py
26
+ import setuptools.command.install
27
+
28
+
29
+ def parse_args():
30
+ """Parse arguments."""
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--version", default=None)
33
+ args, unknown = parser.parse_known_args()
34
+ sys.argv = [sys.argv[0]] + unknown
35
+ args.git_version = None
36
+ args.long_description = ""
37
+ if args.version is None and os.path.exists("version.txt"):
38
+ with open("version.txt", "r") as f:
39
+ args.version = f.read().strip()
40
+ if os.path.exists(".git"):
41
+ try:
42
+ git_version = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd="./")
43
+ args.git_version = git_version.decode("ascii").strip()
44
+ except (OSError, subprocess.CalledProcessError):
45
+ pass
46
+ if os.path.exists("README.md"):
47
+ with open(os.path.join("README.md"), encoding="utf-8") as f:
48
+ args.long_description = f.read()
49
+ return args
50
+
51
+
52
+ def clean_builds():
53
+ for path in ["build", "diffnext.egg-info"]:
54
+ if os.path.exists(path):
55
+ shutil.rmtree(path)
56
+
57
+
58
+ def find_packages(top):
59
+ """Return the python sources installed to package."""
60
+ packages = []
61
+ for root, _, _ in os.walk(top):
62
+ if os.path.exists(os.path.join(root, "__init__.py")):
63
+ packages.append(root)
64
+ return packages
65
+
66
+
67
+ def find_package_data():
68
+ """Return the external data installed to package."""
69
+ return []
70
+
71
+
72
+ class BuildPyCommand(setuptools.command.build_py.build_py):
73
+ """Enhanced 'build_py' command."""
74
+
75
+ def build_packages(self):
76
+ with open("diffnext/version.py", "w") as f:
77
+ f.write(
78
+ 'version = "{}"\n'
79
+ 'git_version = "{}"\n'
80
+ "__version__ = version\n".format(args.version, args.git_version)
81
+ )
82
+ super(BuildPyCommand, self).build_packages()
83
+
84
+ def build_package_data(self):
85
+ self.package_data = {"diffnext": find_package_data()}
86
+ super(BuildPyCommand, self).build_package_data()
87
+
88
+
89
+ class InstallCommand(setuptools.command.install.install):
90
+ """Enhanced 'install' command."""
91
+
92
+ def initialize_options(self):
93
+ super(InstallCommand, self).initialize_options()
94
+ self.old_and_unmanageable = True
95
+
96
+
97
+ args = parse_args()
98
+ setuptools.setup(
99
+ name="diffnext",
100
+ version=args.version,
101
+ description="A diffusers based library for autoregressive diffusion models.",
102
+ long_description=args.long_description,
103
+ long_description_content_type="text/markdown",
104
+ url="https://github.com/baaivision/URSA",
105
+ author="BAAI",
106
+ license="Apache License",
107
+ packages=find_packages("diffnext"),
108
+ cmdclass={"build_py": BuildPyCommand, "install": InstallCommand},
109
+ install_requires=[
110
+ "torch",
111
+ "diffusers",
112
+ "transformers",
113
+ "accelerate",
114
+ "imageio",
115
+ "imageio-ffmpeg",
116
+ "omegaconf",
117
+ "wandb",
118
+ "scipy",
119
+ ],
120
+ classifiers=[
121
+ "Development Status :: 5 - Production/Stable",
122
+ "Intended Audience :: Developers",
123
+ "Intended Audience :: Education",
124
+ "Intended Audience :: Science/Research",
125
+ "License :: OSI Approved :: Apache Software License",
126
+ "Programming Language :: Python :: 3",
127
+ "Programming Language :: Python :: 3 :: Only",
128
+ "Topic :: Scientific/Engineering",
129
+ "Topic :: Scientific/Engineering :: Mathematics",
130
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
131
+ ],
132
+ )
133
+ clean_builds()
URSA/ursa.jpg ADDED
URSA/version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0.3.0a0