Lekr0 commited on
Commit
90afcf2
·
verified ·
1 Parent(s): 5818a32

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. .claude/settings.local.json +17 -0
  2. ICL/.claude/settings.local.json +32 -0
  3. ICL/DAPO/verl-recipe/.github/workflows/pre-commit.yml +37 -0
  4. ICL/DAPO/verl-recipe/dapo/config/dapo_megatron_trainer.yaml +28 -0
  5. ICL/DAPO/verl-recipe/entropy/reward_score/entropy_math/math_normalize.py +192 -0
  6. ICL/DAPO/verl-recipe/fault_recover/agent_loop/fault_recover_agent_loop.py +137 -0
  7. ICL/DAPO/verl-recipe/spo/agent_loop/spo_agent_loop.py +155 -0
  8. ICL/DAPO/verl-recipe/spo/agent_loop/spo_tool_agent_loop.py +414 -0
  9. ICL/DAPO/verl-recipe/sppo/config/sppo_trainer.yaml +38 -0
  10. ICL/EVAL_GUIDE.md +47 -0
  11. ICL/LV/dataset_inspect.tree.txt +456 -0
  12. ICL/RL_DAPO/__init__.py +1 -0
  13. ICL/SFT_new/README.md +389 -0
  14. ICL/SFT_new/convert_and_eval.sh +87 -0
  15. ICL/SFT_new/ds_zero2.json +37 -0
  16. ICL/SFT_new/ds_zero3.json +28 -0
  17. ICL/SFT_new/eval.py +961 -0
  18. ICL/SFT_new/launch_wrapper.py +13 -0
  19. ICL/SFT_new/rebuild_and_train.sh +86 -0
  20. ICL/SFT_new/run_eval.sh +74 -0
  21. ICL/SFT_new/run_single_node.sh +49 -0
  22. ICL/SFT_new/submit_northjob.sh +38 -0
  23. ICL/SFT_new/train.py +659 -0
  24. ICL/build_embeddings.py +370 -0
  25. ICL/build_index.py +506 -0
  26. ICL/build_sft.py +466 -0
  27. ICL/dataset_inspect.tree.txt +456 -0
  28. ICL/eval_icl.py +524 -0
  29. ICL/extract_images.py +231 -0
  30. ICL/merge_captions.py +70 -0
  31. ICL/sft_model/epoch3_step1406_fp32/chat_template.json +3 -0
  32. ICL/sft_model/epoch3_step1406_fp32/config.json +62 -0
  33. ICL/sft_model/epoch3_step1406_fp32/generation_config.json +14 -0
  34. ICL/sft_model/epoch3_step1406_fp32/merges.txt +0 -0
  35. ICL/sft_model/epoch3_step1406_fp32/model.safetensors.index.json +757 -0
  36. ICL/sft_model/epoch3_step1406_fp32/preprocessor_config.json +21 -0
  37. ICL/sft_model/epoch3_step1406_fp32/tokenizer.json +0 -0
  38. ICL/sft_model/epoch3_step1406_fp32/tokenizer_config.json +239 -0
  39. ICL/sft_model/epoch3_step1406_fp32/video_preprocessor_config.json +21 -0
  40. ICL/sft_model/epoch3_step1406_fp32/vocab.json +0 -0
  41. ICL/sft_model/zero_to_fp32.py +760 -0
  42. RL_dataset/.gitattributes +89 -0
  43. RL_dataset/.msc +0 -0
  44. RL_dataset/.mv +1 -0
  45. RL_dataset/INFOSEEK_DOWNLOAD.md +337 -0
  46. RL_dataset/README.md +171 -0
  47. RL_dataset/dataset_infos.json +1 -0
  48. RL_dataset/download_oven_hf_mirror.sh +189 -0
  49. RL_dataset/download_scienceqa_hf.sh +135 -0
  50. download_hf.py +49 -0
.claude/settings.local.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(find /workspace/xiaobin/ICL/SFT_new/output/emb_cache/ -type f -name *.json)",
5
+ "Bash(find /workspace/xiaobin/ICL/SFT_new/output -type f -name *.json)",
6
+ "Bash(find /workspace/xiaobin/ICL -type f -name *.json)",
7
+ "Bash(find /workspace/xiaobin/ICL/SFT_new/output/emb_cache -name *.json)",
8
+ "Bash(find /workspace/xiaobin -path */medlab/*vllm_thread* -o -path */medlab/*vllm*)",
9
+ "Bash(find /workspace/xiaobin/ICL -path */emb_cache/*.json)",
10
+ "Bash(python -c \"import py_compile; py_compile.compile\\(''build_sft.py'', doraise=True\\); print\\(''OK''\\)\")",
11
+ "Bash(python -c \"import py_compile; py_compile.compile\\(''generate_captions.py'', doraise=True\\); print\\(''OK''\\)\")",
12
+ "Bash(find /workspace/xiaobin -type f -name *.py)",
13
+ "Bash(python:*)",
14
+ "Bash(find /workspace -path */NorthServe/* -maxdepth 3)"
15
+ ]
16
+ }
17
+ }
ICL/.claude/settings.local.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(python3 -c \"import sys,json; line=sys.stdin.readline\\(\\); d=json.loads\\(line\\); print\\(list\\(d.keys\\(\\)\\)\\); [print\\(f''''{k}: {type\\(d[k]\\).__name__}, len={len\\(str\\(d[k]\\)\\)}''''\\) for k in d.keys\\(\\)]\")",
5
+ "Bash(python3:*)",
6
+ "Bash(find /workspace/xiaobin/ICL -name *sft*.jsonl -o -name output -type d)",
7
+ "Bash(find /workspace/xiaobin/ICL -name *.jsonl)",
8
+ "Bash(wc:*)",
9
+ "Bash(grep -r \"model_path\\\\|model-path\" /workspace/xiaobin/ICL/SFT_new/*.py)",
10
+ "Bash(grep -r Qwen /workspace/xiaobin/ICL/SFT_new/*.py)",
11
+ "Bash(grep -l embedding /workspace/xiaobin/ICL/SFT/*.py)",
12
+ "Bash(du -sh /workspace/xiaobin/dataset/*)",
13
+ "Bash(lscpu)",
14
+ "Bash(/workspace/miniconda3/envs/sft/bin/python -c \"import torch; print\\('torch:', torch.__version__\\); print\\('CXX11_ABI:', torch._C._GLIBCXX_USE_CXX11_ABI\\)\")",
15
+ "Bash(find /workspace/xiaobin/ICL -maxdepth 3 -name *eval* -o -name *inference* -o -name *test*)",
16
+ "Bash(ls /workspace/xiaobin/ICL/sft_model/final/*.py)",
17
+ "Bash(ls /workspace/xiaobin/ICL/sft_model/final/mp_rank*)",
18
+ "Bash(ls /workspace/xiaobin/ICL/sft_model/final/*.json)",
19
+ "Bash(ls /workspace/xiaobin/ICL/sft_model/final/*tag*)",
20
+ "Bash(pip show:*)",
21
+ "Bash(conda run:*)",
22
+ "Read(//workspace/xiaobin/dataset/sft/all/**)",
23
+ "Bash(find /workspace/xiaobin/ICL -type f -name *eval* -o -name *test* -o -name *infer* -o -name *benchmark* -o -name *generate* -o -name *predict*)",
24
+ "Bash(find /workspace/xiaobin/ICL -type f \\\\\\(-name *.jsonl -o -name *.json \\\\\\))",
25
+ "Bash(grep -E \"\\\\.\\(py|sh\\)$\")",
26
+ "Bash(find /workspace/xiaobin/ICL -type f -name *.jsonl)",
27
+ "Read(//workspace/xiaobin/dataset/sft/**)",
28
+ "Read(//workspace/xiaobin/dataset/**)",
29
+ "Bash(python3 -c \"import json; d=json.load\\(open\\(''/workspace/xiaobin/dataset/detail/captioning/coco/train/captions.json''\\)\\); print\\(''keys:'', list\\(d.keys\\(\\)\\)\\); items=d[''items'']; k=list\\(items.keys\\(\\)\\)[0]; print\\(k, ''->'', items[k][:100]\\)\")"
30
+ ]
31
+ }
32
+ }
ICL/DAPO/verl-recipe/.github/workflows/pre-commit.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # c.f. https://github.com/pre-commit/action?tab=readme-ov-file#using-this-action
2
+ name: pre-commit
3
+
4
+ # No need to avoid / cancel lightweight pre-commit jobs
5
+ on:
6
+ schedule:
7
+ - cron: "0 0 * * 0"
8
+ pull_request:
9
+ push:
10
+ branches:
11
+ - main
12
+ - v0.*
13
+ # Allow manual triggering
14
+ workflow_dispatch:
15
+
16
+ # Declare permissions just read content.
17
+ permissions:
18
+ contents: read
19
+
20
+ jobs:
21
+ pre-commit:
22
+ runs-on: ubuntu-latest
23
+ strategy:
24
+ matrix:
25
+ python-version: ["3.12"]
26
+ steps:
27
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
28
+ - name: Set up Python ${{ matrix.python-version }}
29
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
30
+ with:
31
+ python-version: ${{ matrix.python-version }}
32
+ - name: Set ruff --output-format=github
33
+ run: |
34
+ sed -i 's/--output-format=full/--output-format=github/' .pre-commit-config.yaml
35
+ git add .pre-commit-config.yaml
36
+ # Check "--all-files" by default
37
+ - uses: pre-commit/action@v3.0.1
ICL/DAPO/verl-recipe/dapo/config/dapo_megatron_trainer.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ searchpath:
3
+ - file://verl/trainer/config
4
+
5
+ defaults:
6
+ - ppo_megatron_trainer
7
+ - _self_
8
+
9
+ data:
10
+ gen_batch_size: ${data.train_batch_size}
11
+
12
+ reward_model:
13
+ reward_manager: dapo
14
+ overlong_buffer:
15
+ enable: False # We try to avoid forgetting to set enable
16
+ len: 0
17
+ penalty_factor: 0.0
18
+ log: False
19
+
20
+ algorithm:
21
+ filter_groups:
22
+ _target_: verl.trainer.config.FilterGroupsConfig
23
+ enable: False # We try to avoid forgetting to set enable
24
+ metric: null # acc / score / seq_reward / seq_final_reward / ...
25
+ max_num_gen_batches: 0 # Non-positive values mean no upper limit
26
+
27
+ trainer:
28
+ project_name: verl-dapo
ICL/DAPO/verl-recipe/entropy/reward_score/entropy_math/math_normalize.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 PRIME team and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright (c) 2021 Dan Hendrycks
16
+ #
17
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
18
+ # of this software and associated documentation files (the "Software"), to deal
19
+ # in the Software without restriction, including without limitation the rights
20
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
21
+ # copies of the Software, and to permit persons to whom the Software is
22
+ # furnished to do so, subject to the following conditions:
23
+ #
24
+ # The above copyright notice and this permission notice shall be included in all
25
+ # copies or substantial portions of the Software.
26
+ #
27
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
28
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
29
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
30
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
31
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
33
+ # SOFTWARE.
34
+ """
35
+ This logic is largely copied from the Hendrycks' MATH release (math_equivalence).
36
+
37
+ From: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py
38
+ """
39
+
40
+ import re
41
+ from typing import Optional
42
+
43
+
44
+ def normalize_answer(answer: Optional[str]) -> Optional[str]:
45
+ if answer is None:
46
+ return None
47
+ answer = answer.strip()
48
+ try:
49
+ # Remove enclosing `\text{}`.
50
+ m = re.search(r"^\\text\{(?P<text>.+?)\}$", answer)
51
+ if m is not None:
52
+ answer = m.group("text").strip()
53
+ return _strip_string(answer)
54
+ except Exception:
55
+ return answer
56
+
57
+
58
+ def _fix_fracs(string):
59
+ substrs = string.split("\\frac")
60
+ new_str = substrs[0]
61
+ if len(substrs) > 1:
62
+ substrs = substrs[1:]
63
+ for substr in substrs:
64
+ new_str += "\\frac"
65
+ if substr[0] == "{":
66
+ new_str += substr
67
+ else:
68
+ try:
69
+ assert len(substr) >= 2
70
+ except Exception:
71
+ return string
72
+ a = substr[0]
73
+ b = substr[1]
74
+ if b != "{":
75
+ if len(substr) > 2:
76
+ post_substr = substr[2:]
77
+ new_str += "{" + a + "}{" + b + "}" + post_substr
78
+ else:
79
+ new_str += "{" + a + "}{" + b + "}"
80
+ else:
81
+ if len(substr) > 2:
82
+ post_substr = substr[2:]
83
+ new_str += "{" + a + "}" + b + post_substr
84
+ else:
85
+ new_str += "{" + a + "}" + b
86
+ string = new_str
87
+ return string
88
+
89
+
90
+ def _fix_a_slash_b(string):
91
+ if len(string.split("/")) != 2:
92
+ return string
93
+ a = string.split("/")[0]
94
+ b = string.split("/")[1]
95
+ try:
96
+ a = int(a)
97
+ b = int(b)
98
+ assert string == "{}/{}".format(a, b)
99
+ new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
100
+ return new_string
101
+ except Exception:
102
+ return string
103
+
104
+
105
+ def _remove_right_units(string):
106
+ # "\\text{ " only ever occurs (at least in the val set) when describing units
107
+ if "\\text{ " in string:
108
+ splits = string.split("\\text{ ")
109
+ assert len(splits) == 2
110
+ return splits[0]
111
+ else:
112
+ return string
113
+
114
+
115
+ def _fix_sqrt(string):
116
+ if "\\sqrt" not in string:
117
+ return string
118
+ splits = string.split("\\sqrt")
119
+ new_string = splits[0]
120
+ for split in splits[1:]:
121
+ if split[0] != "{":
122
+ a = split[0]
123
+ new_substr = "\\sqrt{" + a + "}" + split[1:]
124
+ else:
125
+ new_substr = "\\sqrt" + split
126
+ new_string += new_substr
127
+ return new_string
128
+
129
+
130
+ def _strip_string(string):
131
+ # linebreaks
132
+ string = string.replace("\n", "")
133
+
134
+ # remove inverse spaces
135
+ string = string.replace("\\!", "")
136
+
137
+ # replace \\ with \
138
+ string = string.replace("\\\\", "\\")
139
+
140
+ # replace tfrac and dfrac with frac
141
+ string = string.replace("tfrac", "frac")
142
+ string = string.replace("dfrac", "frac")
143
+
144
+ # remove \left and \right
145
+ string = string.replace("\\left", "")
146
+ string = string.replace("\\right", "")
147
+
148
+ # Remove circ (degrees)
149
+ string = string.replace("^{\\circ}", "")
150
+ string = string.replace("^\\circ", "")
151
+
152
+ # remove dollar signs
153
+ string = string.replace("\\$", "")
154
+
155
+ # remove units (on the right)
156
+ string = _remove_right_units(string)
157
+
158
+ # remove percentage
159
+ string = string.replace("\\\\%", "")
160
+ string = string.replace("\\%", "")
161
+
162
+ # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
163
+ string = string.replace(" .", " 0.")
164
+ string = string.replace("{.", "{0.")
165
+ # if empty, return empty string
166
+ if len(string) == 0:
167
+ return string
168
+ if string[0] == ".":
169
+ string = "0" + string
170
+
171
+ # to consider: get rid of e.g. "k = " or "q = " at beginning
172
+ if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2:
173
+ string = string.split("=")[1]
174
+
175
+ # fix sqrt3 --> sqrt{3}
176
+ string = _fix_sqrt(string)
177
+
178
+ # remove spaces
179
+ string = string.replace(" ", "")
180
+
181
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1).
182
+ # Also does a/b --> \\frac{a}{b}
183
+ string = _fix_fracs(string)
184
+
185
+ # manually change 0.5 --> \frac{1}{2}
186
+ if string == "0.5":
187
+ string = "\\frac{1}{2}"
188
+
189
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
190
+ string = _fix_a_slash_b(string)
191
+
192
+ return string
ICL/DAPO/verl-recipe/fault_recover/agent_loop/fault_recover_agent_loop.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import os
16
+ from typing import Any, Optional
17
+ from uuid import uuid4
18
+
19
+ import ray
20
+ from omegaconf import DictConfig
21
+
22
+ from verl.experimental.agent_loop.agent_loop import AgentLoopManager, AgentLoopWorker, AsyncLLMServerManager
23
+ from verl.single_controller.ray.base import RayResourcePool, RayWorkerGroup
24
+ from verl.utils.rollout_trace import rollout_trace_op
25
+ from verl.workers.rollout.replica import TokenOutput
26
+
27
+ logger = logging.getLogger(__file__)
28
+ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
29
+
30
+
31
+ class FaultRecoverAsyncLLMServerManager(AsyncLLMServerManager):
32
+ """
33
+ A class to manage multiple OpenAI compatible LLM servers. This class provides
34
+ - Load balance: least requests load balancing
35
+ - Sticky session: send multi-turn chat completions to same server for automatic prefix caching
36
+ """
37
+
38
+ @rollout_trace_op
39
+ async def generate(
40
+ self,
41
+ request_id,
42
+ *,
43
+ prompt_ids: list[int],
44
+ sampling_params: dict[str, Any],
45
+ image_data: Optional[list[Any]] = None,
46
+ video_data: Optional[list[Any]] = None,
47
+ global_id: int = None,
48
+ ) -> TokenOutput:
49
+ """Generate tokens from prompt ids.
50
+
51
+ Args:
52
+ request_id (str): request id for sticky session.
53
+ prompt_ids (List[int]): List of prompt token ids.
54
+ sampling_params (Dict[str, Any]): Sampling parameters for the chat completion.
55
+ global_id: Global batch id of req.
56
+
57
+ Returns:
58
+ TokenOutput: token output
59
+ """
60
+ server = self._choose_server(request_id)
61
+ new_request_id = uuid4().hex
62
+ tokens_queue = None
63
+ if global_id is not None:
64
+ from recipe.fault_recover.fault_manager import get_tokens_queue
65
+
66
+ tokens_queue = get_tokens_queue()
67
+
68
+ if tokens_queue is not None:
69
+ await tokens_queue.put.remote((new_request_id, global_id))
70
+
71
+ output = await server.generate.remote(
72
+ request_id=new_request_id, # use new request_id for each turn
73
+ prompt_ids=prompt_ids,
74
+ sampling_params=sampling_params,
75
+ image_data=image_data,
76
+ video_data=video_data,
77
+ )
78
+
79
+ if tokens_queue is not None:
80
+ await tokens_queue.put.remote(
81
+ {
82
+ new_request_id: {
83
+ "log_probs": output.log_probs,
84
+ "routed_experts": output.routed_experts,
85
+ "num_preempted": output.num_preempted,
86
+ }
87
+ }
88
+ )
89
+
90
+ return output
91
+
92
+
93
+ class FaultRecoverAgentLoopWorker(AgentLoopWorker):
94
+ """Agent loop worker takes a batch of messages and run each message in an agent loop."""
95
+
96
+ def __init__(
97
+ self,
98
+ config: DictConfig,
99
+ server_handles: list[ray.actor.ActorHandle],
100
+ reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
101
+ ):
102
+ super().__init__(config, server_handles, reward_loop_worker_handles)
103
+ self.server_manager = FaultRecoverAsyncLLMServerManager(config, server_handles)
104
+
105
+
106
+ class FaultRecoverAgentLoopManager(AgentLoopManager):
107
+ """Agent loop manager that manages a group of agent loop workers."""
108
+
109
+ def __init__(
110
+ self,
111
+ config: DictConfig,
112
+ worker_group: RayWorkerGroup = None,
113
+ rollout_resource_pool: RayResourcePool = None,
114
+ reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
115
+ ):
116
+ """Initialize agent loop manager.
117
+
118
+ Args:
119
+ config (DictConfig): trainer config.
120
+ worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode.
121
+ rollout_resource_pool (RayResourcePool): Resource pool for actor rollout (Colocate or Standalone mode).
122
+ reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming reward computation.
123
+ """
124
+ self.config = config
125
+ self.worker_group = worker_group
126
+ self.reward_loop_worker_handles = reward_loop_worker_handles
127
+
128
+ # for recipe to change
129
+ if not hasattr(self, "rollout_replica_class"):
130
+ from recipe.fault_recover.vllm_rollout.vllm_async_server import FaultRecovervLLMReplica
131
+
132
+ self.rollout_replica_class = FaultRecovervLLMReplica
133
+ if not hasattr(self, "agent_loop_workers_class"):
134
+ self.agent_loop_workers_class = ray.remote(FaultRecoverAgentLoopWorker)
135
+
136
+ self._initialize_llm_servers(rollout_resource_pool)
137
+ self._init_agent_loop_workers()
ICL/DAPO/verl-recipe/spo/agent_loop/spo_agent_loop.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Modifications Copyright 2025 SPO authors
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ SPO Agent Loop - Extends base agent loop with code generation support.
17
+
18
+ This module inherits from verl.experimental.agent_loop and only overrides
19
+ the generate_sequences method to add SPO-specific stop tokens for code generation.
20
+ """
21
+
22
+ import asyncio
23
+
24
+ import numpy as np
25
+ import ray
26
+
27
+ from verl import DataProto
28
+
29
+ # Re-export all base classes for backward compatibility
30
+ from verl.experimental.agent_loop.agent_loop import AgentLoopManager, get_trajectory_info
31
+ from verl.experimental.agent_loop.agent_loop import (
32
+ AgentLoopWorkerBase as BaseAgentLoopWorkerBase,
33
+ )
34
+ from verl.utils.transferqueue_utils import tqbridge
35
+
36
+ __all__ = [
37
+ "AgentLoopWorkerBase",
38
+ "SPOAgentLoopWorker",
39
+ "SPOAgentLoopManager",
40
+ ]
41
+
42
+
43
+ class AgentLoopWorkerBase(BaseAgentLoopWorkerBase):
44
+ """SPO-specific agent loop worker with code generation stop tokens.
45
+
46
+ Inherits all functionality from base AgentLoopWorkerBase and only overrides
47
+ the generate_sequences method to add SPO-specific parameters:
48
+ - stop="</code>" for code block termination
49
+ - include_stop_str_in_output=True to include the stop token
50
+ """
51
+
52
+ @tqbridge()
53
+ async def generate_sequences(self, batch: DataProto) -> DataProto:
54
+ """Generate sequences from agent loop with SPO-specific stop tokens.
55
+
56
+ Override: Adds stop="</code>" and include_stop_str_in_output=True
57
+ to sampling_params for SPO code generation use case.
58
+
59
+ Args:
60
+ batch (DataProto): Input batch.
61
+
62
+ Returns:
63
+ DataProto: Output batch.
64
+ - prompts: [bsz, prompt_length], prompt token ids from dataset.
65
+ - responses: [bsz, response_length], output token ids include response tokens
66
+ from LLM generation and observation tokens from tool_calls.
67
+ - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens.
68
+ - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens
69
+ and response tokens.
70
+ - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens.
71
+ - position_ids: [bsz, prompt_length + response_length], incremental position ids.
72
+
73
+ For multi-turn conversations:
74
+ responses: |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->|
75
+ response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0|
76
+ """
77
+ config = self.config.actor_rollout_ref.rollout
78
+
79
+ # SPO-specific: Add stop tokens for code generation
80
+ sampling_params = dict(
81
+ temperature=config.temperature,
82
+ top_p=config.top_p,
83
+ repetition_penalty=1.0,
84
+ logprobs=config.calculate_log_probs,
85
+ stop="</code>", # SPO-SPECIFIC
86
+ include_stop_str_in_output=True, # SPO-SPECIFIC
87
+ )
88
+
89
+ # override sampling params for validation
90
+ if batch.meta_info.get("validate", False):
91
+ sampling_params["top_p"] = config.val_kwargs.top_p
92
+ sampling_params["temperature"] = config.val_kwargs.temperature
93
+
94
+ # by default, we assume it's a single turn agent
95
+ if "agent_name" not in batch.non_tensor_batch:
96
+ default_agent_loop = config.agent.default_agent_loop
97
+ batch.non_tensor_batch["agent_name"] = np.array([default_agent_loop] * len(batch), dtype=object)
98
+
99
+ if "index" in batch.non_tensor_batch:
100
+ index = batch.non_tensor_batch["index"]
101
+ else:
102
+ index = np.arange(len(batch))
103
+
104
+ trajectory_info = await get_trajectory_info(
105
+ batch.meta_info.get("global_steps", -1), index.tolist(), batch.meta_info.get("validate", False)
106
+ )
107
+
108
+ tasks = []
109
+ for i in range(len(batch)):
110
+ kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}
111
+ tasks.append(asyncio.create_task(self._run_agent_loop(sampling_params, trajectory_info[i], **kwargs)))
112
+ outputs = await asyncio.gather(*tasks)
113
+
114
+ output = self._postprocess(outputs)
115
+ return output
116
+
117
+
118
+ @ray.remote
119
+ class SPOAgentLoopWorker(AgentLoopWorkerBase):
120
+ """SPO Agent Loop Worker as a Ray remote actor.
121
+
122
+ This is a Ray remote actor wrapper around AgentLoopWorkerBase,
123
+ enabling distributed execution with SPO-specific stop tokens.
124
+ """
125
+
126
+ def __init__(self, config, server_handles, reward_router_address=None):
127
+ """Initialize SPO Agent Loop Worker.
128
+
129
+ Args:
130
+ config: trainer config.
131
+ server_handles: OpenAI compatible LLM server actor handles.
132
+ reward_router_address: reward router address.
133
+ """
134
+ super().__init__(config, server_handles, reward_router_address)
135
+
136
+
137
+ class SPOAgentLoopManager(AgentLoopManager):
138
+ """SPO-specific Agent Loop Manager that uses SPO's AgentLoopWorker.
139
+
140
+ Inherits all functionality from base AgentLoopManager and only overrides
141
+ the agent_loop_workers_class to use SPOAgentLoopWorker which includes
142
+ code generation stop tokens.
143
+ """
144
+
145
+ def __init__(self, config, worker_group=None, rm_wg=None):
146
+ """Initialize SPO Agent Loop Manager.
147
+
148
+ Args:
149
+ config: trainer config.
150
+ worker_group: ActorRolloutRef worker group for hybrid mode; None for standalone mode.
151
+ rm_wg: Reward model worker group.
152
+ """
153
+ # Set SPO-specific worker class before calling parent __init__
154
+ self.agent_loop_workers_class = SPOAgentLoopWorker
155
+ super().__init__(config, worker_group, rm_wg)
ICL/DAPO/verl-recipe/spo/agent_loop/spo_tool_agent_loop.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ # Modifications Copyright 2025 SPO authors
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import asyncio
16
+ import copy
17
+ import logging
18
+ import os
19
+ from typing import Any, Optional
20
+ from uuid import uuid4
21
+
22
+ from verl.experimental.agent_loop.agent_loop import (
23
+ AgentLoopBase,
24
+ AgentLoopOutput,
25
+ register,
26
+ )
27
+ from verl.experimental.agent_loop.tool_agent_loop import AgentState
28
+ from verl.interactions.base import BaseInteraction
29
+ from verl.interactions.utils.interaction_registry import (
30
+ initialize_interactions_from_config,
31
+ )
32
+ from verl.tools.schemas import ToolResponse
33
+ from verl.tools.utils.tool_registry import initialize_tools_from_config
34
+ from verl.utils.profiler import simple_timer
35
+ from verl.utils.rollout_trace import rollout_trace_op
36
+
37
+ logger = logging.getLogger(__file__)
38
+ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
39
+
40
+
41
+ class AgentData:
42
+ """Encapsulates all state variables for the agent loop."""
43
+
44
+ def __init__(
45
+ self,
46
+ messages: list[dict[str, Any]],
47
+ image_data: Any,
48
+ metrics: dict[str, Any],
49
+ request_id: str,
50
+ tools_kwargs: dict[str, Any],
51
+ interaction: Optional[BaseInteraction] = None,
52
+ interaction_kwargs: Optional[dict[str, Any]] = None,
53
+ ):
54
+ self.messages = messages
55
+ self.image_data = image_data
56
+ self.metrics = metrics
57
+ self.request_id = request_id
58
+ self.tools_kwargs = tools_kwargs
59
+ self.interaction = interaction
60
+ self.interaction_kwargs = interaction_kwargs or {}
61
+
62
+ # State variables
63
+ self.prompt_ids: list[int] = []
64
+ self.response_ids: list[int] = []
65
+ self.response_mask: list[int] = []
66
+ self.response_logprobs: list[float] = []
67
+ self.turn_scores: list[float] = []
68
+ self.tool_rewards: list[float] = []
69
+ self.user_turns = 0
70
+ self.assistant_turns = 0
71
+
72
+ # Temporary state for tool calls
73
+ self.tool_calls: list[str] = [] # Raw Python code strings extracted from <code> tags
74
+
75
+
76
+ @register("spo_tool_agent")
77
+ class SPOToolAgentLoop(AgentLoopBase):
78
+ @classmethod
79
+ def init_class(cls, config, tokenizer, processor, **kwargs):
80
+ if cls._class_initialized:
81
+ return
82
+ cls._class_initialized = True
83
+ print("Performing class-level ToolAgentLoop initialization")
84
+
85
+ # Initialize tools from config file
86
+ cls.tokenizer = tokenizer
87
+ cls.processor = processor
88
+ cls.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns
89
+ cls.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns
90
+ cls.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls
91
+ cls.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length
92
+ cls.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side
93
+ tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path
94
+ tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []
95
+ cls.tools = {tool.name: tool for tool in tool_list}
96
+ cls.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]
97
+ print(f"Initialized tools: {cls.tools}")
98
+
99
+ cls.apply_chat_template_kwargs = config.data.get("apply_chat_template_kwargs", {})
100
+ cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length
101
+ cls.response_length = config.actor_rollout_ref.rollout.response_length
102
+ cls.system_prompt = tokenizer.apply_chat_template(
103
+ [{}], add_generation_prompt=False, tokenize=True, **cls.apply_chat_template_kwargs
104
+ )
105
+ # Initialize interactions from config file
106
+ cls.interaction_config_file = config.actor_rollout_ref.rollout.multi_turn.interaction_config_path
107
+ if cls.interaction_config_file:
108
+ cls.interaction_map: dict[str, BaseInteraction] = cls._initialize_interactions(cls.interaction_config_file)
109
+
110
+ @rollout_trace_op
111
+ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
112
+ messages = list(kwargs["raw_prompt"])
113
+ image_data = copy.deepcopy(kwargs.get("multi_modal_data", {}).get("image", None))
114
+ metrics = {}
115
+ request_id = uuid4().hex
116
+ tools_kwargs = kwargs.get("tools_kwargs", {})
117
+
118
+ # Initialize interaction if needed
119
+ interaction = None
120
+ interaction_kwargs = {}
121
+ if self.interaction_config_file:
122
+ interaction_kwargs = kwargs["extra_info"]["interaction_kwargs"]
123
+ if "name" not in interaction_kwargs:
124
+ raise ValueError("'name' key is required in interaction_kwargs")
125
+ interaction_name = interaction_kwargs["name"]
126
+ if interaction_name not in self.interaction_map:
127
+ raise ValueError(
128
+ f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: "
129
+ f"{list(self.interaction_map.keys())}"
130
+ )
131
+ interaction = self.interaction_map[interaction_name]
132
+ await interaction.start_interaction(request_id, **interaction_kwargs)
133
+ # Create AgentData instance to encapsulate all state
134
+ agent_data = AgentData(
135
+ messages=messages,
136
+ image_data=image_data,
137
+ metrics=metrics,
138
+ request_id=request_id,
139
+ tools_kwargs=tools_kwargs,
140
+ interaction=interaction,
141
+ interaction_kwargs=interaction_kwargs,
142
+ )
143
+
144
+ # State machine loop
145
+ state = AgentState.PENDING
146
+ while state != AgentState.TERMINATED:
147
+ if state == AgentState.PENDING:
148
+ state = await self._handle_pending_state(agent_data, sampling_params)
149
+ elif state == AgentState.GENERATING:
150
+ state = await self._handle_generating_state(agent_data, sampling_params)
151
+ elif state == AgentState.PROCESSING_TOOLS:
152
+ state = await self._handle_processing_tools_state(agent_data)
153
+ elif state == AgentState.INTERACTING:
154
+ state = await self._handle_interacting_state(agent_data)
155
+ else:
156
+ logger.error(f"Invalid state: {state}")
157
+ state = AgentState.TERMINATED
158
+
159
+ # Finalize output
160
+ response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :]
161
+ prompt_ids = agent_data.prompt_ids[: len(agent_data.prompt_ids) - len(agent_data.response_mask)]
162
+ multi_modal_data = {"image": agent_data.image_data} if agent_data.image_data is not None else {}
163
+ output = AgentLoopOutput(
164
+ prompt_ids=prompt_ids,
165
+ response_ids=response_ids[: self.response_length],
166
+ response_mask=agent_data.response_mask[: self.response_length],
167
+ multi_modal_data=multi_modal_data,
168
+ response_logprobs=agent_data.response_logprobs[: self.response_length]
169
+ if agent_data.response_logprobs
170
+ else None,
171
+ num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,
172
+ metrics=agent_data.metrics,
173
+ extra_fields={},
174
+ )
175
+ output.extra_fields.update({"turn_scores": agent_data.turn_scores, "tool_rewards": agent_data.tool_rewards})
176
+ return output
177
+
178
+ def _extract_code_blocks(self, response_ids: list[int]) -> list[str]:
179
+ """Extract Python code from <code>...</code> tags in response.
180
+
181
+ Args:
182
+ response_ids: Token IDs from model response
183
+
184
+ Returns:
185
+ List of cleaned Python code strings
186
+ """
187
+ import re
188
+
189
+ # Decode token IDs to text
190
+ response_text = self.tokenizer.decode(response_ids, skip_special_tokens=False)
191
+
192
+ # Extract all code blocks between <code> and </code> tags
193
+ pattern = r"<code>(.*?)</code>"
194
+ matches = re.findall(pattern, response_text, re.DOTALL)
195
+
196
+ # Clean each code block (remove markdown fences, strip whitespace)
197
+ cleaned_codes = []
198
+ for match in matches:
199
+ # Remove markdown code fences if present
200
+ cleaned = re.sub(r"^```(?:python)?\s*\n?", "", match.strip())
201
+ cleaned = re.sub(r"\n?```\s*$", "", cleaned)
202
+ cleaned_codes.append(cleaned.strip())
203
+
204
+ return cleaned_codes
205
+
206
+ async def _handle_pending_state(self, agent_data: AgentData, sampling_params: dict[str, Any]) -> AgentState:
207
+ """Handle the pending state: prepare the prompt and start generation."""
208
+ problem = agent_data.messages[0]["content"]
209
+ user_prompt = (
210
+ "Solve the following problem step by step. "
211
+ "You now have the ability to selectively write executable Python code to enhance your reasoning process. "
212
+ "The Python code will be executed by an external sandbox, and the output "
213
+ "(wrapped in `<interpreter>output_str</interpreter>`)"
214
+ " can be returned to aid your reasoning and help you arrive at the final answer. "
215
+ "The Python code should be complete scripts, including necessary imports. "
216
+ "Important: The sandbox is stateless and non-interactive; thus, prior imports, definitions, "
217
+ "and state do not persist between executions and cannot be referenced.\n"
218
+ "Each code snippet is wrapped with `<code>\n```python\ncode snippet\n```\n</code>`.\n"
219
+ )
220
+ user_prompt += "*user question:*\n"
221
+ user_prompt += problem
222
+ messages = [{"role": "user", "content": user_prompt}]
223
+ agent_data.prompt_ids = await self.loop.run_in_executor(
224
+ None,
225
+ lambda: self.tokenizer.apply_chat_template(
226
+ messages, add_generation_prompt=True, tokenize=True, **self.apply_chat_template_kwargs
227
+ ),
228
+ )
229
+
230
+ return AgentState.GENERATING
231
+
232
+ async def _handle_generating_state(
233
+ self, agent_data: AgentData, sampling_params: dict[str, Any], ignore_termination: bool = False
234
+ ) -> AgentState:
235
+ """Handle the generating state: generate model response and check for tool calls."""
236
+ add_messages: list[dict[str, Any]] = []
237
+
238
+ with simple_timer("generate_sequences", agent_data.metrics):
239
+ output = await self.server_manager.generate(
240
+ request_id=agent_data.request_id,
241
+ prompt_ids=agent_data.prompt_ids,
242
+ sampling_params=sampling_params,
243
+ image_data=agent_data.image_data,
244
+ )
245
+
246
+ agent_data.assistant_turns += 1
247
+ agent_data.response_ids = output.token_ids
248
+ agent_data.prompt_ids += agent_data.response_ids
249
+ agent_data.response_mask += [1] * len(agent_data.response_ids)
250
+ if output.log_probs:
251
+ agent_data.response_logprobs += output.log_probs
252
+
253
+ # Check termination conditions
254
+ if not ignore_termination and len(agent_data.response_mask) >= self.response_length:
255
+ return AgentState.TERMINATED
256
+ if self.max_assistant_turns and agent_data.assistant_turns >= self.max_assistant_turns:
257
+ return AgentState.TERMINATED
258
+ if self.max_user_turns and agent_data.user_turns >= self.max_user_turns:
259
+ return AgentState.TERMINATED
260
+
261
+ # Extract code blocks from <code> tags
262
+ agent_data.tool_calls = self._extract_code_blocks(agent_data.response_ids)
263
+
264
+ # Handle interaction if needed
265
+ if self.interaction_config_file:
266
+ assistant_message = await self.loop.run_in_executor(
267
+ None, lambda: self.tokenizer.decode(agent_data.response_ids, skip_special_tokens=True)
268
+ )
269
+ add_messages.append({"role": "assistant", "content": assistant_message})
270
+ agent_data.messages.extend(add_messages)
271
+
272
+ # Determine next state
273
+ if agent_data.tool_calls:
274
+ return AgentState.PROCESSING_TOOLS
275
+ elif self.interaction_config_file:
276
+ return AgentState.INTERACTING
277
+ else:
278
+ return AgentState.TERMINATED
279
+
280
+ async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentState:
281
+ """Handle the processing tools state: execute tool calls and prepare tool responses."""
282
+ tasks = []
283
+ tool_call_names = []
284
+ for tool_call in agent_data.tool_calls[: self.max_parallel_calls]:
285
+ tasks.append(self._call_tool(tool_call, agent_data.tools_kwargs))
286
+ tool_call_names.append("code_interpreter")
287
+
288
+ with simple_timer("tool_calls", agent_data.metrics):
289
+ responses = await asyncio.gather(*tasks)
290
+
291
+ response_ids = await self.loop.run_in_executor(
292
+ None, lambda: self.tokenizer.encode(responses[0].text or "", add_special_tokens=False)
293
+ )
294
+
295
+ if len(agent_data.response_mask) + len(response_ids) >= self.response_length:
296
+ return AgentState.TERMINATED
297
+ # Update prompt_ids and response_mask
298
+ agent_data.prompt_ids += response_ids
299
+ agent_data.response_mask += [0] * len(response_ids)
300
+ if agent_data.response_logprobs:
301
+ agent_data.response_logprobs += [0.0] * len(response_ids)
302
+ agent_data.user_turns += 1
303
+ # Change agent_data.request_id to avoid caching issues
304
+ agent_data.request_id = uuid4().hex
305
+ return AgentState.GENERATING
306
+
307
+ async def _handle_interacting_state(self, agent_data: AgentData) -> AgentState:
308
+ """Handle the interacting state: get user input from interaction."""
309
+ (
310
+ should_terminate_sequence,
311
+ interaction_responses,
312
+ reward,
313
+ metrics,
314
+ ) = await agent_data.interaction.generate_response(
315
+ agent_data.request_id, agent_data.messages, **agent_data.interaction_kwargs
316
+ )
317
+ agent_data.user_turns += 1
318
+
319
+ add_messages: list[dict[str, Any]] = [{"role": "user", "content": interaction_responses}]
320
+ agent_data.messages.extend(add_messages)
321
+
322
+ if reward is not None:
323
+ agent_data.turn_scores.append(reward)
324
+
325
+ # Update prompt with user responses (similar to _handle_processing_tools_state)
326
+ if self.processor is not None:
327
+ raw_user_response = await self.loop.run_in_executor(
328
+ None,
329
+ lambda: self.processor.apply_chat_template(
330
+ add_messages,
331
+ add_generation_prompt=True,
332
+ tokenize=False,
333
+ **self.apply_chat_template_kwargs,
334
+ ),
335
+ )
336
+ model_inputs = self.processor(text=[raw_user_response], images=None, return_tensors="pt")
337
+ response_ids = model_inputs.pop("input_ids").squeeze(0).tolist()
338
+ else:
339
+ response_ids = await self.loop.run_in_executor(
340
+ None,
341
+ lambda: self.tokenizer.apply_chat_template(add_messages, add_generation_prompt=True, tokenize=True),
342
+ )
343
+ response_ids = response_ids[len(self.system_prompt) :]
344
+
345
+ # Update prompt_ids and response_mask
346
+ agent_data.prompt_ids += response_ids
347
+ agent_data.response_mask += [0] * len(response_ids)
348
+ if agent_data.response_logprobs:
349
+ agent_data.response_logprobs += [0.0] * len(response_ids)
350
+
351
+ # double check prompt
352
+ # Check termination condition
353
+ if should_terminate_sequence:
354
+ return AgentState.TERMINATED
355
+ else:
356
+ return AgentState.GENERATING
357
+
358
+ async def _call_tool(self, tool_call: str, tools_kwargs: dict[str, Any]) -> tuple[ToolResponse, float, dict]:
359
+ """Call tool and return tool response."""
360
+ tool, instance_id = None, None
361
+ try:
362
+ tool = self.tools["code_interpreter"]
363
+ instance_id, _ = await tool.create(create_kwargs={})
364
+
365
+ tool_execution_response, _, _ = await tool.execute(instance_id, tool_call)
366
+ except Exception as e:
367
+ logger.warning(f"Error when executing tool: {e}")
368
+ return (
369
+ ToolResponse(
370
+ text=f"Error when executing tool: {e}",
371
+ ),
372
+ 0.0,
373
+ {},
374
+ )
375
+ finally:
376
+ if tool and instance_id:
377
+ await tool.release(instance_id)
378
+
379
+ tool_response_text = tool_execution_response.text
380
+ if tool_response_text and len(tool_response_text) > self.max_tool_response_length:
381
+ if self.tool_response_truncate_side == "left":
382
+ tool_response_text = tool_response_text[: self.max_tool_response_length] + "...(truncated)"
383
+ elif self.tool_response_truncate_side == "right":
384
+ tool_response_text = "(truncated)..." + tool_response_text[-self.max_tool_response_length :]
385
+ else:
386
+ length = self.max_tool_response_length // 2
387
+ tool_response_text = tool_response_text[:length] + "...(truncated)..." + tool_response_text[-length:]
388
+
389
+ tool_response_text = f"<interpreter>\n{tool_response_text}\n</interpreter>\n\n"
390
+
391
+ # Create ToolResponse from tool execution result
392
+ tool_response_kwargs = {"text": tool_response_text}
393
+
394
+ # Add multimedia data if present
395
+ for attr_name in ["image", "video"]:
396
+ if hasattr(tool_execution_response, attr_name):
397
+ attr_value = getattr(tool_execution_response, attr_name)
398
+ if attr_value is not None:
399
+ tool_response_kwargs[attr_name] = attr_value
400
+
401
+ return ToolResponse(**tool_response_kwargs)
402
+
403
+ @classmethod
404
+ def _initialize_interactions(cls, interaction_config_file):
405
+ """Initialize interactions from configuration.
406
+ Returns:
407
+ dict[str, BaseInteraction]: A dictionary mapping interaction names to interaction instances.
408
+ """
409
+ if interaction_config_file is None:
410
+ return {}
411
+
412
+ interaction_map = initialize_interactions_from_config(interaction_config_file)
413
+ logger.info(f"Initialize interactions from configuration: interaction_map: {list(interaction_map.keys())}")
414
+ return interaction_map
ICL/DAPO/verl-recipe/sppo/config/sppo_trainer.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # the sppo config will override default ppo_trainer.yaml
2
+
3
+ hydra:
4
+ searchpath:
5
+ - file://verl/trainer/config
6
+
7
+ defaults:
8
+ - ppo_trainer
9
+ - _self_
10
+
11
+ actor_rollout_ref:
12
+ actor:
13
+ _target_: recipe.sppo.config.SPPOActorConfig
14
+
15
+ # sppo_eta is an additional hyperparameter for SPPO, not available in
16
+ # verl core. specifying _target_ with SPPOActorConfig is needed to
17
+ # extend verl ActorConfig with custom fields.
18
+ # additional, it is also possible to use the `extra` field natively supported
19
+ # by all verl core dataclasses, without having to define SPPOActorConfig
20
+ # extra:
21
+ # sppo_eta: 1.0
22
+ sppo_eta: 1.0
23
+
24
+ optim:
25
+ lr_warmup_steps: 15
26
+ rollout:
27
+ name: sglang
28
+ tensor_model_parallel_size: 2
29
+ gpu_memory_utilization: 0.5
30
+ val_kwargs:
31
+ n: 2 # 2 will trigger validation, 1 will bypass
32
+
33
+ algorithm:
34
+ adv_estimator: null
35
+ sppo_eta: 1.0
36
+
37
+ trainer:
38
+ log_val_generations: 0
ICL/EVAL_GUIDE.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ICL 模型评测步骤
2
+
3
+ ## Step 1: 合并 DeepSpeed checkpoint(safetensors 格式)
4
+
5
+ ```bash
6
+ cd /workspace/xiaobin/ICL
7
+
8
+ python3 sft_model/zero_to_fp32.py \
9
+ sft_model \
10
+ sft_model/merged_hf \
11
+ --safe_serialization
12
+ ```
13
+
14
+ ## Step 2: 复制 tokenizer 和 config(注意不要复制 model.safetensors.index.json)
15
+
16
+ ```bash
17
+ cp /workspace/models/Qwen3-VL-8B-Instruct/config.json sft_model/merged_hf/
18
+ cp /workspace/models/Qwen3-VL-8B-Instruct/generation_config.json sft_model/merged_hf/
19
+ cp /workspace/models/Qwen3-VL-8B-Instruct/preprocessor_config.json sft_model/merged_hf/
20
+ cp /workspace/models/Qwen3-VL-8B-Instruct/chat_template.json sft_model/merged_hf/ 2>/dev/null
21
+ cp /workspace/models/Qwen3-VL-8B-Instruct/tokenizer* sft_model/merged_hf/
22
+ cp /workspace/models/Qwen3-VL-8B-Instruct/merges.txt sft_model/merged_hf/
23
+ cp /workspace/models/Qwen3-VL-8B-Instruct/vocab.json sft_model/merged_hf/
24
+ ```
25
+
26
+ ## Step 3: 跑评测
27
+
28
+ 单卡:
29
+
30
+ ```bash
31
+ python3 eval_icl.py \
32
+ --model-path sft_model/merged_hf \
33
+ --all-categories \
34
+ --num-samples 100 \
35
+ --max-rounds 4 \
36
+ --device cuda:0
37
+ ```
38
+
39
+ 多卡 (8 GPU):
40
+
41
+ ```bash
42
+ torchrun --nproc_per_node=8 eval_icl.py \
43
+ --model-path sft_model/merged_hf \
44
+ --all-categories \
45
+ --num-samples 100 \
46
+ --max-rounds 4
47
+ ```
ICL/LV/dataset_inspect.tree.txt ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ M3IT/
2
+ .git/
3
+ data/
4
+ .gitattributes (2.8KB)
5
+ .gitignore (29.0B)
6
+ M3IT.py (54.5KB)
7
+ README.md (18.3KB)
8
+ branches/
9
+ hooks/
10
+ info/
11
+ lfs/
12
+ logs/
13
+ objects/
14
+ refs/
15
+ FETCH_HEAD (110.0B)
16
+ HEAD (21.0B)
17
+ config (339.0B)
18
+ description (73.0B)
19
+ packed-refs (112.0B)
20
+ refs/
21
+ HEAD (189.0B)
22
+ heads/
23
+ remotes/
24
+ main (189.0B)
25
+ heads/
26
+ remotes/
27
+ tags/
28
+ origin/
29
+ HEAD (30.0B)
30
+ main (41.0B)
31
+ info/
32
+ pack/
33
+ pack-ee3e40a1a23ec17affa3b8afb61dc14bdffb229c.idx (38.9KB)
34
+ pack-ee3e40a1a23ec17affa3b8afb61dc14bdffb229c.pack (195.5KB)
35
+ applypatch-msg.sample (478.0B)
36
+ commit-msg.sample (896.0B)
37
+ fsmonitor-watchman.sample (4.5KB)
38
+ post-checkout (280.0B)
39
+ post-commit (276.0B)
40
+ post-merge (274.0B)
41
+ post-update.sample (189.0B)
42
+ pre-applypatch.sample (424.0B)
43
+ pre-commit.sample (1.6KB)
44
+ pre-merge-commit.sample (416.0B)
45
+ pre-push (270.0B)
46
+ pre-push.sample (1.3KB)
47
+ pre-rebase.sample (4.8KB)
48
+ pre-receive.sample (544.0B)
49
+ prepare-commit-msg.sample (1.5KB)
50
+ push-to-checkout.sample (2.7KB)
51
+ update.sample (3.6KB)
52
+ incomplete/
53
+ logs/
54
+ objects/
55
+ tmp/
56
+ 0152398d9443f2d300adc9e6099a773c66303d4e2e085812cd502cb36da7a0c73483193049 (0.0B)
57
+ 0152398d9443f2d300adc9e6099a773c66303d4e2e085812cd502cb36da7a0c7763208216 (0.0B)
58
+ 0152398d9443f2d300adc9e6099a773c66303d4e2e085812cd502cb36da7a0c789921672 (2.5MB)
59
+ 0968a4438d46277583968011563e959e130feaee66f51bb2d66dbd7e8c979f8c.part (0.0B)
60
+ 1f77f56225e10edca84be06b6e0d796c579cbf1d4884aee46da564438ad1ba9b1484563810 (437.0KB)
61
+ 1f77f56225e10edca84be06b6e0d796c579cbf1d4884aee46da564438ad1ba9b3850099655 (326.7KB)
62
+ 1f77f56225e10edca84be06b6e0d796c579cbf1d4884aee46da564438ad1ba9b3898577811 (4.1MB)
63
+ 220d32d087b6b29d1c5aaa49324d32b32ae1c19f42e9800f40f24d3a695c2a8d1743027097 (0.0B)
64
+ 220d32d087b6b29d1c5aaa49324d32b32ae1c19f42e9800f40f24d3a695c2a8d3014727128 (0.0B)
65
+ 220d32d087b6b29d1c5aaa49324d32b32ae1c19f42e9800f40f24d3a695c2a8d71894927 (62.6KB)
66
+ 24f014bb5bc7b1fa7d9183dd65fd4b43c0c49aafd6af01bb91ae3a0e7e65502b2818819757 (49.3MB)
67
+ 3da69649bfbc671710f38c2c2f7c6aaecb8f8544de3446866054bf927257c9332854861486 (158.6KB)
68
+ 3da69649bfbc671710f38c2c2f7c6aaecb8f8544de3446866054bf927257c9334214717938 (0.0B)
69
+ 3da69649bfbc671710f38c2c2f7c6aaecb8f8544de3446866054bf927257c933593947826 (0.0B)
70
+ 45e8c51ed0df8edb1ae51d2012b3f7d6cd9cc84addf41e6f9f9adb0f625d41033126870057 (259.2MB)
71
+ 4a80559730d917177e4d13246da0ce23ca318735b29d519d0448bea5579b1a771450117433 (154.4MB)
72
+ 4fda2aa4918e5dec847935db6d46e9bebc570a173bd4201c5f48e60a3f73813a1530155941 (1.1MB)
73
+ 4fda2aa4918e5dec847935db6d46e9bebc570a173bd4201c5f48e60a3f73813a2738070238 (0.0B)
74
+ 4fda2aa4918e5dec847935db6d46e9bebc570a173bd4201c5f48e60a3f73813a2828099128 (0.0B)
75
+ 52a445f8a26cd898e64129e7f1d4bfa6d7203311442068684f5344fc73407310.part (0.0B)
76
+ 6728a8fb7bad0bad3a2a27669232cb9ae66461c635172f1f7958c80a28e09fa32607733000 (150.2MB)
77
+ 6bb6c9f17e77eab7d88e4a4501c38cb31a6cf792fe77e3b75d511b964a5667df2998182268 (91.8MB)
78
+ 8cb15647ff6bbac322142fea1a38599c523f73acb3614ddb7d12e6a1975a79dc1986657385 (0.0B)
79
+ 8cb15647ff6bbac322142fea1a38599c523f73acb3614ddb7d12e6a1975a79dc2743098052 (0.0B)
80
+ 8cb15647ff6bbac322142fea1a38599c523f73acb3614ddb7d12e6a1975a79dc4193739161 (0.0B)
81
+ 9919274ad6bc88e37235a4c7245d05e357e404ef3352a90a1ba0594e694893c01114223911 (0.0B)
82
+ 9919274ad6bc88e37235a4c7245d05e357e404ef3352a90a1ba0594e694893c03545613611 (0.0B)
83
+ 9919274ad6bc88e37235a4c7245d05e357e404ef3352a90a1ba0594e694893c0559090370 (2.8MB)
84
+ 9cdf4d1a6972db893c8db1a4f2be0d1ec0362ba22a44542402b336760029c87253830692 (88.0MB)
85
+ b6aed90c79d180c5346994f8e7d0657b3d8a9aab002c057503736b4013a2096b.part (0.0B)
86
+ ba47b9680dc949322877399218d1f210a057249803bc70addfb9528152e4b1662004000729 (218.5MB)
87
+ ca49e0b3f3400f38519a1103b2a567db32c9fa990a7395b1024b94454601479b.part (0.0B)
88
+ d66a5b3267a7935b8ff272bcc166a8f43a8d66fb89c59503d536ac87661a02022501429466 (0.0B)
89
+ d66a5b3267a7935b8ff272bcc166a8f43a8d66fb89c59503d536ac87661a020230475132 (0.0B)
90
+ d66a5b3267a7935b8ff272bcc166a8f43a8d66fb89c59503d536ac87661a0202373225118 (62.5KB)
91
+ e5a3eb3e2d0c47d6f014e294ef7398bf26375920c8d2af80fd65e255396dcc78.part (0.0B)
92
+ f19cacf3a9f9a57abdcafc4a6d242aa9c6fa48188ad0a394b1a2558cb8ab4dc5372340294 (199.2MB)
93
+ 20251021T152133.441099492.log (1.4KB)
94
+ 01/
95
+ 02/
96
+ 03/
97
+ 05/
98
+ 06/
99
+ 07/
100
+ 09/
101
+ 0b/
102
+ 0f/
103
+ 10/
104
+ 12/
105
+ 15/
106
+ 16/
107
+ 19/
108
+ 1d/
109
+ 1e/
110
+ 1f/
111
+ 21/
112
+ 22/
113
+ 23/
114
+ 24/
115
+ 2a/
116
+ 2b/
117
+ 2c/
118
+ 2d/
119
+ 2f/
120
+ 30/
121
+ 32/
122
+ 34/
123
+ 37/
124
+ 3b/
125
+ 3d/
126
+ 44/
127
+ 45/
128
+ 4a/
129
+ 4f/
130
+ 50/
131
+ 52/
132
+ 54/
133
+ 56/
134
+ 58/
135
+ 5a/
136
+ 5b/
137
+ 60/
138
+ 61/
139
+ 64/
140
+ 65/
141
+ 67/
142
+ 68/
143
+ 69/
144
+ 6b/
145
+ 6d/
146
+ 6e/
147
+ 70/
148
+ 75/
149
+ 76/
150
+ 7b/
151
+ 7c/
152
+ 80/
153
+ 87/
154
+ 88/
155
+ 89/
156
+ 8b/
157
+ 8c/
158
+ 90/
159
+ 91/
160
+ 93/
161
+ 99/
162
+ 9a/
163
+ 9b/
164
+ 9c/
165
+ 9e/
166
+ 9f/
167
+ a0/
168
+ a5/
169
+ a9/
170
+ ac/
171
+ ae/
172
+ b1/
173
+ b3/
174
+ b4/
175
+ b6/
176
+ ba/
177
+ bb/
178
+ bc/
179
+ bd/
180
+ be/
181
+ c0/
182
+ c1/
183
+ c2/
184
+ c4/
185
+ c6/
186
+ c7/
187
+ c8/
188
+ ca/
189
+ cb/
190
+ d6/
191
+ d9/
192
+ dd/
193
+ e2/
194
+ e5/
195
+ e7/
196
+ e8/
197
+ e9/
198
+ ee/
199
+ ef/
200
+ f1/
201
+ f3/
202
+ f4/
203
+ f5/
204
+ f6/
205
+ f7/
206
+ f8/
207
+ f9/
208
+ fc/
209
+ exclude (240.0B)
210
+ captioning/
211
+ classification/
212
+ generation/
213
+ reasoning/
214
+ vqa/
215
+ chinesefoodnet-10/
216
+ coco-goi/
217
+ coco-text/
218
+ imagenet/
219
+ iqa/
220
+ itm/
221
+ mocheg/
222
+ refcoco/
223
+ snli-ve/
224
+ ss/
225
+ vsr/
226
+ winoground/
227
+ .gitattributes (141.0B)
228
+ README.md (211.0B)
229
+ instructions.json (1.4KB)
230
+ labels.json (9.0KB)
231
+ test.jsonl (223.5MB)
232
+ train.jsonl (238.9MB)
233
+ val.jsonl (227.6MB)
234
+ README.md (31.0B)
235
+ esnlive_test.jsonl (743.0MB)
236
+ esnlive_train.jsonl (1000.8MB)
237
+ esnlive_val.jsonl (717.9MB)
238
+ instructions.json (1.9KB)
239
+ test_2023-10-09.jsonl (2.9GB)
240
+ train_2023-10-09.jsonl (3.9GB)
241
+ instructions.json (825.0B)
242
+ mapping.txt (30.9KB)
243
+ test_2023-10-08.jsonl (10.6GB)
244
+ train.jsonl (1.5GB)
245
+ train_2023-10-08.jsonl (5.9GB)
246
+ val.jsonl (2.6GB)
247
+ instructions.json (907.0B)
248
+ test.jsonl (330.4MB)
249
+ test_2023-10-09.jsonl (1.3GB)
250
+ train.jsonl (1.9GB)
251
+ train_2023-10-08.jsonl (7.8GB)
252
+ val.jsonl (330.8MB)
253
+ instructions.json (773.0B)
254
+ test.jsonl (730.0MB)
255
+ test_2023-10-09.jsonl (2.9GB)
256
+ train.jsonl (4.3GB)
257
+ train_2023-10-08.jsonl (17.1GB)
258
+ val.jsonl (730.2MB)
259
+ instructions.json (1.4KB)
260
+ test_2023-10-09.jsonl (553.7MB)
261
+ train_2023-10-09.jsonl (1.9GB)
262
+ vsr_test.jsonl (137.7MB)
263
+ vsr_train.jsonl (483.3MB)
264
+ vsr_val.jsonl (68.8MB)
265
+ instructions.json (774.0B)
266
+ test_2023-10-10.jsonl (7.6GB)
267
+ train.jsonl (8.2GB)
268
+ train_2023-10-08.jsonl (32.8GB)
269
+ val.jsonl (1.9GB)
270
+ instructions.json (733.0B)
271
+ test_2023-10-07.jsonl (279.1MB)
272
+ train.jsonl (2.0GB)
273
+ train_2023-10-06.jsonl (4.1GB)
274
+ val.jsonl (138.9MB)
275
+ instructions.json (2.0KB)
276
+ winoground_test.jsonl (245.5MB)
277
+ instructions.json (1.3KB)
278
+ test.jsonl (122.9MB)
279
+ instructions.json (1.0KB)
280
+ mocheg_test.jsonl (60.3MB)
281
+ mocheg_train.jsonl (631.7MB)
282
+ mocheg_val.jsonl (28.2MB)
283
+ test_2023-10-08.jsonl (242.5MB)
284
+ train_2023-10-08.jsonl (2.5GB)
285
+ instructions.json (1.5KB)
286
+ test.jsonl (701.9MB)
287
+ test_2023-10-08.jsonl (2.7GB)
288
+ train.jsonl (3.9GB)
289
+ train_2023-10-08.jsonl (15.6GB)
290
+ val.jsonl (667.7MB)
291
+ clevr/
292
+ nlvr/
293
+ science_qa/
294
+ vcr/
295
+ visual_mrc/
296
+ instructions.json (2.5KB)
297
+ science_qa_test.jsonl (174.0MB)
298
+ science_qa_train.jsonl (531.3MB)
299
+ science_qa_validation.jsonl (176.4MB)
300
+ instructions.json (976.0B)
301
+ train.jsonl (5.6GB)
302
+ train_2023-10-07.jsonl (11.1GB)
303
+ val.jsonl (379.6MB)
304
+ val_2023-10-07.jsonl (760.4MB)
305
+ instructions.json (911.0B)
306
+ test.jsonl (1.2GB)
307
+ train.jsonl (3.9GB)
308
+ val.jsonl (266.9MB)
309
+ instructions.json (1.3KB)
310
+ test.jsonl (909.3MB)
311
+ train.jsonl (4.3GB)
312
+ val.jsonl (992.9MB)
313
+ instructions.json (1.2KB)
314
+ test.jsonl (489.0MB)
315
+ train.jsonl (7.9GB)
316
+ val.jsonl (533.3MB)
317
+ mmchat/
318
+ multi30k/
319
+ vist/
320
+ visual_dialog/
321
+ instructions.json (818.0B)
322
+ test.jsonl (65.2MB)
323
+ test_2023-10-10.jsonl (262.2MB)
324
+ train.jsonl (3.2GB)
325
+ train_2023-10-09.jsonl (13.0GB)
326
+ val.jsonl (66.0MB)
327
+ instructions.json (1.2KB)
328
+ test.jsonl (610.6MB)
329
+ train.jsonl (4.4GB)
330
+ val.jsonl (301.1MB)
331
+ instructions.json (809.0B)
332
+ test.jsonl (2.3GB)
333
+ train.jsonl (6.2GB)
334
+ train_new.jsonl (6.2GB)
335
+ validation.jsonl (2.0GB)
336
+ instructions.json (1.0KB)
337
+ test.jsonl (14.0GB)
338
+ train.jsonl (15.4GB)
339
+ val.jsonl (13.0GB)
340
+ a-okvqa/
341
+ activitynet-qa/
342
+ docvqa/
343
+ fm-iqa/
344
+ gqa/
345
+ ivqa/
346
+ msrvtt-qa/
347
+ msvd-qa/
348
+ ocr-vqa/
349
+ okvqa/
350
+ shapes/
351
+ st-vqa/
352
+ text-vqa/
353
+ viquae/
354
+ vqav2/
355
+ instruction.json (905.0B)
356
+ train.jsonl (533.5MB)
357
+ train_new.jsonl (533.5MB)
358
+ validation.jsonl (228.3MB)
359
+ instructions.json (1.9KB)
360
+ train.jsonl (1.2GB)
361
+ train_v2.jsonl (1.2GB)
362
+ val.jsonl (77.7MB)
363
+ val_v2.jsonl (78.2MB)
364
+ instruction.json (905.0B)
365
+ test.jsonl (713.3MB)
366
+ train.jsonl (3.3GB)
367
+ validation_new.jsonl (529.5MB)
368
+ instruction.json (772.0B)
369
+ train.jsonl (1.5GB)
370
+ validation.jsonl (260.3MB)
371
+ instruction.json (853.0B)
372
+ test.jsonl (229.4MB)
373
+ train.jsonl (1.4GB)
374
+ README.md (288.0B)
375
+ instructions.json (1.2KB)
376
+ test.jsonl (132.4MB)
377
+ train.jsonl (343.1MB)
378
+ val.jsonl (60.9MB)
379
+ instructions.json (853.0B)
380
+ train.jsonl (1.9GB)
381
+ val.jsonl (1.9GB)
382
+ instructions.json (1.7KB)
383
+ train.jsonl (7.2GB)
384
+ val.jsonl (976.6MB)
385
+ instructions.json (1.5KB)
386
+ test.jsonl (1.4MB)
387
+ test_2023-10-08.jsonl (7.0MB)
388
+ train.large.jsonl (18.3MB)
389
+ train_2023-10-08.jsonl (92.6MB)
390
+ val.jsonl (1.4MB)
391
+ README.md (334.0B)
392
+ instructions.json (1.0KB)
393
+ test.jsonl (500.8MB)
394
+ train.jsonl (1.5GB)
395
+ val.jsonl (485.4MB)
396
+ README.md (434.0B)
397
+ instructions.json (1.0KB)
398
+ test.jsonl (348.1MB)
399
+ train.jsonl (757.5MB)
400
+ val.jsonl (58.0MB)
401
+ .gitattributes (141.0B)
402
+ README.md (332.0B)
403
+ instructions.json (1.4KB)
404
+ test.jsonl (474.7MB)
405
+ train.jsonl (2.1GB)
406
+ val.jsonl (1.1GB)
407
+ instructions.json (1.2KB)
408
+ train.jsonl (594.8MB)
409
+ train_v2.jsonl (596.3MB)
410
+ val.jsonl (334.3MB)
411
+ val_v2.jsonl (335.2MB)
412
+ instructions.json (802.0B)
413
+ para_train.jsonl (10.5GB)
414
+ para_val.jsonl (4.8GB)
415
+ train.jsonl (10.5GB)
416
+ val.jsonl (4.8GB)
417
+ instructions.json (1.2KB)
418
+ test.jsonl (122.5MB)
419
+ test_v2.jsonl (120.9MB)
420
+ train.jsonl (110.1MB)
421
+ train_v2.jsonl (110.2MB)
422
+ validation.jsonl (125.5MB)
423
+ validation_v2.jsonl (125.6MB)
424
+ coco/
425
+ coco-cn/
426
+ flickr8k-cn/
427
+ image_paragraph_captioning/
428
+ msrvtt/
429
+ textcap/
430
+ .gitattributes (141.0B)
431
+ README.md (490.0B)
432
+ instructions.json (1010.0B)
433
+ test.jsonl (117.1MB)
434
+ train.jsonl (231.1MB)
435
+ val.jsonl (116.9MB)
436
+ instructions.json (541.0B)
437
+ test.jsonl (49.4MB)
438
+ train.jsonl (300.0MB)
439
+ val.jsonl (49.9MB)
440
+ instructions.json (790.0B)
441
+ test.jsonl (66.4MB)
442
+ train.jsonl (1.2GB)
443
+ val.jsonl (65.0MB)
444
+ image_paragraph_captioning_test.jsonl (120.7MB)
445
+ image_paragraph_captioning_train.jsonl (701.2MB)
446
+ image_paragraph_captioning_val.jsonl (118.0MB)
447
+ instruction.json (1.4KB)
448
+ README.md (73.0B)
449
+ create_dataset.py (5.5KB)
450
+ instructions.json (882.0B)
451
+ test.jsonl (333.1MB)
452
+ train.jsonl (7.4GB)
453
+ val.jsonl (333.4MB)
454
+ instructions.json (1.1KB)
455
+ train.jsonl (5.7GB)
456
+ val.jsonl (851.3MB)
ICL/RL_DAPO/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
ICL/SFT_new/README.md ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Qwen3-VL-8B Single-Step Decision SFT
2
+
3
+ ## 项目结构
4
+
5
+ ```
6
+ SFT_new/
7
+ ├── build_sft.py # 数据构造 (SigLIP2 相似度选 shots, 单步决策格式)
8
+ ├── generate_captions.py # VLM 批量 caption 生成 (替代短答案作为检索描述)
9
+ ├── train.py # 训练主脚本 (DeepSpeed + Flash Attention 2)
10
+ ├── ds_zero2.json # DeepSpeed ZeRO-2 配置 (推荐, 速度快)
11
+ ├── ds_zero3.json # DeepSpeed ZeRO-3 配置 (备用, 更省显存)
12
+ ├── run_single_node.sh # 单机启动脚本 (debug)
13
+ ├── run_multi_node.sh # 多机训练入口 (每个 node 执行)
14
+ ├── submit_northjob.sh # northjob 集群提交 (64卡)
15
+ ├── launch_wrapper.py # northjob → bash 桥接
16
+ └── README.md # 本文件
17
+ ```
18
+
19
+ ---
20
+
21
+ ## 整体 Pipeline
22
+
23
+ ```
24
+ 原始数据集 (jsonl + 图片)
25
+
26
+
27
+ ┌─────────────────┐
28
+ │ build_sft.py │ --build-cache ← 只跑一次, GPU
29
+ │ SigLIP2 编码 │ 生成 emb_cache/
30
+ └────────┬────────┘
31
+
32
+
33
+ ┌──────────────────────┐
34
+ │ generate_captions.py │ VLM API 批量生成 ← 只跑一次, 无需 GPU
35
+ │ 生成 caption_cache/ │ (vLLM 部署的 Qwen3-VL)
36
+ └────────┬─────────────┘
37
+
38
+
39
+ ┌─────────────────┐
40
+ │ build_sft.py │ 构造 SFT 数据 ← CPU, 可多进程并行
41
+ │ 读取 emb_cache │ 读取 caption_cache
42
+ │ + caption_cache │ 输出 sft.jsonl
43
+ └────────┬────────┘
44
+
45
+
46
+ ┌─────────────────┐
47
+ │ train.py │ DeepSpeed 训练
48
+ └─────────────────┘
49
+ ```
50
+
51
+ ---
52
+
53
+ ## 1. 配环境
54
+
55
+ ```bash
56
+ # 创建 conda 环境
57
+ conda create -n sft python=3.11 -y
58
+ conda activate sft
59
+
60
+ # PyTorch 2.4 + CUDA 12 (匹配 flash-attn whl)
61
+ pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124
62
+
63
+ # Flash Attention 2 (本地 whl, 先试 TRUE 版, 不行换 FALSE 版)
64
+ pip install /workspace/flash_attn-2.8.3+cu12torch2.4cxx11abiTRUE-cp311-cp311-linux_x86_64.whl
65
+ # 如果报 CXX11 ABI 不匹配:
66
+ # pip install /workspace/flash_attn-2.8.3+cu12torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
67
+
68
+ # 核心依赖
69
+ pip install transformers>=4.57.0
70
+ pip install accelerate>=1.13.0
71
+ pip install peft>=0.18.0
72
+ pip install deepspeed>=0.16.0
73
+ pip install qwen-vl-utils
74
+ pip install tqdm pillow
75
+ pip install openai # generate_captions.py 需要
76
+
77
+ # 验证安装
78
+ python -c "
79
+ import torch, transformers, deepspeed, flash_attn, peft
80
+ print(f'torch: {torch.__version__}')
81
+ print(f'transformers: {transformers.__version__}')
82
+ print(f'deepspeed: {deepspeed.__version__}')
83
+ print(f'flash_attn: {flash_attn.__version__}')
84
+ print(f'peft: {peft.__version__}')
85
+ print(f'CUDA: {torch.cuda.is_available()}, {torch.cuda.get_device_name(0)}')
86
+ from transformers import Qwen3VLForConditionalGeneration
87
+ print('Qwen3VL: OK')
88
+ "
89
+ ```
90
+
91
+ **注意**: flash-attn whl 是针对 torch 2.4 编译的, 所以 PyTorch 必须装 2.4.x 版本.
92
+
93
+ ---
94
+
95
+ ## 2. 构造数据
96
+
97
+ ### 2.1 构建 SigLIP embedding 缓存 (只跑一次, GPU)
98
+
99
+ ```bash
100
+ conda activate sft
101
+
102
+ python /workspace/xiaobin/ICL/SFT_new/build_sft.py \
103
+ --build-cache \
104
+ --data-root /path/to/your/dataset \
105
+ --output-dir /workspace/xiaobin/ICL/SFT_new/output \
106
+ --siglip-model /workspace/models/siglip2-so400m-patch16-naflex \
107
+ --device cuda:0 \
108
+ --categories vqa,captioning,classification,reasoning
109
+ ```
110
+
111
+ 缓存保存在 `output/emb_cache/` 下, JSON 格式 (float16 base64), 可跨环境复用.
112
+
113
+ ### 2.2 生成 VLM Caption (只跑一次, 调 API 无需本地 GPU)
114
+
115
+ **为什么需要这一步**: 很多 VQA 数据集的 answer 是短答案 ("yes", "3", "cab"), 不适合做语义检索的 query 描述. 用 VLM 给每张 pool 图片生成描述性 caption, 作为 `<RET>` 输出的 Description 和 context shot 的 Caption, 质量远好于原始 answer.
116
+
117
+ #### 启动 vLLM 服务 (NorthServe)
118
+
119
+ ```bash
120
+ # 启动 Qwen3-VL-8B 推理服务(8 副本,每副本 1 卡)
121
+ HOME=/root /workspace/nex-agi/NorthServe/northserve launch \
122
+ --model-name qwen3vl8b-caption \
123
+ --served-model-name Qwen3-VL-8B-Instruct \
124
+ --namespace bg-agentic-coding \
125
+ --model-path /i_workspace/models/Qwen3-VL-8B-Instruct \
126
+ --volumes "i-xinsiyang-y4zy0sik0a:/i_workspace" \
127
+ --replicas 32 \
128
+ --gpus-per-pod 1 \
129
+ --pods-per-job 1 \
130
+ --profile generation \
131
+ --backend vllm \
132
+ --priority-class-name higher-priority-job \
133
+ --extra-cmds "--trust-remote-code --max-model-len 4096 --max-num-seqs 128" \
134
+ -y
135
+
136
+ # 验证(所有模型共用 http://10.51.6.110/v1,模型名在请求体里指定)
137
+ curl http://10.51.6.110/v1/models
138
+ ```
139
+
140
+ #### 生成 caption (emb_cache 对齐版)
141
+
142
+ ```bash
143
+ python /workspace/xiaobin/ICL/SFT_new/generate_captions.py \
144
+ --api-base http://10.51.6.110/v1 \
145
+ --model Qwen3-VL-8B-Instruct \
146
+ --emb-cache-dir /workspace/xiaobin/ICL/SFT_new/output/emb_cache \
147
+ --output-dir /workspace/xiaobin/ICL/SFT_new/output/caption_cache \
148
+ --num-workers 128 \
149
+ --prompt "Describe this image in one or two sentences. Focus on the main objects, their attributes, and spatial relationships."
150
+ ```
151
+
152
+ #### 生成 caption (全量图片版, 按 split 分开保存)
153
+
154
+ ```bash
155
+ # 全量跑 (~200 万张图)
156
+ python /workspace/xiaobin/ICL/SFT_new/generate_captions_all.py \
157
+ --api-base http://10.51.6.110/v1 \
158
+ --model Qwen3-VL-8B-Instruct \
159
+ --num-workers 128
160
+
161
+ # 只跑某个 category
162
+ python /workspace/xiaobin/ICL/SFT_new/generate_captions_all.py \
163
+ --api-base http://10.51.6.110/v1 \
164
+ --model Qwen3-VL-8B-Instruct \
165
+ --categories vqa \
166
+ --num-workers 128
167
+ ```
168
+
169
+ 输出到 `/workspace/xiaobin/dataset/detail/{category}/{dataset}/{split}/captions.json`
170
+
171
+ #### 停止服务
172
+
173
+ ```bash
174
+ HOME=/root /workspace/nex-agi/NorthServe/northserve stop \
175
+ --model-name qwen3vl8b-caption
176
+ ```
177
+
178
+ **关键特性**:
179
+ - **断点续传**: 已完成的文件自动跳过, 部分完成的只处理缺失图片
180
+ - **定期存盘**: 每 500 张自动保存 (防崩溃丢数据), `--save-every` 可调
181
+ - **并发请求**: `--num-workers 128`, 8 副本理论上限 1024, 不报错就往大了开
182
+
183
+ ### 2.3 构建 SFT 数据集 (CPU, 不需要 GPU, 可多进程并行)
184
+
185
+ ```bash
186
+ # 单进程
187
+ python /workspace/xiaobin/ICL/SFT_new/build_sft.py \
188
+ --data-root /path/to/your/dataset \
189
+ --output-dir /workspace/xiaobin/ICL/SFT_new/output \
190
+ --caption-cache-dir /workspace/xiaobin/ICL/SFT_new/output/caption_cache \
191
+ --samples-per-cat 20000 \
192
+ --max-shots 3 \
193
+ --answer-at-weights 3,3,2,1
194
+
195
+ # 多进程并行 (4 shards)
196
+ for i in 0 1 2 3; do
197
+ python /workspace/xiaobin/ICL/SFT_new/build_sft.py \
198
+ --data-root /path/to/your/dataset \
199
+ --output-dir /workspace/xiaobin/ICL/SFT_new/output \
200
+ --caption-cache-dir /workspace/xiaobin/ICL/SFT_new/output/caption_cache \
201
+ --shard-id $i --num-shards 4 &
202
+ done
203
+ wait
204
+
205
+ # 合并
206
+ python /workspace/xiaobin/ICL/SFT_new/build_sft.py \
207
+ --data-root /path/to/your/dataset \
208
+ --output-dir /workspace/xiaobin/ICL/SFT_new/output \
209
+ --merge --shuffle
210
+ ```
211
+
212
+ **注意**: `--caption-cache-dir` 不传或目录不存在时行为和之前完全一致(用原始 answer)。正式训练前务必先跑 `generate_captions.py` 生成完整的 caption cache。
213
+
214
+ 最终数据: `output/all/sft.jsonl`
215
+
216
+ **生成数据中的描述字段变化**:
217
+ ```
218
+ # 之前 (用原始 answer, 短答案质量差)
219
+ {"from": "gpt", "value": "<RET>\nDescription: yes"}
220
+ {"from": "human", "value": "...<image>\nCaption: yes..."}
221
+
222
+ # 现在 (用 VLM 生成的描述, 适合语义检索)
223
+ {"from": "gpt", "value": "<RET>\nDescription: A woman cutting a large white cake in a kitchen."}
224
+ {"from": "human", "value": "...<image>\nCaption: A woman cutting a large white cake in a kitchen...."}
225
+ ```
226
+
227
+ ---
228
+
229
+ ## 3. 训练
230
+
231
+ ### 3.1 单机 debug (1 node x 8 H100)
232
+
233
+ ```bash
234
+ conda activate sft
235
+
236
+ bash /workspace/xiaobin/ICL/SFT_new/run_single_node.sh \
237
+ /workspace/xiaobin/ICL/SFT_new/output/all/sft.jsonl \
238
+ 8
239
+ ```
240
+
241
+ 可改 GPU 数快速 debug:
242
+ ```bash
243
+ # 用 2 卡 debug
244
+ bash /workspace/xiaobin/ICL/SFT_new/run_single_node.sh /path/to/sft.jsonl 2
245
+ ```
246
+
247
+ ### 3.2 多机训练 (8 nodes x 8 GPUs = 64 H100)
248
+
249
+ **方式 A: northjob 提交 (推荐)**
250
+
251
+ 先修改 `submit_northjob.sh` 里的 k8s 参数 (queue/namespace/pvc-name 改成你自己的), 然后:
252
+
253
+ ```bash
254
+ bash /workspace/xiaobin/ICL/SFT_new/submit_northjob.sh 64 # 64卡
255
+ bash /workspace/xiaobin/ICL/SFT_new/submit_northjob.sh 32 # 32卡
256
+ ```
257
+
258
+ **方式 B: 手动 torchrun (每个 node 上跑)**
259
+
260
+ ```bash
261
+ # 在每个 node 上执行, 修改 --node_rank=0/1/2/.../7
262
+ torchrun \
263
+ --nproc_per_node=8 \
264
+ --nnodes=8 \
265
+ --node_rank=${NODE_RANK} \
266
+ --master_addr=${MASTER_ADDR} \
267
+ --master_port=29500 \
268
+ /workspace/xiaobin/ICL/SFT_new/train.py \
269
+ --model-path /workspace/models/Qwen3-VL-8B-Instruct \
270
+ --data-path /workspace/xiaobin/ICL/SFT_new/output/all/sft.jsonl \
271
+ --output-dir /workspace/xiaobin/ICL/SFT_new/output/qwen3vl_sft_64gpu \
272
+ --deepspeed /workspace/xiaobin/ICL/SFT_new/ds_zero2.json \
273
+ --num-epochs 3 \
274
+ --batch-size 1 \
275
+ --gradient-accumulation-steps 2 \
276
+ --learning-rate 2e-5
277
+ ```
278
+
279
+ ---
280
+
281
+ ## 4. 训练策略说明
282
+
283
+ | 配置 | 单机 8 GPU (debug) | 64 GPU (正式) |
284
+ |------|-------------------|---------------|
285
+ | 并行 | DeepSpeed ZeRO-2 | DeepSpeed ZeRO-2 |
286
+ | micro_batch/GPU | 1 | 1 |
287
+ | grad_accum | 8 | 2 |
288
+ | **global_batch** | **64** | **128** |
289
+ | LR | 1e-5 | 2e-5 |
290
+ | Epochs | 3 | 3 |
291
+ | max_length | 4096 | 4096 |
292
+ | 精度 | BF16 | BF16 |
293
+ | Attention | Flash Attention 2 | Flash Attention 2 |
294
+ | Gradient ckpt | yes | yes |
295
+ | 训��方式 | Full fine-tuning | Full fine-tuning |
296
+
297
+ **为什么 ZeRO-2**: 8B 模型 BF16 约 16GB, H100 80GB 绰绰有余, ZeRO-2 比 ZeRO-3 快 30-40%.
298
+
299
+ **为什么 Full FT**: 任务需要学 `<RET>/<ANS>` 新 token + 新决策能力, LoRA 对 embedding 层学习有限. 加 `--use-lora` 可切换.
300
+
301
+ **Loss**: 只在 assistant turn 内容上计算, user turn 全部 mask (-100).
302
+
303
+ ---
304
+
305
+ ## 5. 关键参数调整
306
+
307
+ ```bash
308
+ # 如果显存不够 → 降 max_pixels 或切 ZeRO-3
309
+ --max-pixels $((512*28*28)) # 减少图片分辨率
310
+ --deepspeed ds_zero3.json # 切 ZeRO-3
311
+
312
+ # 如果想用 LoRA (省显存, 快, 但效果可能差一点)
313
+ --use-lora --lora-rank 64 --lora-alpha 128
314
+
315
+ # 调整 n-shot 分布 (answer_at_weights)
316
+ --answer-at-weights 3,3,2,1 # 偏向少 shot (默认)
317
+ --answer-at-weights 1,1,1,1 # 均匀分布
318
+ --answer-at-weights 1,2,3,3 # 偏向多 shot
319
+ ```
320
+
321
+ ---
322
+
323
+ ## 6. 输出目录结构
324
+
325
+ ```
326
+ output/
327
+ ├── emb_cache/ # SigLIP2 embedding 缓存
328
+ │ ├── vqa_vqav2.json
329
+ │ ├── vqa_okvqa.json
330
+ │ └── ...
331
+ ├── caption_cache/ # VLM 生成的 caption 缓存
332
+ │ ├── vqa_vqav2.json
333
+ │ ├── vqa_okvqa.json
334
+ │ └── ...
335
+ ├── vqa/
336
+ │ ├── sft.part00.jsonl # 分片
337
+ │ └── sft.jsonl # 合并后
338
+ ├── captioning/
339
+ │ └── ...
340
+ ├── classification/
341
+ │ └── ...
342
+ ├── reasoning/
343
+ │ └── ...
344
+ └── all/
345
+ └── sft.jsonl # 全部合并 + shuffle, 训练用这个
346
+ ```
347
+
348
+ ---
349
+
350
+ ## 7. 快速验证 (小规模测试)
351
+
352
+ ```bash
353
+ # Step 1: 建 embedding cache
354
+ python build_sft.py --build-cache --data-root /path/to/data \
355
+ --categories vqa --device cuda:0
356
+
357
+ # Step 2: 生成 VLM caption (先小规模测试)
358
+ python generate_captions.py \
359
+ --api-base http://10.51.6.110/v1 \
360
+ --model Qwen3-VL-8B-Instruct \
361
+ --emb-cache-dir ./output/emb_cache \
362
+ --output-dir ./output/caption_cache \
363
+ --num-workers 128 --save-every 50
364
+
365
+ # Step 3: 检查 caption 质量
366
+ python -c "
367
+ import json
368
+ d = json.load(open('./output/caption_cache/vqa_vqav2.json'))
369
+ for k, v in list(d['items'].items())[:10]:
370
+ print(f'{k}\n → {v}\n')
371
+ "
372
+
373
+ # Step 4: 构造 SFT 数据 (100 条快速测试)
374
+ python build_sft.py --data-root /path/to/data \
375
+ --caption-cache-dir ./output/caption_cache \
376
+ --categories vqa --samples-per-cat 100
377
+
378
+ # Step 5: 检查生成结果
379
+ python -c "
380
+ import json
381
+ with open('./output/vqa/sft.part00.jsonl') as f:
382
+ for i, line in enumerate(f):
383
+ if i >= 5: break
384
+ r = json.loads(line)
385
+ for c in r['conversations']:
386
+ print(f'[{c[\"from\"]}] {c[\"value\"][:120]}')
387
+ print('---')
388
+ "
389
+ ```
ICL/SFT_new/convert_and_eval.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # =============================================================================
3
+ # DeepSpeed ZeRO checkpoint -> HuggingFace 格式转换 + 跑评测
4
+ #
5
+ # 用法:
6
+ # bash convert_and_eval.sh # 转换 epoch3_step1406,8卡评测
7
+ # bash convert_and_eval.sh final # 转换 final checkpoint
8
+ # bash convert_and_eval.sh epoch2_step937 # 转换指定 checkpoint
9
+ # NUM_GPUS=4 bash convert_and_eval.sh # 4卡评测
10
+ # SKIP_EVAL=1 bash convert_and_eval.sh # 只转换不评测
11
+ # =============================================================================
12
+ set -euo pipefail
13
+
14
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
15
+
16
+ # ---- 参数 ----
17
+ CKPT_TAG="${1:-epoch3_step1406}"
18
+ CKPT_DIR="/workspace/xiaobin/ICL/sft_model"
19
+ BASE_MODEL="/workspace/models/Qwen3-VL-8B-Instruct"
20
+ OUTPUT_DIR="${CKPT_DIR}/${CKPT_TAG}_fp32"
21
+ NUM_GPUS="${NUM_GPUS:-8}"
22
+ BATCH_SIZE="${BATCH_SIZE:-32}"
23
+ SKIP_EVAL="${SKIP_EVAL:-0}"
24
+
25
+ echo "============================================"
26
+ echo " Checkpoint: ${CKPT_TAG}"
27
+ echo " Source: ${CKPT_DIR}/${CKPT_TAG}"
28
+ echo " Output: ${OUTPUT_DIR}"
29
+ echo " Base model: ${BASE_MODEL}"
30
+ echo "============================================"
31
+
32
+ # ---- Step 1: 检查源 checkpoint 存在 ----
33
+ if [ ! -d "${CKPT_DIR}/${CKPT_TAG}" ]; then
34
+ echo "[ERROR] Checkpoint not found: ${CKPT_DIR}/${CKPT_TAG}"
35
+ echo "Available checkpoints:"
36
+ ls -d "${CKPT_DIR}"/epoch* "${CKPT_DIR}"/final 2>/dev/null || echo " (none)"
37
+ exit 1
38
+ fi
39
+
40
+ # ---- Step 2: 转换 DeepSpeed ZeRO -> fp32 ----
41
+ if [ -d "${OUTPUT_DIR}" ] && [ "$(ls -A "${OUTPUT_DIR}" 2>/dev/null)" ]; then
42
+ echo "[SKIP] ${OUTPUT_DIR} already exists, skipping conversion."
43
+ echo " Delete it if you want to re-convert."
44
+ else
45
+ echo "[1/3] Converting DeepSpeed ZeRO checkpoint to fp32..."
46
+ mkdir -p "${OUTPUT_DIR}"
47
+ python3 "${CKPT_DIR}/zero_to_fp32.py" \
48
+ "${CKPT_DIR}" \
49
+ "${OUTPUT_DIR}" \
50
+ --tag "${CKPT_TAG}" \
51
+ --safe_serialization
52
+ echo "Done."
53
+ fi
54
+
55
+ # ---- Step 3: 拷贝 config / tokenizer ----
56
+ echo "[2/3] Copying config & tokenizer from base model..."
57
+ FILES_TO_COPY=(
58
+ config.json
59
+ tokenizer.json
60
+ tokenizer_config.json
61
+ generation_config.json
62
+ preprocessor_config.json
63
+ video_preprocessor_config.json
64
+ special_tokens_map.json
65
+ chat_template.json
66
+ merges.txt
67
+ vocab.json
68
+ )
69
+ copied=0
70
+ for f in "${FILES_TO_COPY[@]}"; do
71
+ if [ -f "${BASE_MODEL}/${f}" ] && [ ! -f "${OUTPUT_DIR}/${f}" ]; then
72
+ cp "${BASE_MODEL}/${f}" "${OUTPUT_DIR}/"
73
+ copied=$((copied + 1))
74
+ fi
75
+ done
76
+ echo "Copied ${copied} files. Model ready at: ${OUTPUT_DIR}"
77
+
78
+ # ---- Step 4: 跑评测 ----
79
+ if [ "${SKIP_EVAL}" = "1" ]; then
80
+ echo "[3/3] SKIP_EVAL=1, skipping evaluation."
81
+ echo "To run eval manually:"
82
+ echo " MODEL_PATH=${OUTPUT_DIR} BATCH_SIZE=${BATCH_SIZE} bash ${SCRIPT_DIR}/run_eval.sh ${NUM_GPUS}"
83
+ exit 0
84
+ fi
85
+
86
+ echo "[3/3] Running evaluation (${NUM_GPUS} GPUs, batch_size=${BATCH_SIZE})..."
87
+ MODEL_PATH="${OUTPUT_DIR}" BATCH_SIZE="${BATCH_SIZE}" bash "${SCRIPT_DIR}/run_eval.sh" "${NUM_GPUS}"
ICL/SFT_new/ds_zero2.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 2,
7
+ "overlap_comm": true,
8
+ "contiguous_gradients": true,
9
+ "reduce_scatter": true,
10
+ "reduce_bucket_size": 5e8,
11
+ "allgather_bucket_size": 5e8
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": 1e-6,
17
+ "betas": [0.9, 0.999],
18
+ "eps": 1e-8,
19
+ "weight_decay": 0.1
20
+ }
21
+ },
22
+ "scheduler": {
23
+ "type": "WarmupDecayLR",
24
+ "params": {
25
+ "warmup_min_lr": 0,
26
+ "warmup_max_lr": 1e-6,
27
+ "warmup_num_steps": 50,
28
+ "total_num_steps": 950
29
+ }
30
+ },
31
+ "gradient_accumulation_steps": 4,
32
+ "gradient_clipping": 1.0,
33
+ "train_batch_size": 64,
34
+ "train_micro_batch_size_per_gpu": 2,
35
+ "wall_clock_breakdown": false,
36
+ "steps_per_print": 50
37
+ }
ICL/SFT_new/ds_zero3.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 3,
7
+ "overlap_comm": true,
8
+ "contiguous_gradients": true,
9
+ "reduce_bucket_size": 5e8,
10
+ "stage3_prefetch_bucket_size": 5e8,
11
+ "stage3_param_persistence_threshold": 1e6,
12
+ "stage3_gather_16bit_weights_on_model_save": true
13
+ },
14
+ "optimizer": {
15
+ "type": "AdamW",
16
+ "params": {
17
+ "lr": 1e-5,
18
+ "betas": [0.9, 0.999],
19
+ "eps": 1e-8,
20
+ "weight_decay": 0.1
21
+ }
22
+ },
23
+ "gradient_accumulation_steps": 4,
24
+ "gradient_clipping": 1.0,
25
+ "train_micro_batch_size_per_gpu": 2,
26
+ "wall_clock_breakdown": false,
27
+ "steps_per_print": 50
28
+ }
ICL/SFT_new/eval.py ADDED
@@ -0,0 +1,961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ ICL 多轮推理评测脚本:模拟 RET/ANS 决策循环,验证 SFT 模型效果。
5
+
6
+ 流程:
7
+ 1. 从 source index 的 val split 加载原始记录(与训练集无重叠)
8
+ 2. 给模型 query_image + question(0-shot)
9
+ 3. 模型输出 <RET> → 从预计算 top5 取下一张 shot + caption,追加 context,再问
10
+ 4. 模型输出 <ANS> → 提取答案,结束
11
+ 5. 最多 max_rounds 轮(防止死循环 RET)
12
+
13
+ 多卡策略:
14
+ - 每张 GPU 加载一份模型,按 dataset 粒度分配任务
15
+ - 只有 rank 0 打印进度日志(其他 rank 静默)
16
+ - 最后 rank 0 汇总并写出有序 JSON log
17
+
18
+ 用法:
19
+ # 单卡 (debug)
20
+ python3 eval.py \\
21
+ --model-path /workspace/xiaobin/ICL/sft_model/merged_hf \\
22
+ --category vqa --dataset vqav2 --split val \\
23
+ --num-samples 20 --device cuda:0
24
+
25
+ # 多卡
26
+ torchrun --nproc_per_node=8 eval.py \\
27
+ --model-path /workspace/xiaobin/ICL/sft_model/merged_hf \\
28
+ --all-categories --split val --num-samples 200
29
+ """
30
+
31
+ import argparse
32
+ import json
33
+ import math
34
+ import os
35
+ import random
36
+ import re
37
+ import sys
38
+ import time
39
+ from collections import defaultdict
40
+ from pathlib import Path
41
+ from typing import Dict, List, Optional, Tuple
42
+
43
+ import torch
44
+ import torch.distributed as dist
45
+
46
+ # 绕过 transformers 对 torch<2.6 的 torch.load 安全检查 (CVE-2025-32434)
47
+ # 在 import transformers 之前 patch modeling_utils.load_state_dict
48
+ import transformers.utils.import_utils as _tu
49
+ if hasattr(_tu, "check_torch_load_is_safe"):
50
+ _tu.check_torch_load_is_safe = lambda: None
51
+ import transformers.modeling_utils as _mu
52
+ if hasattr(_mu, "check_torch_load_is_safe"):
53
+ _mu.check_torch_load_is_safe = lambda: None
54
+ # 直接 patch load_state_dict 里调用的那个
55
+ _orig_load_state_dict = getattr(_mu, "load_state_dict", None)
56
+ if _orig_load_state_dict is not None:
57
+ import functools
58
+ @functools.wraps(_orig_load_state_dict)
59
+ def _patched_load_state_dict(checkpoint_file, **kwargs):
60
+ # 直接用 torch.load 跳过安全检查
61
+ return torch.load(checkpoint_file, map_location="cpu", weights_only=False)
62
+ _mu.load_state_dict = _patched_load_state_dict
63
+
64
+ from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
65
+ from qwen_vl_utils import process_vision_info
66
+ from tqdm import tqdm
67
+
68
+ # ---------------------------------------------------------------------------
69
+ # 默认路径
70
+ # ---------------------------------------------------------------------------
71
+ INDEX_ROOT = "/workspace/xiaobin/dataset/index"
72
+ EMBEDDINGS_DIR = "/workspace/xiaobin/dataset/embeddings"
73
+ CAPTION_CACHE_DIR = "/workspace/xiaobin/dataset/caption_cache"
74
+
75
+ # ---------------------------------------------------------------------------
76
+ # 分布式工具
77
+ # ---------------------------------------------------------------------------
78
+
79
+ def setup_distributed():
80
+ """初始化分布式环境,返回 (rank, world_size, device)。"""
81
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
82
+ rank = int(os.environ["RANK"])
83
+ world_size = int(os.environ["WORLD_SIZE"])
84
+ local_rank = int(os.environ.get("LOCAL_RANK", rank))
85
+ dist.init_process_group("nccl")
86
+ torch.cuda.set_device(local_rank)
87
+ device = f"cuda:{local_rank}"
88
+ else:
89
+ rank, world_size = 0, 1
90
+ device = None
91
+ return rank, world_size, device
92
+
93
+
94
+ def gather_results(local_results: List[Dict], rank: int, world_size: int) -> List[Dict]:
95
+ """各 rank 结果汇总到 rank 0。"""
96
+ if world_size == 1:
97
+ return local_results
98
+
99
+ data = json.dumps(local_results, ensure_ascii=False).encode("utf-8")
100
+ size = torch.tensor([len(data)], dtype=torch.long, device=f"cuda:{rank}")
101
+
102
+ size_list = [torch.zeros(1, dtype=torch.long, device=f"cuda:{rank}") for _ in range(world_size)]
103
+ dist.all_gather(size_list, size)
104
+ max_size = max(s.item() for s in size_list)
105
+
106
+ padded = data + b"\x00" * (max_size - len(data))
107
+ tensor = torch.ByteTensor(list(padded)).cuda(rank)
108
+ tensor_list = [torch.zeros(max_size, dtype=torch.uint8, device=f"cuda:{rank}") for _ in range(world_size)]
109
+ dist.all_gather(tensor_list, tensor)
110
+
111
+ if rank == 0:
112
+ all_results = []
113
+ for t, s in zip(tensor_list, size_list):
114
+ raw = bytes(t[: s.item()].cpu().tolist())
115
+ all_results.extend(json.loads(raw.decode("utf-8")))
116
+ return all_results
117
+ return []
118
+
119
+
120
+ def log(msg: str, rank: int = 0, force: bool = False):
121
+ """只在 rank 0 或 force=True 时打印。"""
122
+ if rank == 0 or force:
123
+ print(msg, flush=True)
124
+
125
+
126
+ # ---------------------------------------------------------------------------
127
+ # 数据加载
128
+ # ---------------------------------------------------------------------------
129
+
130
+ def load_records(cat: str, ds: str, split: str, limit: int = 0) -> List[Dict]:
131
+ """从 index root 加载指定 split 的记录。"""
132
+ path = os.path.join(INDEX_ROOT, cat, ds, f"{split}.jsonl")
133
+ if not os.path.exists(path):
134
+ return []
135
+ records = []
136
+ with open(path, "r", encoding="utf-8") as f:
137
+ for line in f:
138
+ line = line.strip()
139
+ if not line:
140
+ continue
141
+ r = json.loads(line)
142
+ if r.get("image") and r.get("answer"):
143
+ records.append(r)
144
+ if limit and len(records) >= limit:
145
+ break
146
+ return records
147
+
148
+
149
+ def load_top5(cat: str, ds: str) -> Dict[str, List[str]]:
150
+ path = os.path.join(EMBEDDINGS_DIR, f"{cat}_{ds}_top5.json")
151
+ if not os.path.exists(path):
152
+ return {}
153
+ with open(path, "r", encoding="utf-8") as f:
154
+ return json.load(f)
155
+
156
+
157
+ def load_caption_cache(cat: str, ds: str) -> Dict[str, str]:
158
+ path = os.path.join(CAPTION_CACHE_DIR, f"{cat}_{ds}.json")
159
+ if not os.path.exists(path):
160
+ return {}
161
+ with open(path, "r", encoding="utf-8") as f:
162
+ data = json.load(f)
163
+ if isinstance(data, dict) and "items" in data:
164
+ return data["items"]
165
+ return data if isinstance(data, dict) else {}
166
+
167
+
168
+ def load_instructions(cat: str, ds: str) -> List[str]:
169
+ path = os.path.join(INDEX_ROOT, cat, ds, "instructions.json")
170
+ if not os.path.exists(path):
171
+ return ["Look at the image and answer the question."]
172
+ with open(path, "r", encoding="utf-8") as f:
173
+ data = json.load(f)
174
+ if isinstance(data, list):
175
+ return [str(x).strip() for x in data if str(x).strip()]
176
+ if isinstance(data, dict):
177
+ for key in ("instructions", "instruction", "prompts"):
178
+ v = data.get(key)
179
+ if isinstance(v, list):
180
+ return [str(x).strip() for x in v if str(x).strip()]
181
+ return ["Look at the image and answer the question."]
182
+
183
+
184
+ def discover_datasets(categories: List[str]) -> List[Tuple[str, str]]:
185
+ results = []
186
+ for cat in sorted(os.listdir(INDEX_ROOT)):
187
+ if categories and cat not in categories:
188
+ continue
189
+ cat_dir = os.path.join(INDEX_ROOT, cat)
190
+ if not os.path.isdir(cat_dir):
191
+ continue
192
+ for ds in sorted(os.listdir(cat_dir)):
193
+ if os.path.isdir(os.path.join(cat_dir, ds)):
194
+ results.append((cat, ds))
195
+ return results
196
+
197
+
198
+ # ---------------------------------------------------------------------------
199
+ # 模型加载
200
+ # ---------------------------------------------------------------------------
201
+
202
+ def load_model(model_path: str, device: str):
203
+ from transformers import AutoConfig
204
+
205
+ processor = AutoProcessor.from_pretrained(
206
+ model_path,
207
+ trust_remote_code=True,
208
+ min_pixels=256 * 28 * 28,
209
+ max_pixels=1280 * 28 * 28,
210
+ )
211
+
212
+ # 先添加 special tokens 到 tokenizer,这样 vocab_size 对齐 checkpoint
213
+ special_tokens = ["<RET>", "<ANS>", "</ANS>", "<RETQ>", "</RETQ>"]
214
+ processor.tokenizer.add_tokens(special_tokens, special_tokens=True)
215
+ # batch 推理 decoder-only 模型必须左 padding
216
+ processor.tokenizer.padding_side = "left"
217
+ target_vocab_size = len(processor.tokenizer)
218
+
219
+ # 关键:把 config 的 vocab_size 改成 checkpoint 实际大小,
220
+ # 否则 ignore_mismatched_sizes 会导致 embed_tokens/lm_head 被随机初始化!
221
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
222
+ print(f"[load_model] text_config.vocab_size={config.text_config.vocab_size}, target={target_vocab_size}")
223
+ config.text_config.vocab_size = target_vocab_size
224
+
225
+ model = Qwen3VLForConditionalGeneration.from_pretrained(
226
+ model_path,
227
+ config=config,
228
+ trust_remote_code=True,
229
+ torch_dtype=torch.bfloat16,
230
+ attn_implementation="sdpa",
231
+ device_map=device,
232
+ )
233
+
234
+ model.eval()
235
+
236
+ ret_id = processor.tokenizer.convert_tokens_to_ids("<RET>")
237
+ ans_id = processor.tokenizer.convert_tokens_to_ids("<ANS>")
238
+ return model, processor, ret_id, ans_id
239
+
240
+
241
+ # ---------------------------------------------------------------------------
242
+ # 推理核心
243
+ # ---------------------------------------------------------------------------
244
+
245
+ def build_messages(
246
+ instruction: str,
247
+ query_image: str,
248
+ question: Optional[str],
249
+ shots: List[Dict],
250
+ min_pixels: int = 256 * 28 * 28,
251
+ max_pixels: int = 1280 * 28 * 28,
252
+ ) -> List[Dict]:
253
+ """构建 Qwen3-VL chat messages。"""
254
+ user_content = []
255
+
256
+ if instruction:
257
+ user_content.append({"type": "text", "text": instruction})
258
+
259
+ user_content.append({
260
+ "type": "image",
261
+ "image": f"file://{query_image}",
262
+ "min_pixels": min_pixels,
263
+ "max_pixels": max_pixels,
264
+ })
265
+
266
+ if question:
267
+ user_content.append({"type": "text", "text": f"Question: {question}"})
268
+
269
+ for shot in shots:
270
+ user_content.append({
271
+ "type": "image",
272
+ "image": f"file://{shot['image']}",
273
+ "min_pixels": min_pixels,
274
+ "max_pixels": max_pixels,
275
+ })
276
+ if shot.get("caption"):
277
+ user_content.append({"type": "text", "text": f"Caption: {shot['caption']}"})
278
+
279
+ user_content.append({"type": "text", "text": "Action:"})
280
+ return [{"role": "user", "content": user_content}]
281
+
282
+
283
+ @torch.no_grad()
284
+ def generate_action(model, processor, messages: List[Dict], max_new_tokens: int = 256) -> str:
285
+ """单条推理(fallback 用)。"""
286
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
287
+
288
+ image_inputs = None
289
+ try:
290
+ image_inputs, _ = process_vision_info(messages)
291
+ except Exception:
292
+ pass
293
+
294
+ inputs = processor(
295
+ text=[text],
296
+ images=image_inputs if image_inputs else None,
297
+ return_tensors="pt",
298
+ padding=False,
299
+ truncation=False,
300
+ )
301
+
302
+ device = next(model.parameters()).device
303
+ inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
304
+
305
+ outputs = model.generate(
306
+ **inputs,
307
+ max_new_tokens=max_new_tokens,
308
+ do_sample=False,
309
+ temperature=None,
310
+ top_p=None,
311
+ )
312
+
313
+ input_len = inputs["input_ids"].shape[1]
314
+ generated = outputs[0][input_len:]
315
+ return processor.tokenizer.decode(generated, skip_special_tokens=False)
316
+
317
+
318
+ @torch.no_grad()
319
+ def generate_action_batch(
320
+ model, processor, messages_list: List[List[Dict]],
321
+ max_new_tokens: int = 256, batch_size: int = 4,
322
+ pbar=None,
323
+ ) -> List[str]:
324
+ """批量推理,按 batch_size 分批处理。每个 batch 完成后更新 pbar。"""
325
+ all_results = []
326
+ device = next(model.parameters()).device
327
+
328
+ for start in range(0, len(messages_list), batch_size):
329
+ batch_msgs = messages_list[start : start + batch_size]
330
+
331
+ texts = []
332
+ all_images_nested = [] # 嵌套 list: [[sample0 imgs], [sample1 imgs], ...]
333
+ has_any_image = False
334
+ for msgs in batch_msgs:
335
+ texts.append(processor.apply_chat_template(
336
+ msgs, tokenize=False, add_generation_prompt=True
337
+ ))
338
+ try:
339
+ imgs, _ = process_vision_info(msgs)
340
+ if imgs:
341
+ all_images_nested.append(imgs)
342
+ has_any_image = True
343
+ else:
344
+ all_images_nested.append([])
345
+ except Exception:
346
+ all_images_nested.append([])
347
+
348
+ inputs = processor(
349
+ text=texts,
350
+ images=all_images_nested if has_any_image else None,
351
+ return_tensors="pt",
352
+ padding=True,
353
+ truncation=False,
354
+ )
355
+
356
+ inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
357
+
358
+ outputs = model.generate(
359
+ **inputs,
360
+ max_new_tokens=max_new_tokens,
361
+ do_sample=False,
362
+ temperature=None,
363
+ top_p=None,
364
+ )
365
+
366
+ # 解码每条(左 padding 时,所有样本的 padded 输入长度相同)
367
+ input_len = inputs["input_ids"].shape[1]
368
+ for i in range(len(batch_msgs)):
369
+ generated = outputs[i][input_len:]
370
+ text = processor.tokenizer.decode(generated, skip_special_tokens=False)
371
+ all_results.append(text)
372
+
373
+ # 每个 batch 完成后更新进度条
374
+ if pbar is not None:
375
+ pbar.set_postfix_str(f"batch {start // batch_size + 1}/{math.ceil(len(messages_list) / batch_size)}")
376
+
377
+ return all_results
378
+
379
+
380
+ def parse_action(text: str) -> Tuple[str, str]:
381
+ """解析模型输出,返回 (action, content)。"""
382
+ text = text.strip()
383
+
384
+ if text.startswith("<RET>"):
385
+ desc = text[len("<RET>"):].strip()
386
+ if desc.startswith("Description:"):
387
+ desc = desc[len("Description:"):].strip()
388
+ for tok in ["<|im_end|>", "</s>", "<|endoftext|>"]:
389
+ desc = desc.replace(tok, "").strip()
390
+ return "ret", desc
391
+
392
+ if text.startswith("<ANS>"):
393
+ ans = text[len("<ANS>"):]
394
+ end_idx = ans.find("</ANS>")
395
+ if end_idx != -1:
396
+ ans = ans[:end_idx]
397
+ else:
398
+ for tok in ["<|im_end|>", "</s>", "<|endoftext|>"]:
399
+ ans = ans.replace(tok, "").strip()
400
+ return "ans", ans.strip()
401
+
402
+ return "unknown", text
403
+
404
+
405
+ def run_icl_loop(
406
+ model,
407
+ processor,
408
+ record: Dict,
409
+ instruction: str,
410
+ top5: Dict[str, List[str]],
411
+ caption_cache: Dict[str, str],
412
+ max_rounds: int = 4,
413
+ ) -> Dict:
414
+ """对单条记录跑多轮 RET/ANS 循环(fallback 用)。"""
415
+ query_image = record["image"]
416
+ question = record.get("question", "")
417
+ gt_answer = record.get("answer", "")
418
+
419
+ shots = []
420
+ used_images = {query_image}
421
+ rounds = []
422
+ candidates = top5.get(query_image, [])
423
+
424
+ for round_idx in range(max_rounds):
425
+ messages = build_messages(instruction, query_image, question, shots)
426
+ raw_output = generate_action(model, processor, messages)
427
+ action, content = parse_action(raw_output)
428
+
429
+ rounds.append({
430
+ "round": round_idx,
431
+ "action": action,
432
+ "content": content,
433
+ "raw": raw_output[:300],
434
+ })
435
+
436
+ if action == "ans":
437
+ return {
438
+ "image": query_image,
439
+ "question": question,
440
+ "gt_answer": gt_answer,
441
+ "final_answer": content,
442
+ "num_rounds": round_idx + 1,
443
+ "terminated_by": "ans",
444
+ "rounds": rounds,
445
+ }
446
+
447
+ if action == "ret":
448
+ next_image = None
449
+ for c in candidates:
450
+ if c not in used_images:
451
+ next_image = c
452
+ break
453
+
454
+ if next_image is None:
455
+ return {
456
+ "image": query_image,
457
+ "question": question,
458
+ "gt_answer": gt_answer,
459
+ "final_answer": None,
460
+ "num_rounds": round_idx + 1,
461
+ "terminated_by": "no_more_shots",
462
+ "rounds": rounds,
463
+ }
464
+
465
+ cap = caption_cache.get(next_image, content)
466
+ shots.append({"image": next_image, "caption": cap})
467
+ used_images.add(next_image)
468
+ else:
469
+ return {
470
+ "image": query_image,
471
+ "question": question,
472
+ "gt_answer": gt_answer,
473
+ "final_answer": content,
474
+ "num_rounds": round_idx + 1,
475
+ "terminated_by": "unknown_action",
476
+ "rounds": rounds,
477
+ }
478
+
479
+ return {
480
+ "image": query_image,
481
+ "question": question,
482
+ "gt_answer": gt_answer,
483
+ "final_answer": None,
484
+ "num_rounds": max_rounds,
485
+ "terminated_by": "max_rounds",
486
+ "rounds": rounds,
487
+ }
488
+
489
+
490
+ def run_icl_batch(
491
+ model, processor,
492
+ records: List[Dict],
493
+ instructions: List[str],
494
+ top5: Dict[str, List[str]],
495
+ caption_cache: Dict[str, str],
496
+ max_rounds: int = 4,
497
+ batch_size: int = 4,
498
+ rank: int = 0,
499
+ ds_label: str = "",
500
+ ) -> List[Dict]:
501
+ """对一批记录做 round-parallel 的批量 ICL 推理。
502
+
503
+ Round 0: 所有样本 batch 推理
504
+ Round 1: RET 的样本加 shot 后 batch 推理
505
+ ...直到全部完成或 max_rounds
506
+ """
507
+ rng = random.Random(42)
508
+
509
+ # 初始化每条样本的状态
510
+ states = []
511
+ for rec in records:
512
+ states.append({
513
+ "record": rec,
514
+ "instruction": rng.choice(instructions),
515
+ "query_image": rec["image"],
516
+ "question": rec.get("question", ""),
517
+ "gt_answer": rec.get("answer", ""),
518
+ "shots": [],
519
+ "used_images": {rec["image"]},
520
+ "candidates": top5.get(rec["image"], []),
521
+ "rounds": [],
522
+ "done": False,
523
+ "result": None,
524
+ })
525
+
526
+ total = len(states)
527
+ pbar = tqdm(total=total, desc=f" {ds_label}", unit="done",
528
+ disable=(rank != 0))
529
+
530
+ for round_idx in range(max_rounds):
531
+ # 收集未完成的样本
532
+ active = [(i, s) for i, s in enumerate(states) if not s["done"]]
533
+ if not active:
534
+ break
535
+
536
+ n_active = len(active)
537
+ pbar.set_postfix(round=round_idx, active=n_active)
538
+
539
+ # 构建 messages
540
+ messages_list = []
541
+ active_indices = []
542
+ for i, s in active:
543
+ msgs = build_messages(
544
+ s["instruction"], s["query_image"], s["question"], s["shots"]
545
+ )
546
+ messages_list.append(msgs)
547
+ active_indices.append(i)
548
+
549
+ # 批量推理
550
+ try:
551
+ raw_outputs = generate_action_batch(
552
+ model, processor, messages_list,
553
+ batch_size=batch_size,
554
+ pbar=pbar,
555
+ )
556
+ except Exception as e:
557
+ # batch 推理失败时 fallback 到逐条
558
+ log(f" [WARN] Batch failed at round {round_idx}, falling back to single: {e}", rank)
559
+ raw_outputs = []
560
+ for msgs in messages_list:
561
+ try:
562
+ raw_outputs.append(generate_action(model, processor, msgs))
563
+ except Exception:
564
+ raw_outputs.append("")
565
+
566
+ # 解析结果、更新状态
567
+ newly_done = 0
568
+ for idx_in_batch, global_idx in enumerate(active_indices):
569
+ s = states[global_idx]
570
+ raw = raw_outputs[idx_in_batch]
571
+ action, content = parse_action(raw)
572
+
573
+ s["rounds"].append({
574
+ "round": round_idx,
575
+ "action": action,
576
+ "content": content,
577
+ "raw": raw[:300],
578
+ })
579
+
580
+ if action == "ans":
581
+ s["done"] = True
582
+ s["result"] = {
583
+ "image": s["query_image"],
584
+ "question": s["question"],
585
+ "gt_answer": s["gt_answer"],
586
+ "final_answer": content,
587
+ "num_rounds": round_idx + 1,
588
+ "terminated_by": "ans",
589
+ "rounds": s["rounds"],
590
+ }
591
+ newly_done += 1
592
+ elif action == "ret":
593
+ next_image = None
594
+ for c in s["candidates"]:
595
+ if c not in s["used_images"]:
596
+ next_image = c
597
+ break
598
+
599
+ if next_image is None:
600
+ s["done"] = True
601
+ s["result"] = {
602
+ "image": s["query_image"],
603
+ "question": s["question"],
604
+ "gt_answer": s["gt_answer"],
605
+ "final_answer": None,
606
+ "num_rounds": round_idx + 1,
607
+ "terminated_by": "no_more_shots",
608
+ "rounds": s["rounds"],
609
+ }
610
+ newly_done += 1
611
+ else:
612
+ cap = caption_cache.get(next_image, content)
613
+ s["shots"].append({"image": next_image, "caption": cap})
614
+ s["used_images"].add(next_image)
615
+ else:
616
+ s["done"] = True
617
+ s["result"] = {
618
+ "image": s["query_image"],
619
+ "question": s["question"],
620
+ "gt_answer": s["gt_answer"],
621
+ "final_answer": content,
622
+ "num_rounds": round_idx + 1,
623
+ "terminated_by": "unknown_action",
624
+ "rounds": s["rounds"],
625
+ }
626
+ newly_done += 1
627
+
628
+ pbar.update(newly_done)
629
+
630
+ n_active = sum(1 for s in states if not s["done"])
631
+ if rank == 0:
632
+ pbar.set_postfix(round=round_idx, active=n_active)
633
+
634
+ # 处理还没完成的(达到 max_rounds)
635
+ for s in states:
636
+ if not s["done"]:
637
+ s["result"] = {
638
+ "image": s["query_image"],
639
+ "question": s["question"],
640
+ "gt_answer": s["gt_answer"],
641
+ "final_answer": None,
642
+ "num_rounds": max_rounds,
643
+ "terminated_by": "max_rounds",
644
+ "rounds": s["rounds"],
645
+ }
646
+ pbar.update(1)
647
+
648
+ pbar.close()
649
+ return [s["result"] for s in states]
650
+
651
+
652
+ # ---------------------------------------------------------------------------
653
+ # 答案质量指标
654
+ # ---------------------------------------------------------------------------
655
+
656
+ def normalize_answer(s: str) -> str:
657
+ """归一化答案用于比较。"""
658
+ s = s.lower().strip()
659
+ # 去标点
660
+ s = re.sub(r"[^\w\s]", "", s)
661
+ # 去多余空格
662
+ s = " ".join(s.split())
663
+ return s
664
+
665
+
666
+ def compute_metrics(results: List[Dict]) -> Dict:
667
+ """计算答案质量指标。"""
668
+ answered = [r for r in results if r.get("final_answer") is not None]
669
+ if not answered:
670
+ return {"exact_match": 0.0, "contains_gt": 0.0, "answer_rate": 0.0}
671
+
672
+ em_count = 0
673
+ contains_count = 0
674
+
675
+ for r in answered:
676
+ pred = normalize_answer(r["final_answer"])
677
+ gt = normalize_answer(r["gt_answer"])
678
+
679
+ if pred == gt:
680
+ em_count += 1
681
+ if gt in pred or pred in gt:
682
+ contains_count += 1
683
+
684
+ n_total = len(results)
685
+ n_answered = len(answered)
686
+
687
+ return {
688
+ "exact_match": em_count / n_answered * 100 if n_answered else 0.0,
689
+ "contains_gt": contains_count / n_answered * 100 if n_answered else 0.0,
690
+ "answer_rate": n_answered / n_total * 100 if n_total else 0.0,
691
+ "shot_distribution": compute_shot_distribution(results),
692
+ "avg_shots": compute_avg_shots(results),
693
+ }
694
+
695
+
696
+ def compute_shot_distribution(results: List[Dict]) -> Dict[str, int]:
697
+ """统计 shot 数量分布。"""
698
+ shot_counts = defaultdict(int)
699
+ for r in results:
700
+ if r.get("terminated_by") == "ans":
701
+ n_shots = r["num_rounds"] - 1
702
+ else:
703
+ n_shots = r["num_rounds"]
704
+ shot_counts[f"{n_shots}-shot"] += 1
705
+ return dict(sorted(shot_counts.items()))
706
+
707
+
708
+ def compute_avg_shots(results: List[Dict]) -> float:
709
+ if not results:
710
+ return 0.0
711
+ total = 0
712
+ for r in results:
713
+ if r.get("terminated_by") == "ans":
714
+ total += r["num_rounds"] - 1
715
+ else:
716
+ total += r["num_rounds"]
717
+ return total / len(results)
718
+
719
+
720
+ # ---------------------------------------------------------------------------
721
+ # 统计输出
722
+ # ---------------------------------------------------------------------------
723
+
724
+ def print_stats(results: List[Dict], cat: str = "", ds: str = ""):
725
+ prefix = f"[{cat}/{ds}]" if ds else f"[{cat}]" if cat else "[ALL]"
726
+ n = len(results)
727
+ if n == 0:
728
+ print(f"{prefix} 无结果")
729
+ return
730
+
731
+ # 终止原因
732
+ term_counts = defaultdict(int)
733
+ for r in results:
734
+ term_counts[r["terminated_by"]] += 1
735
+
736
+ # 每轮 action 分布
737
+ round_actions = defaultdict(lambda: defaultdict(int))
738
+ for r in results:
739
+ for rd in r["rounds"]:
740
+ round_actions[rd["round"]][rd["action"]] += 1
741
+
742
+ avg_rounds = sum(r["num_rounds"] for r in results) / n
743
+
744
+ # 答案质量
745
+ metrics = compute_metrics(results)
746
+
747
+ print(f"\n{'=' * 64}")
748
+ print(f"{prefix} 共 {n} 条样本")
749
+ print(f" 平均轮次: {avg_rounds:.2f}")
750
+ print(f" 终止原因:")
751
+ for k, v in sorted(term_counts.items()):
752
+ print(f" {k}: {v} ({v / n * 100:.1f}%)")
753
+
754
+ print(f" 每轮 RET/ANS 分布:")
755
+ for rd_idx in sorted(round_actions.keys()):
756
+ actions = round_actions[rd_idx]
757
+ total = sum(actions.values())
758
+ parts = [f"{a}={c}({c / total * 100:.0f}%)" for a, c in sorted(actions.items())]
759
+ print(f" Round {rd_idx}: {' | '.join(parts)} (共 {total} 条)")
760
+
761
+ # Shot 数量统计(num_rounds - 1 = 回答前检索了几个 shot)
762
+ shot_counts = defaultdict(int)
763
+ for r in results:
764
+ if r["terminated_by"] == "ans":
765
+ n_shots = r["num_rounds"] - 1 # RET 次数 = 回答时已有的 shot 数
766
+ else:
767
+ n_shots = r["num_rounds"] # 没回答的,全是 RET
768
+ shot_counts[n_shots] += 1
769
+
770
+ print(f" Shot 数量分布 (回答时已有的 shot 数):")
771
+ for k in sorted(shot_counts.keys()):
772
+ v = shot_counts[k]
773
+ bar = "█" * int(v / n * 40)
774
+ print(f" {k}-shot: {v:4d} ({v / n * 100:5.1f}%) {bar}")
775
+ avg_shots = sum(k * v for k, v in shot_counts.items()) / n
776
+ print(f" 平均 shot 数: {avg_shots:.2f}")
777
+
778
+ answered = [r for r in results if r["final_answer"] is not None]
779
+ print(f" 产出答案: {len(answered)}/{n} ({metrics['answer_rate']:.1f}%)")
780
+ if answered:
781
+ print(f" 答案质量 (仅 ans 样本):")
782
+ print(f" Exact Match: {metrics['exact_match']:.1f}%")
783
+ print(f" Contains GT: {metrics['contains_gt']:.1f}%")
784
+ print(f"{'=' * 64}")
785
+
786
+
787
+ # ---------------------------------------------------------------------------
788
+ # Main
789
+ # ---------------------------------------------------------------------------
790
+
791
+ def main():
792
+ parser = argparse.ArgumentParser(description="ICL 多轮推理评测(支持多卡,log 对齐)")
793
+ parser.add_argument("--model-path", required=True, help="合并后的 HF 模型路径")
794
+ parser.add_argument("--category", type=str, default="")
795
+ parser.add_argument("--dataset", type=str, default="")
796
+ parser.add_argument("--split", type=str, default="val",
797
+ help="使用的数据 split(默认 val,与训练集 train 隔离)")
798
+ parser.add_argument("--all-categories", action="store_true")
799
+ parser.add_argument("--num-samples", type=int, default=100,
800
+ help="每个 dataset 采样数")
801
+ parser.add_argument("--max-rounds", type=int, default=4)
802
+ parser.add_argument("--batch-size", type=int, default=4,
803
+ help="每轮 batch 推理的样本数")
804
+ parser.add_argument("--device", type=str, default="cuda:0",
805
+ help="单卡时用的设备")
806
+ parser.add_argument("--output-dir", type=str,
807
+ default="/workspace/xiaobin/ICL/SFT_new/eval_results",
808
+ help="评测结果保存目录")
809
+ parser.add_argument("--seed", type=int, default=42)
810
+ args = parser.parse_args()
811
+
812
+ random.seed(args.seed)
813
+
814
+ # ---- 分布式初始化 ----
815
+ rank, world_size, dist_device = setup_distributed()
816
+ device = dist_device or args.device
817
+ is_main = rank == 0
818
+
819
+ log(f"World size: {world_size}", rank)
820
+ log(f"Model: {args.model_path}", rank)
821
+ log(f"Split: {args.split} (与训练集 train 隔离)", rank)
822
+
823
+ # ---- 加载模型 ----
824
+ model, processor, ret_id, ans_id = load_model(args.model_path, device)
825
+ log(f"Model loaded. <RET>={ret_id}, <ANS>={ans_id}", rank)
826
+
827
+ # ---- 确定 dataset 列表 ----
828
+ if args.all_categories:
829
+ categories = ["vqa", "captioning", "classification", "reasoning"]
830
+ elif args.category:
831
+ categories = [args.category]
832
+ else:
833
+ categories = ["vqa"]
834
+
835
+ if args.dataset:
836
+ ds_list = [(args.category or "vqa", args.dataset)]
837
+ else:
838
+ ds_list = discover_datasets(categories)
839
+
840
+ # ---- 按 rank 分配 dataset ----
841
+ my_ds_list = ds_list[rank::world_size]
842
+ log(f"共 {len(ds_list)} 个 dataset,rank {rank} 分到 {len(my_ds_list)} 个", rank)
843
+
844
+ local_results = []
845
+ t_start = time.time()
846
+
847
+ for ds_idx, (cat, ds) in enumerate(my_ds_list):
848
+ log(f"[{ds_idx + 1}/{len(my_ds_list)}] Evaluating {cat}/{ds} ({args.split})", rank)
849
+
850
+ records = load_records(cat, ds, args.split, limit=args.num_samples * 5)
851
+ if not records:
852
+ log(f" 跳过 {cat}/{ds}:无记录", rank)
853
+ continue
854
+
855
+ top5 = load_top5(cat, ds)
856
+ if not top5:
857
+ log(f" 跳过 {cat}/{ds}:无 top5 embedding", rank)
858
+ continue
859
+
860
+ caption_cache = load_caption_cache(cat, ds)
861
+ instructions = load_instructions(cat, ds)
862
+
863
+ # 过滤:需要 top5 覆盖
864
+ records = [r for r in records if r["image"] in top5]
865
+ if not records:
866
+ log(f" 跳过 {cat}/{ds}:val 图片无 top5 覆盖", rank)
867
+ continue
868
+
869
+ if len(records) > args.num_samples:
870
+ records = random.sample(records, args.num_samples)
871
+ log(f" {cat}/{ds}: {len(records)} 条, batch_size={args.batch_size}", rank)
872
+
873
+ ds_results = run_icl_batch(
874
+ model, processor, records, instructions, top5, caption_cache,
875
+ max_rounds=args.max_rounds,
876
+ batch_size=args.batch_size,
877
+ rank=rank,
878
+ ds_label=f"{cat}/{ds}",
879
+ )
880
+ for r in ds_results:
881
+ r["category"] = cat
882
+ r["dataset"] = ds
883
+ local_results.extend(ds_results)
884
+
885
+ elapsed = time.time() - t_start
886
+ log(f"\nrank {rank} 完成,{len(local_results)} 条,耗时 {elapsed:.1f}s", rank)
887
+
888
+ # ---- 汇总结果 ----
889
+ all_results = gather_results(local_results, rank, world_size)
890
+
891
+ if is_main:
892
+ # 排序:category → dataset → image
893
+ all_results.sort(key=lambda r: (r.get("category", ""), r.get("dataset", ""), r.get("image", "")))
894
+
895
+ # ---- 按 category / dataset 打印统计 ----
896
+ cat_results = defaultdict(list)
897
+ for r in all_results:
898
+ cat_results[r["category"]].append(r)
899
+
900
+ for cat in categories:
901
+ if not cat_results.get(cat):
902
+ continue
903
+ ds_groups = defaultdict(list)
904
+ for r in cat_results[cat]:
905
+ ds_groups[r["dataset"]].append(r)
906
+ for d in sorted(ds_groups):
907
+ print_stats(ds_groups[d], cat, d)
908
+ # category 汇总
909
+ if len(ds_groups) > 1:
910
+ print_stats(cat_results[cat], cat)
911
+
912
+ # 总汇总
913
+ if len(categories) > 1 or not args.dataset:
914
+ print_stats(all_results)
915
+
916
+ # ---- 保存 JSON log ----
917
+ os.makedirs(args.output_dir, exist_ok=True)
918
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
919
+ output_path = os.path.join(args.output_dir, f"eval_{args.split}_{timestamp}.json")
920
+
921
+ # 构建 summary
922
+ summary = {
923
+ "model_path": args.model_path,
924
+ "split": args.split,
925
+ "num_samples_per_ds": args.num_samples,
926
+ "max_rounds": args.max_rounds,
927
+ "total_samples": len(all_results),
928
+ "world_size": world_size,
929
+ "elapsed_seconds": elapsed,
930
+ "metrics": {},
931
+ }
932
+
933
+ # 整体 metrics
934
+ summary["metrics"]["overall"] = compute_metrics(all_results)
935
+
936
+ # 按 category metrics
937
+ for cat in categories:
938
+ if cat_results.get(cat):
939
+ summary["metrics"][cat] = compute_metrics(cat_results[cat])
940
+
941
+ output_data = {
942
+ "summary": summary,
943
+ "results": all_results,
944
+ }
945
+
946
+ with open(output_path, "w", encoding="utf-8") as f:
947
+ json.dump(output_data, f, ensure_ascii=False, indent=2)
948
+ print(f"\n详细结果已保存到: {output_path}")
949
+
950
+ # 也保存一份不带时间戳的 latest
951
+ latest_path = os.path.join(args.output_dir, f"eval_{args.split}_latest.json")
952
+ with open(latest_path, "w", encoding="utf-8") as f:
953
+ json.dump(output_data, f, ensure_ascii=False, indent=2)
954
+ print(f"Latest 链接: {latest_path}")
955
+
956
+ if world_size > 1:
957
+ dist.destroy_process_group()
958
+
959
+
960
+ if __name__ == "__main__":
961
+ main()
ICL/SFT_new/launch_wrapper.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Wrapper for northjob: receives torchrun args, launches run_multi_node.sh."""
3
+ import subprocess
4
+ import sys
5
+ import os
6
+
7
+ if __name__ == "__main__":
8
+ script_dir = os.path.dirname(os.path.abspath(__file__))
9
+ bash_script = os.path.join(script_dir, "run_multi_node.sh")
10
+ args = sys.argv[1:]
11
+ cmd = ["bash", bash_script] + args
12
+ result = subprocess.run(cmd, env=os.environ.copy())
13
+ sys.exit(result.returncode)
ICL/SFT_new/rebuild_and_train.sh ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # =============================================================================
3
+ # 一键:重建 SFT 数据 → 提交 16 卡训练任务
4
+ #
5
+ # 1. 用新配比 (answer_at_weights=1,3,3,2 + 去掉中间ANS) 重建数据
6
+ # 2. 通过 northjob 提交 16 GPU 训练
7
+ #
8
+ # Usage:
9
+ # bash rebuild_and_train.sh
10
+ # =============================================================================
11
+ set -euo pipefail
12
+
13
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
14
+ ICL_DIR="$(dirname "${SCRIPT_DIR}")"
15
+ PYTHON_BIN="/workspace/miniconda3/envs/sft/bin/python3"
16
+
17
+ BUILD_SCRIPT="${ICL_DIR}/build_sft.py"
18
+ SFT_OUTPUT="/workspace/xiaobin/dataset/sft"
19
+ SFT_DATA="${SFT_OUTPUT}/all/sft.jsonl"
20
+
21
+ echo "============================================"
22
+ echo "Step 1: 重建 SFT 数据集"
23
+ echo " 权重: 5,3,2,1 (多给 0-shot ANS,轨迹式无矛盾)"
24
+ echo " 轨迹式生成:同一输入只出现一种 action"
25
+ echo "============================================"
26
+
27
+ # 备份旧数据
28
+ if [ -f "${SFT_DATA}" ]; then
29
+ BACKUP="${SFT_DATA}.bak.$(date +%Y%m%d_%H%M%S)"
30
+ cp "${SFT_DATA}" "${BACKUP}"
31
+ echo "旧数据已备份: ${BACKUP}"
32
+ fi
33
+
34
+ # 重建数据(4 类,总量 ~6 万条 SFT 样本)
35
+ ${PYTHON_BIN} "${BUILD_SCRIPT}" \
36
+ --answer-at-weights "5,3,2,1" \
37
+ --samples-per-cat 7800 \
38
+ --shuffle
39
+
40
+ echo ""
41
+
42
+ # 验证新数据
43
+ echo "============================================"
44
+ echo "Step 2: 验证新数据配比"
45
+ echo "============================================"
46
+ ${PYTHON_BIN} -c "
47
+ import json
48
+ ret, ans = 0, 0
49
+ shot_ret, shot_ans = {}, {}
50
+ with open('${SFT_DATA}') as f:
51
+ for line in f:
52
+ r = json.loads(line)
53
+ n = len(r.get('shots', []))
54
+ if r['type'] == 'ret':
55
+ ret += 1
56
+ shot_ret[n] = shot_ret.get(n, 0) + 1
57
+ else:
58
+ ans += 1
59
+ shot_ans[n] = shot_ans.get(n, 0) + 1
60
+ total = ret + ans
61
+ print(f'总样本: {total}')
62
+ print(f'RET: {ret} ({ret/total*100:.1f}%)')
63
+ print(f'ANS: {ans} ({ans/total*100:.1f}%)')
64
+ print(f'RET/ANS 比: {ret/max(ans,1):.2f}:1')
65
+ print()
66
+ print('RET shot 分布:')
67
+ for k in sorted(shot_ret): print(f' {k}-shot: {shot_ret[k]}')
68
+ print('ANS shot 分布:')
69
+ for k in sorted(shot_ans): print(f' {k}-shot: {shot_ans[k]}')
70
+ r0 = shot_ret.get(0, 0); a0 = shot_ans.get(0, 0)
71
+ print(f'\n0-shot: RET={r0}({r0/(r0+a0)*100:.1f}%) ANS={a0}({a0/(r0+a0)*100:.1f}%)')
72
+ "
73
+
74
+ echo ""
75
+ echo "============================================"
76
+ echo "Step 3: 提交 16 卡训练任务"
77
+ echo "============================================"
78
+
79
+ bash "${SCRIPT_DIR}/submit_northjob.sh" 16
80
+
81
+ echo ""
82
+ echo "============================================"
83
+ echo "全部完成!"
84
+ echo " 数据: ${SFT_DATA}"
85
+ echo " 任务: 16 GPU via northjob"
86
+ echo "============================================"
ICL/SFT_new/run_eval.sh ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # =============================================================================
3
+ # ICL 评测启动脚本
4
+ #
5
+ # 默认:四类任务 (vqa, captioning, classification, reasoning) 各 500 条
6
+ #
7
+ # 用法:
8
+ # bash run_eval.sh # 单卡,四类各 500 条
9
+ # bash run_eval.sh 8 # 8 卡,四类各 500 条
10
+ # bash run_eval.sh 1 vqa vqav2 20 # 单卡,指定 dataset,20 条
11
+ #
12
+ # 环境变量:
13
+ # MODEL_PATH=... bash run_eval.sh # 指定模型路径
14
+ # BATCH_SIZE=8 bash run_eval.sh # 调大 batch
15
+ # =============================================================================
16
+ set -euo pipefail
17
+
18
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
19
+
20
+ # ---- 默认参数 ----
21
+ NUM_GPUS="${1:-1}"
22
+ CATEGORY="${2:-}"
23
+ DATASET="${3:-}"
24
+ NUM_SAMPLES="${4:-500}"
25
+ BATCH_SIZE="${BATCH_SIZE:-4}"
26
+ SPLIT="val"
27
+ MODEL_PATH="${MODEL_PATH:-/workspace/xiaobin/ICL/sft_model/epoch3_step1406_fp32}"
28
+ OUTPUT_DIR="${SCRIPT_DIR}/eval_results"
29
+
30
+ echo "============================================"
31
+ echo "ICL Evaluation"
32
+ echo " GPUs: ${NUM_GPUS}"
33
+ echo " Model: ${MODEL_PATH}"
34
+ echo " Split: ${SPLIT}"
35
+ echo " Batch size: ${BATCH_SIZE}"
36
+ echo " Samples/ds: ${NUM_SAMPLES}"
37
+ echo " Category: ${CATEGORY:-all (vqa,captioning,classification,reasoning)}"
38
+ echo " Dataset: ${DATASET:-all}"
39
+ echo " Output: ${OUTPUT_DIR}"
40
+ echo "============================================"
41
+
42
+ # ---- 构建参数 ----
43
+ EXTRA_ARGS=""
44
+ if [ -n "${CATEGORY}" ] && [ -n "${DATASET}" ]; then
45
+ EXTRA_ARGS="--category ${CATEGORY} --dataset ${DATASET}"
46
+ elif [ -n "${CATEGORY}" ]; then
47
+ EXTRA_ARGS="--category ${CATEGORY}"
48
+ else
49
+ EXTRA_ARGS="--all-categories"
50
+ fi
51
+
52
+ if [ "${NUM_GPUS}" -eq 1 ]; then
53
+ python3 "${SCRIPT_DIR}/eval.py" \
54
+ --model-path "${MODEL_PATH}" \
55
+ --split "${SPLIT}" \
56
+ --num-samples "${NUM_SAMPLES}" \
57
+ --batch-size "${BATCH_SIZE}" \
58
+ --max-rounds 4 \
59
+ --output-dir "${OUTPUT_DIR}" \
60
+ --device cuda:0 \
61
+ ${EXTRA_ARGS}
62
+ else
63
+ torchrun \
64
+ --nproc_per_node="${NUM_GPUS}" \
65
+ --master_port=29501 \
66
+ "${SCRIPT_DIR}/eval.py" \
67
+ --model-path "${MODEL_PATH}" \
68
+ --split "${SPLIT}" \
69
+ --num-samples "${NUM_SAMPLES}" \
70
+ --batch-size "${BATCH_SIZE}" \
71
+ --max-rounds 4 \
72
+ --output-dir "${OUTPUT_DIR}" \
73
+ ${EXTRA_ARGS}
74
+ fi
ICL/SFT_new/run_single_node.sh ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # =============================================================================
3
+ # Single-node training (1 machine, 8x H100)
4
+ # For debugging and quick iteration
5
+ #
6
+ # Usage:
7
+ # bash run_single_node.sh <data.jsonl> [num_gpus]
8
+ # bash run_single_node.sh /path/to/sft.jsonl 8
9
+ # bash run_single_node.sh /path/to/sft.jsonl 2 # quick debug
10
+ # =============================================================================
11
+ set -euo pipefail
12
+
13
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
14
+
15
+ # ---- Config ----
16
+ MODEL_PATH="/workspace/models/Qwen3-VL-8B-Instruct"
17
+ DATA_PATH="${1:?Usage: $0 <data.jsonl> [num_gpus]}"
18
+ NUM_GPUS="${2:-8}"
19
+ OUTPUT_DIR="/workspace/xiaobin/ICL/sft_model"
20
+
21
+ # ---- Env ----
22
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
23
+ export NCCL_P2P_DISABLE=0
24
+ export NCCL_IB_DISABLE=0
25
+
26
+ # ---- Launch ----
27
+ echo "============================================"
28
+ echo "Single-node SFT: ${NUM_GPUS} GPUs"
29
+ echo "Model: ${MODEL_PATH}"
30
+ echo "Data: ${DATA_PATH}"
31
+ echo "Output: ${OUTPUT_DIR}"
32
+ echo "============================================"
33
+
34
+ torchrun \
35
+ --nproc_per_node=${NUM_GPUS} \
36
+ --master_port=29500 \
37
+ ${SCRIPT_DIR}/train.py \
38
+ --model-path ${MODEL_PATH} \
39
+ --data-path ${DATA_PATH} \
40
+ --output-dir ${OUTPUT_DIR} \
41
+ --deepspeed ${SCRIPT_DIR}/ds_zero2.json \
42
+ --num-epochs 3 \
43
+ --batch-size 2 \
44
+ --gradient-accumulation-steps 4 \
45
+ --learning-rate 1e-6 \
46
+ --max-length 32768 \
47
+ --gradient-checkpointing \
48
+ --log-interval 10 \
49
+ --save-interval 500
ICL/SFT_new/submit_northjob.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # =============================================================================
3
+ # Submit multi-node job via northjob (16 GPUs = 2 nodes × 8 H100)
4
+ #
5
+ # Usage:
6
+ # bash submit_northjob.sh [num_gpus]
7
+ # bash submit_northjob.sh 16 # 2 nodes
8
+ # bash submit_northjob.sh 8 # 1 node (debug)
9
+ # =============================================================================
10
+ set -euo pipefail
11
+
12
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
13
+ GPU_NUMS="${1:-16}"
14
+ GPU_PER_NODE=8
15
+ NNODES=$((GPU_NUMS / GPU_PER_NODE))
16
+
17
+ JOB_NAME="qwen3vl-sft-${GPU_NUMS}gpu"
18
+ WORK_DIR="${SCRIPT_DIR}"
19
+ TRAIN_SCRIPT="${SCRIPT_DIR}/launch_wrapper.py"
20
+
21
+ echo "Submitting: ${JOB_NAME} (${NNODES} nodes × ${GPU_PER_NODE} GPUs)"
22
+
23
+ /workspace/miniconda3/envs/sft/bin/northjob \
24
+ create \
25
+ --job-type train \
26
+ --nproc-per-node ${GPU_PER_NODE} \
27
+ --gpu-per-node ${GPU_PER_NODE} \
28
+ --nnodes ${NNODES} \
29
+ --k8s-priority 3 \
30
+ --k8s-queue bg-agentic-coding \
31
+ --k8s-namespace bg-agentic-coding \
32
+ --k8s-pvc-name i-xinsiyang-y4zy0sik0a \
33
+ --k8s-pvc-mount-path /workspace \
34
+ --k8s-no-reclaim \
35
+ --k8s-images harbor.local.clusters/bp/megatron-bplm:25.03_fp8.ibgda.qwen3.next.fix_triton.fix_te.hf457.qwen3_vl \
36
+ --job-name ${JOB_NAME} \
37
+ --workspace ${WORK_DIR} \
38
+ ${TRAIN_SCRIPT} ${GPU_PER_NODE}
ICL/SFT_new/train.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Qwen3-VL-8B SFT Training Script (single-step RET/ANS decision).
5
+
6
+ Supports:
7
+ - Full fine-tuning or LoRA
8
+ - DeepSpeed ZeRO-2/3
9
+ - Multi-image conversations
10
+ - Loss masking on user turns only
11
+ - Flash Attention 2 on H100
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import logging
17
+ import math
18
+ import os
19
+ import sys
20
+ from pathlib import Path
21
+ from typing import Dict, List, Optional, Sequence
22
+
23
+ import torch
24
+ import torch.distributed as dist
25
+ from torch.utils.data import Dataset, DataLoader
26
+
27
+ from transformers import (
28
+ AutoProcessor,
29
+ Qwen3VLForConditionalGeneration,
30
+ get_cosine_schedule_with_warmup,
31
+ )
32
+ from peft import LoraConfig, get_peft_model, TaskType
33
+ from qwen_vl_utils import process_vision_info
34
+
35
+ try:
36
+ import deepspeed
37
+ HAS_DEEPSPEED = True
38
+ except ImportError:
39
+ HAS_DEEPSPEED = False
40
+
41
+ logging.basicConfig(
42
+ format="%(asctime)s [%(levelname)s] %(message)s",
43
+ level=logging.INFO,
44
+ )
45
+ logger = logging.getLogger(__name__)
46
+
47
+ # Special token IDs (Qwen3-VL)
48
+ IM_START_ID = 151644
49
+ IM_END_ID = 151645
50
+ IGNORE_INDEX = -100
51
+
52
+
53
+ # ============================================================================
54
+ # Dataset
55
+ # ============================================================================
56
+
57
+ class SFTDataset(Dataset):
58
+ """Load single-step SFT JSONL (轻量引用格式).
59
+
60
+ 支持两种格式:
61
+
62
+ 格式 A (新,轻量引用):
63
+ {
64
+ "type": "ret" | "ans",
65
+ "query_image": "/path/to/query.jpg",
66
+ "question": "What color?",
67
+ "answer": "black",
68
+ "instruction": "Answer the question...",
69
+ "shots": [{"image": "/path/shot.jpg", "caption": "A cat..."}],
70
+ "next_description": "A dog..." // 仅 ret 类型
71
+ }
72
+
73
+ 格式 B (旧,conversations):
74
+ {
75
+ "images": ["path1.jpg", ...],
76
+ "conversations": [
77
+ {"from": "human", "value": "...<image>..."},
78
+ {"from": "gpt", "value": "<ANS>answer</ANS>"}
79
+ ]
80
+ }
81
+ """
82
+
83
+ def __init__(self, data_path: str, processor, max_length: int = 4096,
84
+ min_pixels: int = 256 * 28 * 28,
85
+ max_pixels: int = 1280 * 28 * 28):
86
+ self.processor = processor
87
+ self.max_length = max_length
88
+ self.min_pixels = min_pixels
89
+ self.max_pixels = max_pixels
90
+ self.records = []
91
+
92
+ logger.info(f"Loading data from {data_path}")
93
+ with open(data_path, "r", encoding="utf-8") as f:
94
+ for line in f:
95
+ line = line.strip()
96
+ if not line:
97
+ continue
98
+ try:
99
+ self.records.append(json.loads(line))
100
+ except Exception:
101
+ continue
102
+ logger.info(f"Loaded {len(self.records)} samples")
103
+
104
+ def __len__(self):
105
+ return len(self.records)
106
+
107
+ # ---- 新格式: 从引用字段动态构建 messages ----
108
+
109
+ def _build_messages_v2(self, record: Dict) -> List[Dict]:
110
+ """从轻量引用格式构建 Qwen3-VL chat messages."""
111
+ user_content = []
112
+
113
+ # 1. instruction
114
+ inst = record.get("instruction", "")
115
+ if inst:
116
+ user_content.append({"type": "text", "text": inst})
117
+
118
+ # 2. query image
119
+ user_content.append({
120
+ "type": "image",
121
+ "image": f"file://{record['query_image']}",
122
+ "min_pixels": self.min_pixels,
123
+ "max_pixels": self.max_pixels,
124
+ })
125
+
126
+ # 3. question (可能为空,如 captioning 类)
127
+ question = record.get("question", "")
128
+ if question:
129
+ user_content.append({"type": "text", "text": f"Question: {question}"})
130
+
131
+ # 4. context shots (image + caption)
132
+ for shot in record.get("shots", []):
133
+ user_content.append({
134
+ "type": "image",
135
+ "image": f"file://{shot['image']}",
136
+ "min_pixels": self.min_pixels,
137
+ "max_pixels": self.max_pixels,
138
+ })
139
+ cap = shot.get("caption", "")
140
+ if cap:
141
+ user_content.append({"type": "text", "text": f"Caption: {cap}"})
142
+
143
+ # 5. Action prompt
144
+ user_content.append({"type": "text", "text": "Action:"})
145
+
146
+ # 6. assistant response
147
+ if record["type"] == "ret":
148
+ desc = record.get("next_description", "")
149
+ assistant_text = f"<RET>\nDescription: {desc}"
150
+ else:
151
+ assistant_text = f"<ANS>{record['answer']}</ANS>"
152
+
153
+ messages = [
154
+ {"role": "user", "content": user_content},
155
+ {"role": "assistant", "content": [{"type": "text", "text": assistant_text}]},
156
+ ]
157
+ return messages
158
+
159
+ # ---- 旧格式: conversations + <image> 占位符 ----
160
+
161
+ def _build_messages_v1(self, record: Dict) -> List[Dict]:
162
+ """Convert conversations format → Qwen3-VL chat messages."""
163
+ convs = record["conversations"]
164
+ image_paths = record.get("images", [])
165
+ messages = []
166
+
167
+ for turn in convs:
168
+ role = "user" if turn["from"] == "human" else "assistant"
169
+ text = turn["value"]
170
+
171
+ if role == "user":
172
+ content = []
173
+ parts = text.split("<image>")
174
+ img_idx = 0
175
+ for i, part in enumerate(parts):
176
+ if i > 0 and img_idx < len(image_paths):
177
+ content.append({
178
+ "type": "image",
179
+ "image": f"file://{image_paths[img_idx]}",
180
+ "min_pixels": self.min_pixels,
181
+ "max_pixels": self.max_pixels,
182
+ })
183
+ img_idx += 1
184
+ if part.strip():
185
+ content.append({"type": "text", "text": part.strip()})
186
+ messages.append({"role": role, "content": content})
187
+ else:
188
+ messages.append({
189
+ "role": role,
190
+ "content": [{"type": "text", "text": text}],
191
+ })
192
+
193
+ return messages
194
+
195
+ def __getitem__(self, idx):
196
+ record = self.records[idx]
197
+
198
+ # 自动检测格式
199
+ if "type" in record and "query_image" in record:
200
+ messages = self._build_messages_v2(record)
201
+ else:
202
+ messages = self._build_messages_v1(record)
203
+
204
+ # Apply chat template (no generation prompt for training)
205
+ text = self.processor.apply_chat_template(
206
+ messages, tokenize=False, add_generation_prompt=False
207
+ )
208
+
209
+ # Process images
210
+ image_inputs = None
211
+ try:
212
+ image_inputs, _ = process_vision_info(messages)
213
+ except Exception:
214
+ pass
215
+
216
+ # Tokenize — 不截断,避免图片 token 不匹配
217
+ inputs = self.processor(
218
+ text=[text],
219
+ images=image_inputs if image_inputs else None,
220
+ return_tensors="pt",
221
+ padding=False,
222
+ truncation=False,
223
+ )
224
+
225
+ # Squeeze batch dim
226
+ input_ids = inputs["input_ids"].squeeze(0)
227
+ attention_mask = inputs["attention_mask"].squeeze(0)
228
+
229
+ # 超长时截断文本部分(保留前 max_length 个 token)
230
+ if input_ids.shape[0] > self.max_length:
231
+ input_ids = input_ids[:self.max_length]
232
+ attention_mask = attention_mask[:self.max_length]
233
+
234
+ # Build labels: mask user turns, keep assistant turns
235
+ labels = self._build_labels(input_ids)
236
+
237
+ result = {
238
+ "input_ids": input_ids,
239
+ "attention_mask": attention_mask,
240
+ "labels": labels,
241
+ }
242
+ # Pass through pixel values if present
243
+ if "pixel_values" in inputs:
244
+ result["pixel_values"] = inputs["pixel_values"].squeeze(0) \
245
+ if inputs["pixel_values"].dim() > 3 else inputs["pixel_values"]
246
+ if "image_grid_thw" in inputs:
247
+ result["image_grid_thw"] = inputs["image_grid_thw"]
248
+
249
+ return result
250
+
251
+ def _build_labels(self, input_ids: torch.Tensor) -> torch.Tensor:
252
+ """Mask everything except assistant responses.
253
+
254
+ Strategy: find <|im_start|>assistant ... <|im_end|> spans,
255
+ only compute loss on tokens after 'assistant\n' until <|im_end|>.
256
+ """
257
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
258
+ ids = input_ids.tolist()
259
+
260
+ assist_tokens = self.processor.tokenizer.encode(
261
+ "assistant\n", add_special_tokens=False
262
+ )
263
+
264
+ i = 0
265
+ while i < len(ids):
266
+ if ids[i] == IM_START_ID:
267
+ start = i + 1
268
+ end = start + len(assist_tokens)
269
+ if end <= len(ids) and ids[start:end] == assist_tokens:
270
+ content_start = end
271
+ j = content_start
272
+ while j < len(ids) and ids[j] != IM_END_ID:
273
+ j += 1
274
+ labels[content_start:j + 1] = input_ids[content_start:j + 1]
275
+ i = j + 1
276
+ continue
277
+ i += 1
278
+
279
+ return labels
280
+
281
+
282
+ # ============================================================================
283
+ # Collator
284
+ # ============================================================================
285
+
286
+ class SFTCollator:
287
+ """Pad variable-length samples into a batch."""
288
+
289
+ def __init__(self, pad_token_id: int, max_length: int = 4096):
290
+ self.pad_token_id = pad_token_id
291
+ self.max_length = max_length
292
+
293
+ def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
294
+ max_len = min(
295
+ max(f["input_ids"].size(0) for f in features),
296
+ self.max_length,
297
+ )
298
+
299
+ batch_input_ids = []
300
+ batch_attention_mask = []
301
+ batch_labels = []
302
+ batch_pixel_values = []
303
+ batch_image_grid_thw = []
304
+
305
+ for f in features:
306
+ ids = f["input_ids"][:max_len]
307
+ mask = f["attention_mask"][:max_len]
308
+ lab = f["labels"][:max_len]
309
+ pad_len = max_len - ids.size(0)
310
+
311
+ if pad_len > 0:
312
+ ids = torch.cat([ids, torch.full((pad_len,), self.pad_token_id, dtype=ids.dtype)])
313
+ mask = torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
314
+ lab = torch.cat([lab, torch.full((pad_len,), IGNORE_INDEX, dtype=lab.dtype)])
315
+
316
+ batch_input_ids.append(ids)
317
+ batch_attention_mask.append(mask)
318
+ batch_labels.append(lab)
319
+
320
+ if "pixel_values" in f:
321
+ batch_pixel_values.append(f["pixel_values"])
322
+ if "image_grid_thw" in f:
323
+ batch_image_grid_thw.append(f["image_grid_thw"])
324
+
325
+ result = {
326
+ "input_ids": torch.stack(batch_input_ids),
327
+ "attention_mask": torch.stack(batch_attention_mask),
328
+ "labels": torch.stack(batch_labels),
329
+ }
330
+
331
+ if batch_pixel_values:
332
+ result["pixel_values"] = torch.cat(batch_pixel_values, dim=0)
333
+ if batch_image_grid_thw:
334
+ result["image_grid_thw"] = torch.cat(batch_image_grid_thw, dim=0)
335
+
336
+ return result
337
+
338
+
339
+ # ============================================================================
340
+ # Training
341
+ # ============================================================================
342
+
343
+ def train(args):
344
+ # ---- Distributed setup ----
345
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
346
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
347
+ rank = int(os.environ.get("RANK", 0))
348
+
349
+ if world_size > 1 and not dist.is_initialized():
350
+ dist.init_process_group("nccl")
351
+
352
+ torch.cuda.set_device(local_rank)
353
+ device = torch.device(f"cuda:{local_rank}")
354
+ is_main = rank == 0
355
+
356
+ if is_main:
357
+ logger.info(f"World size: {world_size}, Local rank: {local_rank}")
358
+ logger.info(f"Args: {vars(args)}")
359
+
360
+ # ---- Load processor & model ----
361
+ processor = AutoProcessor.from_pretrained(
362
+ args.model_path, trust_remote_code=True,
363
+ min_pixels=args.min_pixels, max_pixels=args.max_pixels,
364
+ )
365
+
366
+ model_kwargs = {
367
+ "trust_remote_code": True,
368
+ "torch_dtype": torch.bfloat16,
369
+ "attn_implementation": "flash_attention_2",
370
+ }
371
+ if not (HAS_DEEPSPEED and args.deepspeed):
372
+ model_kwargs["device_map"] = {"": device}
373
+
374
+ model = Qwen3VLForConditionalGeneration.from_pretrained(
375
+ args.model_path, **model_kwargs,
376
+ )
377
+
378
+ # Add special tokens
379
+ special_tokens = ["<RET>", "<ANS>", "</ANS>", "<RETQ>", "</RETQ>"]
380
+ num_added = processor.tokenizer.add_tokens(special_tokens, special_tokens=True)
381
+ if num_added > 0:
382
+ model.resize_token_embeddings(len(processor.tokenizer))
383
+ if is_main:
384
+ logger.info(f"Added {num_added} special tokens, vocab → {len(processor.tokenizer)}")
385
+
386
+ # ---- LoRA (optional) ----
387
+ if args.use_lora:
388
+ lora_config = LoraConfig(
389
+ task_type=TaskType.CAUSAL_LM,
390
+ r=args.lora_rank,
391
+ lora_alpha=args.lora_alpha,
392
+ lora_dropout=args.lora_dropout,
393
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
394
+ "gate_proj", "up_proj", "down_proj"],
395
+ )
396
+ model = get_peft_model(model, lora_config)
397
+ if is_main:
398
+ model.print_trainable_parameters()
399
+ else:
400
+ if args.gradient_checkpointing:
401
+ model.gradient_checkpointing_enable(
402
+ gradient_checkpointing_kwargs={"use_reentrant": False}
403
+ )
404
+
405
+ # ---- Dataset ----
406
+ train_dataset = SFTDataset(
407
+ args.data_path, processor, args.max_length,
408
+ args.min_pixels, args.max_pixels,
409
+ )
410
+ collator = SFTCollator(processor.tokenizer.pad_token_id, args.max_length)
411
+
412
+ # ---- DeepSpeed or vanilla DDP ----
413
+ if HAS_DEEPSPEED and args.deepspeed:
414
+ # Load DS config and dynamically set scheduler params
415
+ import copy
416
+ with open(args.deepspeed, "r") as _f:
417
+ ds_config = json.load(_f)
418
+
419
+ # Explicitly set all batch-size params (avoid "auto" which some DS versions don't support)
420
+ micro_bs = ds_config.get("train_micro_batch_size_per_gpu", args.batch_size)
421
+ grad_accum_cfg = ds_config.get("gradient_accumulation_steps", args.gradient_accumulation_steps)
422
+ ds_config["train_micro_batch_size_per_gpu"] = micro_bs
423
+ ds_config["gradient_accumulation_steps"] = grad_accum_cfg
424
+ ds_config["train_batch_size"] = micro_bs * grad_accum_cfg * world_size
425
+
426
+ # Override LR from CLI args
427
+ if "optimizer" in ds_config and "params" in ds_config["optimizer"]:
428
+ ds_config["optimizer"]["params"]["lr"] = args.learning_rate
429
+
430
+ if is_main:
431
+ logger.info(f"DeepSpeed config: micro_bs={micro_bs}, grad_accum={grad_accum_cfg}, "
432
+ f"world_size={world_size}, train_batch_size={ds_config['train_batch_size']}")
433
+
434
+ model_engine, optimizer, train_loader, _ = deepspeed.initialize(
435
+ model=model,
436
+ model_parameters=[p for p in model.parameters() if p.requires_grad],
437
+ training_data=train_dataset,
438
+ collate_fn=collator,
439
+ config=ds_config,
440
+ )
441
+ # total_steps = optimizer steps (micro-batch steps per epoch / grad_accum * num_epochs)
442
+ grad_accum = model_engine.gradient_accumulation_steps()
443
+ steps_per_epoch = len(train_loader) // grad_accum
444
+ total_steps = steps_per_epoch * args.num_epochs
445
+ warmup_steps = int(total_steps * args.warmup_ratio)
446
+
447
+ # Replace DS scheduler with cosine schedule
448
+ # Note: model_engine.optimizer is DeepSpeedZeroOptimizer (not a torch.optim.Optimizer),
449
+ # so we must use the underlying torch optimizer for LambdaLR.
450
+ base_optimizer = model_engine.optimizer.optimizer # unwrap to torch AdamW
451
+ ds_scheduler = get_cosine_schedule_with_warmup(
452
+ base_optimizer,
453
+ num_warmup_steps=warmup_steps,
454
+ num_training_steps=total_steps,
455
+ )
456
+ model_engine.lr_scheduler = ds_scheduler
457
+ scheduler = None
458
+ else:
459
+ # Vanilla DDP
460
+ if world_size > 1:
461
+ model = torch.nn.parallel.DistributedDataParallel(
462
+ model, device_ids=[local_rank],
463
+ find_unused_parameters=False,
464
+ )
465
+ sampler = torch.utils.data.distributed.DistributedSampler(
466
+ train_dataset, num_replicas=world_size, rank=rank, shuffle=True,
467
+ ) if world_size > 1 else None
468
+
469
+ train_loader = DataLoader(
470
+ train_dataset, batch_size=args.batch_size,
471
+ sampler=sampler, shuffle=(sampler is None),
472
+ collate_fn=collator, num_workers=args.num_workers,
473
+ pin_memory=True, drop_last=True,
474
+ )
475
+ optimizer = torch.optim.AdamW(
476
+ [p for p in model.parameters() if p.requires_grad],
477
+ lr=args.learning_rate, weight_decay=args.weight_decay,
478
+ betas=(0.9, 0.999),
479
+ )
480
+ total_steps = (len(train_loader) * args.num_epochs) // args.gradient_accumulation_steps
481
+ warmup_steps = int(total_steps * args.warmup_ratio)
482
+ scheduler = get_cosine_schedule_with_warmup(
483
+ optimizer, warmup_steps, total_steps,
484
+ )
485
+ model_engine = None
486
+
487
+ if is_main:
488
+ logger.info(f"Dataset: {len(train_dataset)} samples")
489
+ logger.info(f"Total steps: {total_steps}, Warmup: {warmup_steps}")
490
+
491
+ # ---- Training loop ----
492
+ optimizer_step = 0
493
+ running_loss = 0.0
494
+ running_count = 0
495
+ accum_loss = 0.0 # accumulate loss across micro-batches within one grad accum cycle
496
+
497
+ for epoch in range(args.num_epochs):
498
+ if hasattr(train_loader, "sampler") and hasattr(train_loader.sampler, "set_epoch"):
499
+ train_loader.sampler.set_epoch(epoch)
500
+
501
+ model.train() if model_engine is None else model_engine.train()
502
+
503
+ for step, batch in enumerate(train_loader):
504
+ # Move batch to GPU
505
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
506
+ for k, v in batch.items()}
507
+
508
+ # Forward
509
+ if model_engine:
510
+ outputs = model_engine(**batch)
511
+ loss = outputs.loss
512
+ model_engine.backward(loss)
513
+ model_engine.step()
514
+
515
+ # Accumulate loss across micro-batches
516
+ accum_loss += loss.item()
517
+
518
+ # Log/save only on optimizer step boundaries
519
+ if model_engine.is_gradient_accumulation_boundary():
520
+ grad_accum = model_engine.gradient_accumulation_steps()
521
+ optimizer_step += 1
522
+ cur_loss = accum_loss / grad_accum # average over micro-batches
523
+ accum_loss = 0.0
524
+
525
+ running_loss += cur_loss
526
+ running_count += 1
527
+ avg_loss = running_loss / running_count
528
+
529
+ if is_main and optimizer_step % args.log_interval == 0:
530
+ lr_now = ds_scheduler.get_last_lr()[0]
531
+ logger.info(
532
+ f"Epoch {epoch+1}/{args.num_epochs} "
533
+ f"Step {optimizer_step}/{total_steps} "
534
+ f"Loss {cur_loss:.4f} "
535
+ f"AvgLoss {avg_loss:.4f} "
536
+ f"LR {lr_now:.2e}"
537
+ )
538
+
539
+ # Save checkpoint
540
+ if optimizer_step > 0 and optimizer_step % args.save_interval == 0:
541
+ _save_checkpoint(args, model, model_engine, processor, epoch, optimizer_step, is_main)
542
+
543
+ else:
544
+ outputs = model(**batch)
545
+ loss = outputs.loss / args.gradient_accumulation_steps
546
+ loss.backward()
547
+ accum_loss += loss.item() * args.gradient_accumulation_steps
548
+
549
+ if (step + 1) % args.gradient_accumulation_steps == 0:
550
+ torch.nn.utils.clip_grad_norm_(
551
+ model.parameters(), args.max_grad_norm
552
+ )
553
+ optimizer.step()
554
+ scheduler.step()
555
+ optimizer.zero_grad()
556
+ optimizer_step += 1
557
+
558
+ cur_loss = accum_loss / args.gradient_accumulation_steps
559
+ accum_loss = 0.0
560
+
561
+ running_loss += cur_loss
562
+ running_count += 1
563
+ avg_loss = running_loss / running_count
564
+
565
+ if is_main and optimizer_step % args.log_interval == 0:
566
+ lr_now = scheduler.get_last_lr()[0]
567
+ logger.info(
568
+ f"Epoch {epoch+1}/{args.num_epochs} "
569
+ f"Step {optimizer_step}/{total_steps} "
570
+ f"Loss {cur_loss:.4f} "
571
+ f"AvgLoss {avg_loss:.4f} "
572
+ f"LR {lr_now:.2e}"
573
+ )
574
+
575
+ # End of epoch save
576
+ if model_engine:
577
+ _save_checkpoint(args, model, model_engine, processor, epoch, optimizer_step, is_main)
578
+ else:
579
+ _save_checkpoint(args, model, model_engine, processor, epoch, optimizer_step, is_main)
580
+
581
+ # Final save
582
+ if model_engine:
583
+ _save_checkpoint(args, model, model_engine, processor, args.num_epochs, optimizer_step, is_main, final=True)
584
+ else:
585
+ _save_checkpoint(args, model, model_engine, processor, args.num_epochs, optimizer_step, is_main, final=True)
586
+
587
+ if is_main:
588
+ logger.info("Training complete!")
589
+
590
+
591
+ def _save_checkpoint(args, model, model_engine, processor, epoch, step, is_main, final=False):
592
+ tag = "final" if final else f"epoch{epoch+1}_step{step}"
593
+ save_dir = Path(args.output_dir) / tag
594
+
595
+ if model_engine and HAS_DEEPSPEED:
596
+ # DeepSpeed save_checkpoint must be called by ALL ranks
597
+ model_engine.save_checkpoint(str(args.output_dir), tag=tag)
598
+ elif is_main:
599
+ unwrapped = model.module if hasattr(model, "module") else model
600
+ if args.use_lora:
601
+ unwrapped.save_pretrained(str(save_dir))
602
+ else:
603
+ unwrapped.save_pretrained(str(save_dir))
604
+ processor.save_pretrained(str(save_dir))
605
+
606
+ if is_main:
607
+ logger.info(f"Saved checkpoint → {save_dir}")
608
+
609
+
610
+ # ============================================================================
611
+ # Main
612
+ # ============================================================================
613
+
614
+ def parse_args():
615
+ p = argparse.ArgumentParser(description="Qwen3-VL-8B SFT")
616
+
617
+ # Model
618
+ p.add_argument("--model-path", default="/workspace/models/Qwen3-VL-8B-Instruct")
619
+ p.add_argument("--output-dir", default="/workspace/xiaobin/ICL/SFT_new/output/qwen3vl_sft")
620
+
621
+ # Data
622
+ p.add_argument("--data-path", required=True, help="Path to sft.jsonl")
623
+ p.add_argument("--max-length", type=int, default=4096)
624
+ p.add_argument("--min-pixels", type=int, default=256 * 28 * 28)
625
+ p.add_argument("--max-pixels", type=int, default=1280 * 28 * 28)
626
+
627
+ # Training
628
+ p.add_argument("--num-epochs", type=int, default=3)
629
+ p.add_argument("--batch-size", type=int, default=1,
630
+ help="Per-GPU micro batch size")
631
+ p.add_argument("--gradient-accumulation-steps", type=int, default=4)
632
+ p.add_argument("--learning-rate", type=float, default=1e-5)
633
+ p.add_argument("--weight-decay", type=float, default=0.1)
634
+ p.add_argument("--warmup-ratio", type=float, default=0.05)
635
+ p.add_argument("--max-grad-norm", type=float, default=1.0)
636
+ p.add_argument("--gradient-checkpointing", action="store_true", default=True)
637
+ p.add_argument("--num-workers", type=int, default=4)
638
+
639
+ # LoRA
640
+ p.add_argument("--use-lora", action="store_true", default=False)
641
+ p.add_argument("--lora-rank", type=int, default=64)
642
+ p.add_argument("--lora-alpha", type=int, default=128)
643
+ p.add_argument("--lora-dropout", type=float, default=0.05)
644
+
645
+ # Logging
646
+ p.add_argument("--log-interval", type=int, default=10)
647
+ p.add_argument("--save-interval", type=int, default=500)
648
+
649
+ # DeepSpeed
650
+ p.add_argument("--deepspeed", type=str, default=None,
651
+ help="Path to DeepSpeed config JSON")
652
+ p.add_argument("--local_rank", type=int, default=-1) # torchrun sets this
653
+
654
+ return p.parse_args()
655
+
656
+
657
+ if __name__ == "__main__":
658
+ args = parse_args()
659
+ train(args)
ICL/build_embeddings.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 预算 SigLIP2 embeddings + Top5 相似图片映射(8卡 DataParallel)。
4
+
5
+ 用法:
6
+ python3 build_embeddings.py # 8卡,全部
7
+ python3 build_embeddings.py --datasets vqa/shapes # 测试
8
+ python3 build_embeddings.py --force # 强制重建
9
+ """
10
+
11
+ import argparse
12
+ import json
13
+ import os
14
+ import numpy as np
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ from typing import Dict, List, Tuple
17
+
18
+ try:
19
+ from tqdm import tqdm
20
+ except ImportError:
21
+ def tqdm(x, **kw):
22
+ return x
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import cv2
27
+ import numpy as np
28
+ from PIL import Image
29
+ from transformers import AutoModel, AutoProcessor
30
+
31
+ # ---------------------------------------------------------------------------
32
+ IMAGES_ROOT = "/workspace/xiaobin/dataset/images"
33
+ CAPTION_CACHE_DIR = "/workspace/xiaobin/dataset/caption_cache"
34
+ EMBEDDINGS_DIR = "/workspace/xiaobin/dataset/embeddings"
35
+ DEFAULT_MODEL = "/workspace/models/siglip2-so400m-patch14-384"
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # DataParallel wrappers
40
+ # ---------------------------------------------------------------------------
41
+ class SigLIPImageModule(nn.Module):
42
+ def __init__(self, model):
43
+ super().__init__()
44
+ self.model = model
45
+
46
+ def forward(self, **kwargs):
47
+ out = self.model.get_image_features(**kwargs)
48
+ feat = out.pooler_output if hasattr(out, "pooler_output") else out
49
+ return feat / feat.norm(dim=-1, keepdim=True)
50
+
51
+
52
+ class SigLIPTextModule(nn.Module):
53
+ def __init__(self, model):
54
+ super().__init__()
55
+ self.model = model
56
+
57
+ def forward(self, **kwargs):
58
+ out = self.model.get_text_features(**kwargs)
59
+ feat = out.pooler_output if hasattr(out, "pooler_output") else out
60
+ return feat / feat.norm(dim=-1, keepdim=True)
61
+
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # Encoder: 单进程, 多线程读图, 小batch快速跑
65
+ # ---------------------------------------------------------------------------
66
+ class SigLIPEncoder:
67
+ def __init__(self, model_path: str, gpu_ids: List[int],
68
+ batch_size_per_gpu: int = 64, num_threads: int = 16):
69
+ self.gpu_ids = gpu_ids
70
+ self.n_gpus = len(gpu_ids)
71
+ self.batch_size = batch_size_per_gpu * self.n_gpus
72
+ self.num_threads = num_threads
73
+ self.primary = torch.device(f"cuda:{gpu_ids[0]}")
74
+
75
+ print(f" GPU: {gpu_ids} ({self.n_gpus} 张)")
76
+ print(f" batch: {batch_size_per_gpu}/卡 × {self.n_gpus}卡 = {self.batch_size}")
77
+
78
+ self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
79
+ base_model = AutoModel.from_pretrained(
80
+ model_path, dtype=torch.bfloat16, trust_remote_code=True
81
+ ).to(self.primary).eval()
82
+
83
+ self.img_module = nn.DataParallel(
84
+ SigLIPImageModule(base_model), device_ids=gpu_ids)
85
+ self.txt_module = nn.DataParallel(
86
+ SigLIPTextModule(base_model), device_ids=gpu_ids)
87
+
88
+ @staticmethod
89
+ def _load_and_preprocess(path):
90
+ """读图 + OpenCV resize + normalize → numpy (3, 384, 384) float32"""
91
+ try:
92
+ img = cv2.imread(path)
93
+ if img is None:
94
+ return (path, None)
95
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
96
+ img = cv2.resize(img, (384, 384))
97
+ img = img.astype(np.float32) / 255.0
98
+ img = (img - 0.5) / 0.5
99
+ img = np.transpose(img, (2, 0, 1)) # (3, 384, 384)
100
+ return (path, img)
101
+ except Exception:
102
+ return (path, None)
103
+
104
+ def encode_images(self, paths: List[str]) -> Tuple[List[str], np.ndarray]:
105
+ all_embs = []
106
+ valid_paths = []
107
+ n = len(paths)
108
+ pbar = tqdm(total=n, desc=" encode-img", unit="张", dynamic_ncols=True)
109
+
110
+ thread_pool = ThreadPoolExecutor(max_workers=self.num_threads)
111
+
112
+ batches = [paths[s:s + self.batch_size]
113
+ for s in range(0, n, self.batch_size)]
114
+
115
+ # 预提交第一批
116
+ if batches:
117
+ next_future = list(thread_pool.map(self._load_and_preprocess, batches[0]))
118
+ else:
119
+ next_future = []
120
+
121
+ for i, batch_paths in enumerate(batches):
122
+ loaded = next_future
123
+
124
+ # 提前提交下一批 IO + 预处理
125
+ if i + 1 < len(batches):
126
+ next_futures_list = [thread_pool.submit(self._load_and_preprocess, p)
127
+ for p in batches[i + 1]]
128
+ else:
129
+ next_futures_list = None
130
+
131
+ batch_valid = []
132
+ batch_arrays = []
133
+ for p, arr in loaded:
134
+ if arr is not None:
135
+ batch_valid.append(p)
136
+ batch_arrays.append(arr)
137
+
138
+ if not batch_arrays:
139
+ pbar.update(len(batch_paths))
140
+ if next_futures_list:
141
+ next_future = [f.result() for f in next_futures_list]
142
+ continue
143
+
144
+ # numpy stack → torch → GPU
145
+ pixel_values = torch.from_numpy(np.stack(batch_arrays)).to(
146
+ dtype=torch.bfloat16, device=self.primary)
147
+
148
+ with torch.inference_mode():
149
+ feat = self.img_module(pixel_values=pixel_values)
150
+ all_embs.append(feat.cpu().float().numpy())
151
+ valid_paths.extend(batch_valid)
152
+
153
+ pbar.update(len(batch_paths))
154
+
155
+ if next_futures_list:
156
+ next_future = [f.result() for f in next_futures_list]
157
+
158
+ thread_pool.shutdown(wait=False)
159
+ pbar.close()
160
+ if not all_embs:
161
+ return [], np.empty((0, 0), dtype=np.float16)
162
+ return valid_paths, np.concatenate(all_embs, axis=0).astype(np.float16)
163
+
164
+ def encode_texts(self, texts: List[str]) -> np.ndarray:
165
+ all_embs = []
166
+ n = len(texts)
167
+ pbar = tqdm(total=n, desc=" encode-txt", unit="条", dynamic_ncols=True)
168
+
169
+ for start in range(0, n, self.batch_size):
170
+ batch = texts[start:start + self.batch_size]
171
+ inp = self.processor(text=batch, return_tensors="pt",
172
+ padding="max_length", truncation=True,
173
+ max_length=64)
174
+ keys = {k: v.to(self.primary) for k, v in inp.items()
175
+ if k in ("input_ids", "attention_mask", "position_ids")}
176
+ with torch.inference_mode():
177
+ feat = self.txt_module(**keys)
178
+ all_embs.append(feat.cpu().float().numpy())
179
+ pbar.update(len(batch))
180
+
181
+ pbar.close()
182
+ if not all_embs:
183
+ return np.empty((0, 0), dtype=np.float16)
184
+ return np.concatenate(all_embs, axis=0).astype(np.float16)
185
+
186
+
187
+ # ---------------------------------------------------------------------------
188
+ # Top-K(GPU)
189
+ # ---------------------------------------------------------------------------
190
+ def compute_top_k(caption_embs, image_embs, image_paths, k=5,
191
+ chunk_size=5000, device="cuda:0"):
192
+ n = len(image_paths)
193
+ img_gpu = torch.from_numpy(image_embs.astype(np.float32)).to(device)
194
+ top_k_map = {}
195
+
196
+ for start in tqdm(range(0, n, chunk_size), desc=" compute-top5", unit="chunk"):
197
+ end = min(start + chunk_size, n)
198
+ cap = torch.from_numpy(caption_embs[start:end].astype(np.float32)).to(device)
199
+ sim = cap @ img_gpu.T
200
+ idx_range = torch.arange(end - start, device=sim.device)
201
+ sim[idx_range, torch.arange(start, end, device=sim.device)] = -1.0
202
+ _, top_idx = sim.topk(k, dim=1)
203
+ top_idx_cpu = top_idx.cpu().numpy()
204
+ for i in range(end - start):
205
+ top_k_map[image_paths[start + i]] = [
206
+ image_paths[j] for j in top_idx_cpu[i]]
207
+ return top_k_map
208
+
209
+
210
+ # ---------------------------------------------------------------------------
211
+ # 数据集工具
212
+ # ---------------------------------------------------------------------------
213
+ def discover_datasets(categories=None, specific=None):
214
+ if specific:
215
+ return [(s.split("/")[0], s.split("/")[1]) for s in specific if "/" in s]
216
+ result = []
217
+ for cat in sorted(os.listdir(IMAGES_ROOT)):
218
+ d = os.path.join(IMAGES_ROOT, cat)
219
+ if not os.path.isdir(d):
220
+ continue
221
+ if categories and cat not in categories:
222
+ continue
223
+ for ds in sorted(os.listdir(d)):
224
+ if os.path.isdir(os.path.join(d, ds)):
225
+ result.append((cat, ds))
226
+ return result
227
+
228
+
229
+ def load_captions(cat, ds):
230
+ p = os.path.join(CAPTION_CACHE_DIR, f"{cat}_{ds}.json")
231
+ if not os.path.exists(p):
232
+ return {}
233
+ try:
234
+ with open(p) as f:
235
+ return json.load(f).get("items", {})
236
+ except Exception:
237
+ return {}
238
+
239
+
240
+ def collect_images(cat, ds):
241
+ base = os.path.join(IMAGES_ROOT, cat, ds)
242
+ paths = []
243
+ for split in ("train", "val", "test", "other"):
244
+ d = os.path.join(base, split)
245
+ if not os.path.isdir(d):
246
+ continue
247
+ for fn in sorted(os.listdir(d)):
248
+ fp = os.path.join(d, fn)
249
+ if os.path.isfile(fp):
250
+ paths.append(fp)
251
+ return paths
252
+
253
+
254
+ # ---------------------------------------------------------------------------
255
+ # 处理单个数据集(含断点续传)
256
+ # ---------------------------------------------------------------------------
257
+ def process_dataset(cat, ds, encoder, top_k, force):
258
+ tag = f"{cat}_{ds}"
259
+ npz_path = os.path.join(EMBEDDINGS_DIR, f"{tag}.npz")
260
+ top5_path = os.path.join(EMBEDDINGS_DIR, f"{tag}_top{top_k}.json")
261
+
262
+ # 断点1:全部完成
263
+ if not force and os.path.exists(npz_path) and os.path.exists(top5_path):
264
+ try:
265
+ data = np.load(npz_path, allow_pickle=True)
266
+ n_emb = len(data["image_paths"])
267
+ with open(top5_path) as f:
268
+ n_top = len(json.load(f))
269
+ if n_emb == n_top and n_emb > 0:
270
+ print(f" [SKIP] {tag} ({n_emb} 张)")
271
+ return True
272
+ except Exception:
273
+ pass
274
+
275
+ # 断点2:有 embeddings 缺 top5
276
+ if not force and os.path.exists(npz_path) and not os.path.exists(top5_path):
277
+ try:
278
+ data = np.load(npz_path, allow_pickle=True)
279
+ sp = list(data["image_paths"])
280
+ si, sc = data["image_embs"], data["caption_embs"]
281
+ if len(sp) > 0 and si.shape[0] == len(sp):
282
+ print(f" [RESUME] {tag} 只算 top{top_k} ({len(sp)} 张)")
283
+ m = compute_top_k(sc, si, sp, k=top_k, device=str(encoder.primary))
284
+ with open(top5_path, 'w') as f:
285
+ json.dump(m, f, ensure_ascii=False)
286
+ print(f" top{top_k}: {os.path.getsize(top5_path)/1048576:.1f}MB")
287
+ return True
288
+ except Exception:
289
+ pass
290
+
291
+ # 从头
292
+ all_paths = collect_images(cat, ds)
293
+ if not all_paths:
294
+ print(f" [SKIP] {tag} 无图片")
295
+ return False
296
+
297
+ captions = load_captions(cat, ds)
298
+ if not captions:
299
+ print(f" [WARN] {tag} 无 caption,跳过")
300
+ return False
301
+
302
+ paths_with_cap = [p for p in all_paths if p in captions]
303
+ if not paths_with_cap:
304
+ print(f" [WARN] {tag} 无交集,跳过")
305
+ return False
306
+
307
+ print(f"\n {tag}: {len(paths_with_cap)} 张图")
308
+
309
+ valid_paths, image_embs = encoder.encode_images(paths_with_cap)
310
+ if not valid_paths:
311
+ print(f" [ERROR] {tag} 编码失败")
312
+ return False
313
+
314
+ caption_embs = encoder.encode_texts([captions[p] for p in valid_paths])
315
+
316
+ os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
317
+ np.savez_compressed(npz_path, image_paths=np.array(valid_paths),
318
+ image_embs=image_embs, caption_embs=caption_embs)
319
+ print(f" embeddings: {os.path.getsize(npz_path)/1048576:.1f}MB")
320
+
321
+ m = compute_top_k(caption_embs, image_embs, valid_paths,
322
+ k=top_k, device=str(encoder.primary))
323
+ with open(top5_path, 'w') as f:
324
+ json.dump(m, f, ensure_ascii=False)
325
+ print(f" top{top_k}: {os.path.getsize(top5_path)/1048576:.1f}MB")
326
+ return True
327
+
328
+
329
+ # ---------------------------------------------------------------------------
330
+ def main():
331
+ parser = argparse.ArgumentParser()
332
+ parser.add_argument("--model-path", default=DEFAULT_MODEL)
333
+ parser.add_argument("--gpus", default="")
334
+ parser.add_argument("--batch-size-per-gpu", type=int, default=256,
335
+ help="每卡batch(预处理不再是瓶颈,可以开大)")
336
+ parser.add_argument("--num-threads", type=int, default=16,
337
+ help="图片IO线程数")
338
+ parser.add_argument("--top-k", type=int, default=5)
339
+ parser.add_argument("--categories", default="")
340
+ parser.add_argument("--datasets", default="")
341
+ parser.add_argument("--force", action="store_true")
342
+ args = parser.parse_args()
343
+
344
+ gpu_ids = ([int(x) for x in args.gpus.split(",") if x.strip()]
345
+ or list(range(torch.cuda.device_count())))
346
+ total_batch = args.batch_size_per_gpu * len(gpu_ids)
347
+ print(f"GPU: {gpu_ids} ({len(gpu_ids)} 张), batch: {total_batch}")
348
+
349
+ cats = [c.strip() for c in args.categories.split(",") if c.strip()] or None
350
+ specific = [d.strip() for d in args.datasets.split(",") if d.strip()] or None
351
+ datasets = discover_datasets(categories=cats, specific=specific)
352
+ print(f"共 {len(datasets)} 个数据集\n")
353
+
354
+ encoder = SigLIPEncoder(args.model_path, gpu_ids,
355
+ args.batch_size_per_gpu, args.num_threads)
356
+
357
+ ok, fail = 0, 0
358
+ pbar = tqdm(datasets, desc="总进度", unit="ds", dynamic_ncols=True)
359
+ for i, (cat, ds) in enumerate(pbar, 1):
360
+ pbar.set_postfix(current=f"{cat}/{ds}", ok=ok, fail=fail)
361
+ if process_dataset(cat, ds, encoder, args.top_k, args.force):
362
+ ok += 1
363
+ else:
364
+ fail += 1
365
+ pbar.close()
366
+ print(f"\n完成: {ok} 成功, {fail} 失败/跳过 → {EMBEDDINGS_DIR}")
367
+
368
+
369
+ if __name__ == "__main__":
370
+ main()
ICL/build_index.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 生成索引 JSONL:将原始 base64 JSONL 的文本字段 + 提取后的图片路径 + VLM描述 对应起来。
4
+
5
+ 输入:
6
+ /workspace/xiaobin/dataset/data/{cat}/{ds}/{split}.jsonl (原始,含base64)
7
+ /workspace/xiaobin/dataset/images/{cat}/{ds}/{split}/ (已提取的图片)
8
+ /workspace/xiaobin/dataset/detail/{cat}/{ds}/{split}/captions.json (VLM描述)
9
+
10
+ 输出:
11
+ /workspace/xiaobin/dataset/index/{cat}/{ds}/{split}.jsonl (轻量索引)
12
+
13
+ 每条记录格式:
14
+ {
15
+ "image": "/workspace/xiaobin/dataset/images/vqa/shapes/test/00000000.jpg",
16
+ "images": ["/path/..."], # 多图时(video_str/images字段)
17
+ "question": "...",
18
+ "answer": "...",
19
+ "description": "A cat sitting...", # 来自 detail/captions.json
20
+ "meta": {...}, # 原始meta(如有)
21
+ "id": "...", # 原始id/img_id
22
+ "category": "vqa",
23
+ "dataset": "shapes",
24
+ "split": "test"
25
+ }
26
+
27
+ 用法:
28
+ python3 build_index.py # 全部(已完成的自动跳过)
29
+ python3 build_index.py vqa/shapes # 某个数据集
30
+ python3 build_index.py --force # 全部强制重建
31
+ python3 build_index.py --force vqa/shapes # 某个数据集强制重建
32
+ """
33
+
34
+ import os
35
+ import sys
36
+ import json
37
+ import glob
38
+ import re
39
+ from tqdm import tqdm
40
+
41
+ DATA_ROOT = "/workspace/xiaobin/dataset/data"
42
+ IMAGES_ROOT = "/workspace/xiaobin/dataset/images"
43
+ DETAIL_ROOT = "/workspace/xiaobin/dataset/detail"
44
+ INDEX_ROOT = "/workspace/xiaobin/dataset/index"
45
+
46
+ # 图片base64字段(用于判断"这行有图",和extract_images.py一致)
47
+ ALL_IMAGE_FIELDS = [
48
+ "image", "image_str", "image_base64_str", "img_str",
49
+ "base64", "image_base64", "image_base_url",
50
+ "video_str", "images",
51
+ ]
52
+
53
+ # 文本字段提取
54
+ QUESTION_FIELDS = ["question", "text", "query", "prompt", "input", "inputs", "user_prompt"]
55
+ ANSWER_FIELDS = ["answer", "output", "outputs", "label", "target", "caption", "paraphrased_answer", "original_answer"]
56
+
57
+
58
+ def classify_split(filename):
59
+ fn = filename.lower()
60
+ if "train" in fn:
61
+ return "train"
62
+ elif "test" in fn:
63
+ return "test"
64
+ elif "val" in fn:
65
+ return "val"
66
+ else:
67
+ return "other"
68
+
69
+
70
+ def has_image(record):
71
+ """判断这条记录是否有图(和 extract_images.py 逻辑一致)"""
72
+ for field in ALL_IMAGE_FIELDS:
73
+ if field not in record or not record[field]:
74
+ continue
75
+ val = record[field]
76
+ if isinstance(val, str) and len(val) > 100:
77
+ return True
78
+ elif isinstance(val, list):
79
+ if any(isinstance(item, str) and len(item) > 100 for item in val):
80
+ return True
81
+ return False
82
+
83
+
84
+ def is_multi_image(record):
85
+ """判断是否是多图记录(video_str/images 列表字段)"""
86
+ for field in ("video_str", "images"):
87
+ if field in record and isinstance(record[field], list):
88
+ items = [x for x in record[field] if isinstance(x, str) and len(x) > 100]
89
+ if len(items) > 1:
90
+ return True
91
+ # image_str/image_base64 也可能是list
92
+ for field in ("image_str", "image_base64"):
93
+ val = record.get(field)
94
+ if isinstance(val, list):
95
+ items = [x for x in val if isinstance(x, str) and len(x) > 100]
96
+ if len(items) > 1:
97
+ return True
98
+ return False
99
+
100
+
101
+ def count_images_in_record(record):
102
+ """统计这条记录里有几张图"""
103
+ for field in ALL_IMAGE_FIELDS:
104
+ if field not in record or not record[field]:
105
+ continue
106
+ val = record[field]
107
+ if isinstance(val, str) and len(val) > 100:
108
+ return 1
109
+ elif isinstance(val, list):
110
+ return len([x for x in val if isinstance(x, str) and len(x) > 100])
111
+ return 0
112
+
113
+
114
+ def extract_text(record, fields):
115
+ """从记录中提取文本字段"""
116
+ for k in fields:
117
+ v = record.get(k)
118
+ if isinstance(v, str) and v.strip():
119
+ return v.strip()
120
+ # 尝试 answers 列表
121
+ if "answers" in record:
122
+ v = record["answers"]
123
+ if isinstance(v, list):
124
+ for a in v:
125
+ if isinstance(a, str) and a.strip():
126
+ return a.strip()
127
+ return None
128
+
129
+
130
+ def extract_id(record):
131
+ """提取记录ID"""
132
+ for k in ("id", "image_id", "img_id"):
133
+ v = record.get(k)
134
+ if v is not None:
135
+ return str(v)
136
+ meta = record.get("meta")
137
+ if isinstance(meta, dict):
138
+ for k in ("img_id", "id", "image_id"):
139
+ v = meta.get(k)
140
+ if v is not None:
141
+ return str(v)
142
+ return None
143
+
144
+
145
+ def extract_meta(record):
146
+ """提取meta信息(去掉base64等大字段)"""
147
+ meta = record.get("meta")
148
+ if not isinstance(meta, dict):
149
+ return None
150
+ out = {}
151
+ for k, v in meta.items():
152
+ # 跳过所有图片/base64相关字段
153
+ if any(x in k.lower() for x in ("image", "img", "base64", "video")):
154
+ continue
155
+ # 跳过大字符串
156
+ if isinstance(v, str) and len(v) > 500:
157
+ continue
158
+ # 跳过含大字符串的列表
159
+ if isinstance(v, list) and v and isinstance(v[0], str) and len(v[0]) > 200:
160
+ continue
161
+ out[k] = v
162
+ return out if out else None
163
+
164
+
165
+ def load_detail(category, dataset, split):
166
+ """加载 VLM description 缓存"""
167
+ path = os.path.join(DETAIL_ROOT, category, dataset, split, "captions.json")
168
+ if not os.path.exists(path):
169
+ return {}
170
+ try:
171
+ with open(path, 'r', encoding='utf-8') as f:
172
+ data = json.load(f)
173
+ items = data.get("items", {})
174
+ if isinstance(items, dict):
175
+ return items
176
+ except Exception:
177
+ pass
178
+ return {}
179
+
180
+
181
+ def count_lines(filepath):
182
+ count = 0
183
+ with open(filepath, 'rb') as f:
184
+ buf_size = 8 * 1024 * 1024
185
+ buf = f.raw.read(buf_size)
186
+ while buf:
187
+ count += buf.count(b'\n')
188
+ buf = f.raw.read(buf_size)
189
+ return count
190
+
191
+
192
+ def process_one(jsonl_path, file_idx, total_files):
193
+ """处理单个原始 JSONL,生成索引 JSONL"""
194
+ rel_path = os.path.relpath(jsonl_path, DATA_ROOT)
195
+ parts = rel_path.split(os.sep)
196
+ if len(parts) < 3:
197
+ return 0
198
+
199
+ category, dataset, filename = parts[0], parts[1], parts[2]
200
+ split = classify_split(filename)
201
+
202
+ # 图片目录
203
+ img_dir = os.path.join(IMAGES_ROOT, category, dataset, split)
204
+ if not os.path.isdir(img_dir):
205
+ print(f" [SKIP] 无图片目录: {img_dir}")
206
+ return 0
207
+
208
+ # 图片文件列表(按编号排序)
209
+ img_files = sorted([f for f in os.listdir(img_dir) if os.path.isfile(os.path.join(img_dir, f))])
210
+ if not img_files:
211
+ print(f" [SKIP] 图片目录为空: {img_dir}")
212
+ return 0
213
+
214
+ # VLM描述
215
+ detail = load_detail(category, dataset, split)
216
+
217
+ # 输出索引文件
218
+ out_path = os.path.join(INDEX_ROOT, category, dataset, f"{split}.jsonl")
219
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
220
+
221
+ total_lines = count_lines(jsonl_path)
222
+ file_size_mb = os.path.getsize(jsonl_path) / (1024 * 1024)
223
+ desc = f"[{file_idx}/{total_files}] {category}/{dataset}/{split} ({file_size_mb:.0f}MB)"
224
+
225
+ img_idx = 0 # 图片文件游标
226
+ written = 0
227
+ skipped = 0
228
+
229
+ with open(jsonl_path, 'r', encoding='utf-8') as fin, \
230
+ open(out_path, 'w', encoding='utf-8') as fout:
231
+
232
+ pbar = tqdm(fin, total=total_lines, desc=desc, unit="行",
233
+ dynamic_ncols=True, miniters=100)
234
+
235
+ for line in pbar:
236
+ line = line.strip()
237
+ if not line:
238
+ continue
239
+ try:
240
+ record = json.loads(line)
241
+ except json.JSONDecodeError:
242
+ continue
243
+
244
+ if not has_image(record):
245
+ skipped += 1
246
+ continue
247
+
248
+ n_imgs = count_images_in_record(record)
249
+ if img_idx + n_imgs > len(img_files):
250
+ # 图片不够了,可能extract时有错误
251
+ skipped += 1
252
+ continue
253
+
254
+ # 收集这条记录对应的图片路径
255
+ if n_imgs == 1:
256
+ img_path = os.path.join(img_dir, img_files[img_idx])
257
+ img_paths = [img_path]
258
+ else:
259
+ img_paths = [os.path.join(img_dir, img_files[img_idx + i])
260
+ for i in range(n_imgs)]
261
+ img_path = img_paths[0]
262
+
263
+ # 获取 VLM 描述
264
+ desc_text = detail.get(img_path, "")
265
+ # 多图时尝试获取每张的描述
266
+ if n_imgs > 1:
267
+ descs = [detail.get(p, "") for p in img_paths]
268
+ else:
269
+ descs = None
270
+
271
+ # 构建索引记录
272
+ idx_record = {
273
+ "image": img_path,
274
+ "question": extract_text(record, QUESTION_FIELDS),
275
+ "answer": extract_text(record, ANSWER_FIELDS),
276
+ "description": desc_text,
277
+ "category": category,
278
+ "dataset": dataset,
279
+ "split": split,
280
+ }
281
+
282
+ # 多图
283
+ if n_imgs > 1:
284
+ idx_record["images"] = img_paths
285
+ idx_record["descriptions"] = descs
286
+
287
+ # ID
288
+ rid = extract_id(record)
289
+ if rid:
290
+ idx_record["id"] = rid
291
+
292
+ # meta
293
+ meta = extract_meta(record)
294
+ if meta:
295
+ idx_record["meta"] = meta
296
+
297
+ # instructions(如有)
298
+ insts = record.get("instructions")
299
+ if isinstance(insts, list) and insts:
300
+ idx_record["instructions"] = insts
301
+
302
+ fout.write(json.dumps(idx_record, ensure_ascii=False) + "\n")
303
+ written += 1
304
+ img_idx += n_imgs
305
+
306
+ pbar.set_postfix(written=written, imgs=img_idx, skip=skipped, refresh=False)
307
+
308
+ pbar.close()
309
+
310
+ print(f" -> {written} 条, 用了 {img_idx} 张图, 跳过 {skipped} 行")
311
+ if img_idx != len(img_files):
312
+ print(f" [WARN] 图片游标 {img_idx} != 图片总数 {len(img_files)}")
313
+ return written
314
+
315
+
316
+ def find_all_jsonl_files():
317
+ all_files = []
318
+ for jsonl_path in sorted(glob.glob(os.path.join(DATA_ROOT, "*/*/*.jsonl"))):
319
+ filename = os.path.basename(jsonl_path)
320
+ if re.search(r'_\d{4}-\d{2}-\d{2}\.jsonl$', filename):
321
+ continue
322
+ if '_v2.jsonl' in filename or '_new.jsonl' in filename:
323
+ continue
324
+ if filename.startswith('para_'):
325
+ continue
326
+ all_files.append(jsonl_path)
327
+ return all_files
328
+
329
+
330
+ def group_by_split(files):
331
+ """将多个JSONL文件按 (category/dataset/split) 分组,
332
+ 同一split的多个文件按顺序合并处理(因为extract_images是按这个顺序提取的)"""
333
+ from collections import OrderedDict
334
+ groups = OrderedDict()
335
+ for f in files:
336
+ rel = os.path.relpath(f, DATA_ROOT)
337
+ parts = rel.split(os.sep)
338
+ if len(parts) < 3:
339
+ continue
340
+ cat, ds, fn = parts[0], parts[1], parts[2]
341
+ split = classify_split(fn)
342
+ key = (cat, ds, split)
343
+ groups.setdefault(key, []).append(f)
344
+ return groups
345
+
346
+
347
+ def process_group(jsonl_files, category, dataset, split, group_idx, total_groups,
348
+ force=False):
349
+ """处理同一个 split 的一组 JSONL 文件(可能有多个)"""
350
+ out_path = os.path.join(INDEX_ROOT, category, dataset, f"{split}.jsonl")
351
+
352
+ # 断点续传:对比索引条数和图片数,一致才跳过
353
+ if not force and os.path.exists(out_path) and os.path.getsize(out_path) > 0:
354
+ existing = sum(1 for _ in open(out_path, 'r', encoding='utf-8'))
355
+ img_dir = os.path.join(IMAGES_ROOT, category, dataset, split)
356
+ if os.path.isdir(img_dir):
357
+ img_count = len([f for f in os.listdir(img_dir) if os.path.isfile(os.path.join(img_dir, f))])
358
+ if existing == img_count:
359
+ print(f" [SKIP] {category}/{dataset}/{split} 索引完整 ({existing}/{img_count})")
360
+ return existing
361
+ else:
362
+ print(f" [REDO] {category}/{dataset}/{split} 索引不完整 ({existing}/{img_count}), 重建")
363
+
364
+ img_dir = os.path.join(IMAGES_ROOT, category, dataset, split)
365
+ if not os.path.isdir(img_dir):
366
+ print(f" [SKIP] 无图片目录: {img_dir}")
367
+ return 0
368
+
369
+ img_files = sorted([f for f in os.listdir(img_dir) if os.path.isfile(os.path.join(img_dir, f))])
370
+ if not img_files:
371
+ print(f" [SKIP] 图片目录为空: {img_dir}")
372
+ return 0
373
+
374
+ detail = load_detail(category, dataset, split)
375
+
376
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
377
+
378
+ img_idx = 0 # 图片游标,跨文件累加
379
+ written = 0
380
+
381
+ with open(out_path, 'w', encoding='utf-8') as fout:
382
+ for fi, jsonl_path in enumerate(jsonl_files):
383
+ total_lines = count_lines(jsonl_path)
384
+ file_size_mb = os.path.getsize(jsonl_path) / (1024 * 1024)
385
+ fn = os.path.basename(jsonl_path)
386
+ if len(jsonl_files) > 1:
387
+ desc = f"[{group_idx}/{total_groups}] {category}/{dataset}/{split} ({fn}, {file_size_mb:.0f}MB)"
388
+ else:
389
+ desc = f"[{group_idx}/{total_groups}] {category}/{dataset}/{split} ({file_size_mb:.0f}MB)"
390
+
391
+ skipped = 0
392
+ with open(jsonl_path, 'r', encoding='utf-8') as fin:
393
+ pbar = tqdm(fin, total=total_lines, desc=desc, unit="行",
394
+ dynamic_ncols=True, miniters=100)
395
+ for line in pbar:
396
+ line = line.strip()
397
+ if not line:
398
+ continue
399
+ try:
400
+ record = json.loads(line)
401
+ except json.JSONDecodeError:
402
+ continue
403
+
404
+ if not has_image(record):
405
+ skipped += 1
406
+ continue
407
+
408
+ n_imgs = count_images_in_record(record)
409
+ if img_idx + n_imgs > len(img_files):
410
+ skipped += 1
411
+ continue
412
+
413
+ if n_imgs == 1:
414
+ img_path = os.path.join(img_dir, img_files[img_idx])
415
+ img_paths = [img_path]
416
+ else:
417
+ img_paths = [os.path.join(img_dir, img_files[img_idx + i])
418
+ for i in range(n_imgs)]
419
+ img_path = img_paths[0]
420
+
421
+ desc_text = detail.get(img_path, "")
422
+
423
+ idx_record = {
424
+ "image": img_path,
425
+ "question": extract_text(record, QUESTION_FIELDS),
426
+ "answer": extract_text(record, ANSWER_FIELDS),
427
+ "description": desc_text,
428
+ "category": category,
429
+ "dataset": dataset,
430
+ "split": split,
431
+ }
432
+
433
+ if n_imgs > 1:
434
+ idx_record["images"] = img_paths
435
+ idx_record["descriptions"] = [detail.get(p, "") for p in img_paths]
436
+
437
+ rid = extract_id(record)
438
+ if rid:
439
+ idx_record["id"] = rid
440
+
441
+ meta = extract_meta(record)
442
+ if meta:
443
+ idx_record["meta"] = meta
444
+
445
+ insts = record.get("instructions")
446
+ if isinstance(insts, list) and insts:
447
+ idx_record["instructions"] = insts
448
+
449
+ fout.write(json.dumps(idx_record, ensure_ascii=False) + "\n")
450
+ written += 1
451
+ img_idx += n_imgs
452
+
453
+ pbar.set_postfix(written=written, imgs=img_idx, skip=skipped, refresh=False)
454
+ pbar.close()
455
+
456
+ print(f" -> {written} 条, 用了 {img_idx}/{len(img_files)} 张图")
457
+ if img_idx != len(img_files):
458
+ print(f" [WARN] 图片游标 {img_idx} != 图片总数 {len(img_files)}")
459
+ return written
460
+
461
+
462
+ def main():
463
+ print("=" * 60)
464
+ print("生成索引 JSONL (图片路径 + 文本 + VLM描述)")
465
+ print(f"原始数据: {DATA_ROOT}")
466
+ print(f"图片目录: {IMAGES_ROOT}")
467
+ print(f"描述缓存: {DETAIL_ROOT}")
468
+ print(f"输出索引: {INDEX_ROOT}")
469
+ print("=" * 60)
470
+
471
+ force = "--force" in sys.argv
472
+ args = [a for a in sys.argv[1:] if a != "--force"]
473
+
474
+ if args:
475
+ target = args[0]
476
+ if os.path.isfile(target):
477
+ files = [target]
478
+ else:
479
+ files = sorted(glob.glob(os.path.join(DATA_ROOT, target, "*.jsonl")))
480
+ files = [f for f in files
481
+ if not re.search(r'_\d{4}-\d{2}-\d{2}\.jsonl$', os.path.basename(f))
482
+ and '_v2.jsonl' not in os.path.basename(f)
483
+ and '_new.jsonl' not in os.path.basename(f)
484
+ and not os.path.basename(f).startswith('para_')]
485
+ else:
486
+ files = find_all_jsonl_files()
487
+
488
+ groups = group_by_split(files)
489
+ print(f"\n共 {len(groups)} 个 split 组 ({len(files)} 个文件):")
490
+ for (cat, ds, split), flist in groups.items():
491
+ for f in flist:
492
+ size_mb = os.path.getsize(f) / (1024 * 1024)
493
+ print(f" {cat}/{ds}/{split}: {os.path.basename(f):40s} {size_mb:>10.1f} MB")
494
+
495
+ total = 0
496
+ for i, ((cat, ds, split), flist) in enumerate(groups.items(), 1):
497
+ n = process_group(flist, cat, ds, split, i, len(groups), force=force)
498
+ total += n
499
+
500
+ print(f"\n{'=' * 60}")
501
+ print(f"全部完成!共生成 {total} 条索引记录")
502
+ print(f"保存在: {INDEX_ROOT}")
503
+
504
+
505
+ if __name__ == "__main__":
506
+ main()
ICL/build_sft.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 构建单步决策 SFT 数据集(轻量版,只存引用路径)。
4
+
5
+ 输入:
6
+ /workspace/xiaobin/dataset/index/{cat}/{ds}/{split}.jsonl (索引)
7
+ /workspace/xiaobin/dataset/embeddings/{cat}_{ds}_top5.json (预计算相似图)
8
+ /workspace/xiaobin/dataset/caption_cache/{cat}_{ds}.json (VLM描述)
9
+ /workspace/xiaobin/dataset/index/{cat}/{ds}/instructions.json
10
+
11
+ 输出:
12
+ /workspace/xiaobin/dataset/sft/{cat}/sft.part{shard}.jsonl
13
+ /workspace/xiaobin/dataset/sft/all/sft.jsonl (合并后)
14
+
15
+ 每条记录格式(不含conversation,由train.py动态构建):
16
+ {
17
+ "type": "ret" | "ans",
18
+ "query_image": "/path/to/query.jpg",
19
+ "question": "...",
20
+ "answer": "...",
21
+ "instruction": "...",
22
+ "shots": [{"image": "...", "caption": "..."}],
23
+ "next_description": "...", # 仅 ret 类型
24
+ "category": "vqa",
25
+ "dataset": "vqav2"
26
+ }
27
+
28
+ 用法:
29
+ python3 build_sft.py # 全部
30
+ python3 build_sft.py --categories vqa # 单类
31
+ python3 build_sft.py --shard-id 0 --num-shards 4 # 分片
32
+ python3 build_sft.py --merge --shuffle # 合并
33
+ """
34
+
35
+ import argparse
36
+ import json
37
+ import os
38
+ import random
39
+ from pathlib import Path
40
+ from typing import Dict, List, Optional, Tuple
41
+
42
+ try:
43
+ from tqdm import tqdm
44
+ except ImportError:
45
+ def tqdm(x, **kw):
46
+ return x
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # 默认路径
50
+ # ---------------------------------------------------------------------------
51
+ INDEX_ROOT = "/workspace/xiaobin/dataset/index"
52
+ EMBEDDINGS_DIR = "/workspace/xiaobin/dataset/embeddings"
53
+ CAPTION_CACHE_DIR = "/workspace/xiaobin/dataset/caption_cache"
54
+ OUTPUT_DIR = "/workspace/xiaobin/dataset/sft"
55
+
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # 数据加载
59
+ # ---------------------------------------------------------------------------
60
+ def discover_datasets(index_root: str, categories: List[str]) -> List[Tuple[str, str]]:
61
+ """发现所有 (category, dataset) 对。"""
62
+ result = []
63
+ for cat in sorted(os.listdir(index_root)):
64
+ if categories and cat not in categories:
65
+ continue
66
+ cat_dir = os.path.join(index_root, cat)
67
+ if not os.path.isdir(cat_dir):
68
+ continue
69
+ for ds in sorted(os.listdir(cat_dir)):
70
+ if os.path.isdir(os.path.join(cat_dir, ds)):
71
+ result.append((cat, ds))
72
+ return result
73
+
74
+
75
+ def load_index(index_root: str, cat: str, ds: str, split: str) -> List[Dict]:
76
+ """加载索引 JSONL。"""
77
+ path = os.path.join(index_root, cat, ds, f"{split}.jsonl")
78
+ if not os.path.exists(path):
79
+ return []
80
+ records = []
81
+ with open(path, "r", encoding="utf-8") as f:
82
+ for line in f:
83
+ line = line.strip()
84
+ if not line:
85
+ continue
86
+ try:
87
+ r = json.loads(line)
88
+ # 必须有 image + (question 或 answer)
89
+ if r.get("image"):
90
+ records.append(r)
91
+ except Exception:
92
+ continue
93
+ return records
94
+
95
+
96
+ def load_top5(embeddings_dir: str, cat: str, ds: str, k: int = 5) -> Dict[str, List[str]]:
97
+ """加载预计算的 top-k 相似图映射。"""
98
+ path = os.path.join(embeddings_dir, f"{cat}_{ds}_top{k}.json")
99
+ if not os.path.exists(path):
100
+ return {}
101
+ with open(path, "r", encoding="utf-8") as f:
102
+ return json.load(f)
103
+
104
+
105
+ def load_captions(caption_cache_dir: str, cat: str, ds: str) -> Dict[str, str]:
106
+ """加载 caption 缓存: {image_path: description}。"""
107
+ path = os.path.join(caption_cache_dir, f"{cat}_{ds}.json")
108
+ if not os.path.exists(path):
109
+ return {}
110
+ try:
111
+ with open(path, "r", encoding="utf-8") as f:
112
+ data = json.load(f)
113
+ items = data.get("items", {})
114
+ return items if isinstance(items, dict) else {}
115
+ except Exception:
116
+ return {}
117
+
118
+
119
+ def load_instructions(index_root: str, cat: str, ds: str) -> List[str]:
120
+ """加载 instruction 模板。"""
121
+ path = os.path.join(index_root, cat, ds, "instructions.json")
122
+ if not os.path.exists(path):
123
+ return []
124
+ try:
125
+ with open(path, "r", encoding="utf-8") as f:
126
+ data = json.load(f)
127
+ if isinstance(data, list):
128
+ return [str(x).strip() for x in data if str(x).strip()]
129
+ if isinstance(data, dict):
130
+ for key in ("instructions", "instruction", "prompts"):
131
+ v = data.get(key)
132
+ if isinstance(v, list):
133
+ return [str(x).strip() for x in v if str(x).strip()]
134
+ return []
135
+ except Exception:
136
+ return []
137
+
138
+
139
+ # ---------------------------------------------------------------------------
140
+ # 样本生成
141
+ # ---------------------------------------------------------------------------
142
+ def generate_samples(
143
+ records: List[Dict],
144
+ top5_map: Dict[str, List[str]],
145
+ caption_map: Dict[str, str],
146
+ instructions: List[str],
147
+ cat: str, ds: str,
148
+ rng: random.Random,
149
+ max_shots: int = 3,
150
+ answer_at_weights: List[float] = None,
151
+ target_count: int = 0,
152
+ ) -> List[Dict]:
153
+ """为一个数据集生成 SFT 样本。
154
+
155
+ target_count=0 表示全量(遍历每条记录),>0 表示随机抽样到目标数。
156
+ """
157
+ if answer_at_weights is None:
158
+ answer_at_weights = [1, 3, 3, 2]
159
+
160
+ # 过滤:需要有 answer + top5;question 可为空(captioning 类)
161
+ valid = [r for r in records
162
+ if r.get("answer") and r.get("image") in top5_map]
163
+ if not valid:
164
+ return []
165
+
166
+ answer_at_values = list(range(len(answer_at_weights)))
167
+ default_inst = "Please answer the question based on the image."
168
+ samples = []
169
+
170
+ # 决定遍历源:全量遍历 or 随机抽样
171
+ if target_count > 0:
172
+ # 随机抽样模式
173
+ source = [rng.choice(valid) for _ in range(target_count * 5)]
174
+ else:
175
+ # 全量模式:遍历所有记录
176
+ source = valid
177
+
178
+ for q in source:
179
+ q_img = q["image"]
180
+ q_question = q.get("question") or ""
181
+ q_answer = q["answer"]
182
+
183
+ inst = rng.choice(instructions) if instructions else default_inst
184
+
185
+ answer_at = rng.choices(answer_at_values, weights=answer_at_weights, k=1)[0]
186
+ answer_at = min(answer_at, max_shots)
187
+
188
+ top5 = top5_map.get(q_img, [])
189
+ if answer_at > 0 and not top5:
190
+ continue
191
+
192
+ # 降级处理
193
+ if answer_at > len(top5):
194
+ answer_at = len(top5)
195
+
196
+ # 从 top5 里随机选 answer_at 个
197
+ chosen = rng.sample(top5, answer_at) if answer_at > 0 else []
198
+
199
+ shots = []
200
+ for img_path in chosen:
201
+ cap = caption_map.get(img_path, "")
202
+ shots.append({"image": img_path, "caption": cap})
203
+
204
+ remaining = [p for p in top5 if p not in chosen]
205
+
206
+ # ---- 轨迹式生成:每条记录只有一条一致的 RET→...→ANS 轨迹 ----
207
+ # answer_at=0: 直接 ANS(0-shot)
208
+ # answer_at=2: RET(0-shot) → RET(1-shot) → ANS(2-shot)
209
+ # 不在同一个 (image, question, n-shot) 下同时生成 RET 和 ANS,避免矛盾信号
210
+ for n in range(answer_at):
211
+ if n < len(chosen):
212
+ next_desc = caption_map.get(chosen[n], "")
213
+ elif remaining:
214
+ next_desc = caption_map.get(rng.choice(remaining), "")
215
+ else:
216
+ break
217
+
218
+ # RET 样本:在 n-shot 时决定继续检索
219
+ samples.append({
220
+ "type": "ret",
221
+ "query_image": q_img,
222
+ "question": q_question,
223
+ "answer": q_answer,
224
+ "instruction": inst,
225
+ "shots": shots[:n],
226
+ "next_description": next_desc,
227
+ "category": cat,
228
+ "dataset": ds,
229
+ })
230
+
231
+ # ANS 样本:在 answer_at shot 时回答
232
+ samples.append({
233
+ "type": "ans",
234
+ "query_image": q_img,
235
+ "question": q_question,
236
+ "answer": q_answer,
237
+ "instruction": inst,
238
+ "shots": shots[:answer_at],
239
+ "category": cat,
240
+ "dataset": ds,
241
+ })
242
+
243
+ if target_count > 0 and len(samples) >= target_count:
244
+ break
245
+
246
+ if target_count > 0:
247
+ samples = samples[:target_count]
248
+
249
+ return samples
250
+
251
+
252
+ # ---------------------------------------------------------------------------
253
+ # 文件工具
254
+ # ---------------------------------------------------------------------------
255
+ def write_jsonl(path: str, records: List[Dict]):
256
+ os.makedirs(os.path.dirname(path), exist_ok=True)
257
+ with open(path, "w", encoding="utf-8") as f:
258
+ for r in records:
259
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
260
+
261
+
262
+ def concat_and_shuffle(output_dir: str, categories: List[str], shuffle: bool, seed: int):
263
+ """合并各 category 的分片,生成最终数据集。"""
264
+ rng = random.Random(seed)
265
+
266
+ for cat in categories:
267
+ cat_dir = os.path.join(output_dir, cat)
268
+ if not os.path.isdir(cat_dir):
269
+ continue
270
+ parts = sorted(Path(cat_dir).glob("sft.part*.jsonl"))
271
+ if not parts:
272
+ continue
273
+ out_path = os.path.join(cat_dir, "sft.jsonl")
274
+ lines = []
275
+ for p in parts:
276
+ with open(p, "r", encoding="utf-8") as f:
277
+ lines.extend(line for line in f if line.strip())
278
+ if shuffle:
279
+ rng.shuffle(lines)
280
+ with open(out_path, "w", encoding="utf-8") as f:
281
+ f.writelines(lines)
282
+ print(f" [OK] {cat}: {len(lines)} 条")
283
+
284
+ # 合并所有 category
285
+ all_lines = []
286
+ for cat in categories:
287
+ cat_file = os.path.join(output_dir, cat, "sft.jsonl")
288
+ if os.path.exists(cat_file):
289
+ with open(cat_file, "r", encoding="utf-8") as f:
290
+ all_lines.extend(line for line in f if line.strip())
291
+ if all_lines:
292
+ if shuffle:
293
+ rng.shuffle(all_lines)
294
+ all_dir = os.path.join(output_dir, "all")
295
+ os.makedirs(all_dir, exist_ok=True)
296
+ all_path = os.path.join(all_dir, "sft.jsonl")
297
+ with open(all_path, "w", encoding="utf-8") as f:
298
+ f.writelines(all_lines)
299
+ print(f" [OK] all: {len(all_lines)} 条 → {all_path}")
300
+
301
+
302
+ # ---------------------------------------------------------------------------
303
+ # Main
304
+ # ---------------------------------------------------------------------------
305
+ def main():
306
+ parser = argparse.ArgumentParser(description="构建单步决策 SFT 数据集")
307
+
308
+ # 路径
309
+ parser.add_argument("--index-root", default=INDEX_ROOT)
310
+ parser.add_argument("--embeddings-dir", default=EMBEDDINGS_DIR)
311
+ parser.add_argument("--caption-cache-dir", default=CAPTION_CACHE_DIR)
312
+ parser.add_argument("--output-dir", default=OUTPUT_DIR)
313
+
314
+ # 数据集选择
315
+ parser.add_argument("--categories", default="vqa,captioning,classification,reasoning")
316
+ parser.add_argument("--split", default="train", help="query 来自哪个 split")
317
+ parser.add_argument("--top-k", type=int, default=5)
318
+
319
+ # 样本参数
320
+ parser.add_argument("--samples-per-cat", type=int, default=0,
321
+ help="每类目标数,0=全量遍历所有记录")
322
+ parser.add_argument("--samples-per-ds", type=int, default=0,
323
+ help="每个数据集最多取多少条原始记录(0=不限)")
324
+ parser.add_argument("--max-shots", type=int, default=3)
325
+ parser.add_argument("--answer-at-weights", default="1,3,3,2",
326
+ help="0/1/2/3-shot 的权重(默认 1,3,3,2,鼓励多轮 RET)")
327
+ parser.add_argument("--seed", type=int, default=42)
328
+
329
+ # 分片
330
+ parser.add_argument("--shard-id", type=int, default=0)
331
+ parser.add_argument("--num-shards", type=int, default=1)
332
+
333
+ # 模式
334
+ parser.add_argument("--merge", action="store_true", help="合并分片")
335
+ parser.add_argument("--shuffle", action="store_true", help="合并时 shuffle")
336
+
337
+ args = parser.parse_args()
338
+ categories = [c.strip() for c in args.categories.split(",") if c.strip()]
339
+
340
+ # ---- 合并模式 ----
341
+ if args.merge:
342
+ print("合并分片...")
343
+ concat_and_shuffle(args.output_dir, categories, args.shuffle, args.seed)
344
+ return
345
+
346
+ # ---- 构建模式 ----
347
+ aw = [float(x) for x in args.answer_at_weights.split(",") if x.strip()]
348
+ rng = random.Random(args.seed + args.shard_id * 1000003)
349
+
350
+ datasets = discover_datasets(args.index_root, categories)
351
+ print(f"共 {len(datasets)} 个数据集")
352
+
353
+ # 按 category 分组
354
+ cat_datasets: Dict[str, List[Tuple[str, str]]] = {}
355
+ for cat, ds in datasets:
356
+ cat_datasets.setdefault(cat, []).append((cat, ds))
357
+
358
+ for cat in categories:
359
+ ds_list = cat_datasets.get(cat, [])
360
+ if not ds_list:
361
+ print(f"[SKIP] {cat}: 无数据集")
362
+ continue
363
+
364
+ # 加载数据
365
+ ds_data = []
366
+ for c, d in ds_list:
367
+ records = load_index(args.index_root, c, d, args.split)
368
+ top5 = load_top5(args.embeddings_dir, c, d, args.top_k)
369
+ captions = load_captions(args.caption_cache_dir, c, d)
370
+ insts = load_instructions(args.index_root, c, d)
371
+ if not records or not top5:
372
+ print(f" [SKIP] {c}/{d}: records={len(records)} top5={len(top5)}")
373
+ continue
374
+ # 预检:有多少条记录同时有 answer + top5 覆盖
375
+ n_valid = sum(1 for r in records
376
+ if r.get("answer") and r.get("image") in top5)
377
+ if n_valid == 0:
378
+ print(f" [SKIP] {c}/{d}: {len(records)} 条但无 answer+top5 覆盖")
379
+ continue
380
+
381
+ ds_data.append({
382
+ "cat": c, "ds": d,
383
+ "records": records, "top5": top5,
384
+ "captions": captions, "instructions": insts,
385
+ })
386
+ # 统计 caption 覆盖率
387
+ n_cap = sum(1 for r in records if r.get("image") in captions)
388
+ n_top5 = sum(1 for r in records if r.get("image") in top5)
389
+ print(f" [OK] {c}/{d}: {len(records)} 条, "
390
+ f"valid={n_valid}, top5覆盖={n_top5}, caption覆盖={n_cap}, "
391
+ f"instructions={len(insts)}")
392
+
393
+ if not ds_data:
394
+ print(f"[WARN] {cat}: 无可用数据集")
395
+ continue
396
+
397
+ all_samples = []
398
+
399
+ # 计算每个数据集该抽多少条原始记录
400
+ n_ds = len(ds_data)
401
+ if args.samples_per_cat > 0:
402
+ # 目标: 每类 samples_per_cat 条 SFT 样本
403
+ # 保守估计每条记录生成 ~1.5 条样本(captioning等可能更少)
404
+ # 多抽一些,最后按 samples_per_cat 截断
405
+ records_per_ds = max(200, int(args.samples_per_cat / 1.0 / n_ds))
406
+ elif args.samples_per_ds > 0:
407
+ records_per_ds = args.samples_per_ds
408
+ else:
409
+ records_per_ds = 0 # 全量
410
+
411
+ print(f" {cat}: {n_ds} 个数据集, 每个抽 {records_per_ds} 条记录" if records_per_ds > 0
412
+ else f" {cat}: {n_ds} 个数据集, 全量")
413
+
414
+ for d in tqdm(ds_data, desc=f"{cat} shard{args.shard_id}"):
415
+ recs = d["records"]
416
+
417
+ # 抽样
418
+ if records_per_ds > 0 and len(recs) > records_per_ds:
419
+ recs = rng.sample(recs, records_per_ds)
420
+
421
+ samples = generate_samples(
422
+ records=recs,
423
+ top5_map=d["top5"],
424
+ caption_map=d["captions"],
425
+ instructions=d["instructions"],
426
+ cat=d["cat"], ds=d["ds"],
427
+ rng=rng,
428
+ max_shots=args.max_shots,
429
+ answer_at_weights=aw,
430
+ target_count=0, # 遍历抽出的所有记录
431
+ )
432
+ all_samples.extend(samples)
433
+
434
+ # 截断到目标数(仅 samples-per-cat>0 时)
435
+ if args.samples_per_cat > 0 and len(all_samples) > args.samples_per_cat:
436
+ rng.shuffle(all_samples)
437
+ all_samples = all_samples[:args.samples_per_cat]
438
+
439
+ # shuffle 保证混合
440
+ rng.shuffle(all_samples)
441
+
442
+ # 写出
443
+ out_path = os.path.join(args.output_dir, cat, f"sft.part{args.shard_id:02d}.jsonl")
444
+ write_jsonl(out_path, all_samples)
445
+
446
+ # 统计
447
+ n_ret = sum(1 for r in all_samples if r["type"] == "ret")
448
+ n_ans = sum(1 for r in all_samples if r["type"] == "ans")
449
+ n_dist = {}
450
+ for r in all_samples:
451
+ nc = len(r.get("shots", []))
452
+ n_dist[nc] = n_dist.get(nc, 0) + 1
453
+ print(f"[OK] {cat} shard{args.shard_id}: {len(all_samples)} 条 "
454
+ f"(ret={n_ret} ans={n_ans}) shot分布={dict(sorted(n_dist.items()))}")
455
+ print(f" → {out_path}")
456
+
457
+ # 单 shard 时自动合并 + shuffle
458
+ if args.num_shards == 1:
459
+ print("\n自动合并所有 category...")
460
+ concat_and_shuffle(args.output_dir, categories, shuffle=True, seed=args.seed)
461
+
462
+ print(f"\n完成!输出: {args.output_dir}")
463
+
464
+
465
+ if __name__ == "__main__":
466
+ main()
ICL/dataset_inspect.tree.txt ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ M3IT/
2
+ .git/
3
+ data/
4
+ .gitattributes (2.8KB)
5
+ .gitignore (29.0B)
6
+ M3IT.py (54.5KB)
7
+ README.md (18.3KB)
8
+ branches/
9
+ hooks/
10
+ info/
11
+ lfs/
12
+ logs/
13
+ objects/
14
+ refs/
15
+ FETCH_HEAD (110.0B)
16
+ HEAD (21.0B)
17
+ config (339.0B)
18
+ description (73.0B)
19
+ packed-refs (112.0B)
20
+ refs/
21
+ HEAD (189.0B)
22
+ heads/
23
+ remotes/
24
+ main (189.0B)
25
+ heads/
26
+ remotes/
27
+ tags/
28
+ origin/
29
+ HEAD (30.0B)
30
+ main (41.0B)
31
+ info/
32
+ pack/
33
+ pack-ee3e40a1a23ec17affa3b8afb61dc14bdffb229c.idx (38.9KB)
34
+ pack-ee3e40a1a23ec17affa3b8afb61dc14bdffb229c.pack (195.5KB)
35
+ applypatch-msg.sample (478.0B)
36
+ commit-msg.sample (896.0B)
37
+ fsmonitor-watchman.sample (4.5KB)
38
+ post-checkout (280.0B)
39
+ post-commit (276.0B)
40
+ post-merge (274.0B)
41
+ post-update.sample (189.0B)
42
+ pre-applypatch.sample (424.0B)
43
+ pre-commit.sample (1.6KB)
44
+ pre-merge-commit.sample (416.0B)
45
+ pre-push (270.0B)
46
+ pre-push.sample (1.3KB)
47
+ pre-rebase.sample (4.8KB)
48
+ pre-receive.sample (544.0B)
49
+ prepare-commit-msg.sample (1.5KB)
50
+ push-to-checkout.sample (2.7KB)
51
+ update.sample (3.6KB)
52
+ incomplete/
53
+ logs/
54
+ objects/
55
+ tmp/
56
+ 0152398d9443f2d300adc9e6099a773c66303d4e2e085812cd502cb36da7a0c73483193049 (0.0B)
57
+ 0152398d9443f2d300adc9e6099a773c66303d4e2e085812cd502cb36da7a0c7763208216 (0.0B)
58
+ 0152398d9443f2d300adc9e6099a773c66303d4e2e085812cd502cb36da7a0c789921672 (2.5MB)
59
+ 0968a4438d46277583968011563e959e130feaee66f51bb2d66dbd7e8c979f8c.part (0.0B)
60
+ 1f77f56225e10edca84be06b6e0d796c579cbf1d4884aee46da564438ad1ba9b1484563810 (437.0KB)
61
+ 1f77f56225e10edca84be06b6e0d796c579cbf1d4884aee46da564438ad1ba9b3850099655 (326.7KB)
62
+ 1f77f56225e10edca84be06b6e0d796c579cbf1d4884aee46da564438ad1ba9b3898577811 (4.1MB)
63
+ 220d32d087b6b29d1c5aaa49324d32b32ae1c19f42e9800f40f24d3a695c2a8d1743027097 (0.0B)
64
+ 220d32d087b6b29d1c5aaa49324d32b32ae1c19f42e9800f40f24d3a695c2a8d3014727128 (0.0B)
65
+ 220d32d087b6b29d1c5aaa49324d32b32ae1c19f42e9800f40f24d3a695c2a8d71894927 (62.6KB)
66
+ 24f014bb5bc7b1fa7d9183dd65fd4b43c0c49aafd6af01bb91ae3a0e7e65502b2818819757 (49.3MB)
67
+ 3da69649bfbc671710f38c2c2f7c6aaecb8f8544de3446866054bf927257c9332854861486 (158.6KB)
68
+ 3da69649bfbc671710f38c2c2f7c6aaecb8f8544de3446866054bf927257c9334214717938 (0.0B)
69
+ 3da69649bfbc671710f38c2c2f7c6aaecb8f8544de3446866054bf927257c933593947826 (0.0B)
70
+ 45e8c51ed0df8edb1ae51d2012b3f7d6cd9cc84addf41e6f9f9adb0f625d41033126870057 (259.2MB)
71
+ 4a80559730d917177e4d13246da0ce23ca318735b29d519d0448bea5579b1a771450117433 (154.4MB)
72
+ 4fda2aa4918e5dec847935db6d46e9bebc570a173bd4201c5f48e60a3f73813a1530155941 (1.1MB)
73
+ 4fda2aa4918e5dec847935db6d46e9bebc570a173bd4201c5f48e60a3f73813a2738070238 (0.0B)
74
+ 4fda2aa4918e5dec847935db6d46e9bebc570a173bd4201c5f48e60a3f73813a2828099128 (0.0B)
75
+ 52a445f8a26cd898e64129e7f1d4bfa6d7203311442068684f5344fc73407310.part (0.0B)
76
+ 6728a8fb7bad0bad3a2a27669232cb9ae66461c635172f1f7958c80a28e09fa32607733000 (150.2MB)
77
+ 6bb6c9f17e77eab7d88e4a4501c38cb31a6cf792fe77e3b75d511b964a5667df2998182268 (91.8MB)
78
+ 8cb15647ff6bbac322142fea1a38599c523f73acb3614ddb7d12e6a1975a79dc1986657385 (0.0B)
79
+ 8cb15647ff6bbac322142fea1a38599c523f73acb3614ddb7d12e6a1975a79dc2743098052 (0.0B)
80
+ 8cb15647ff6bbac322142fea1a38599c523f73acb3614ddb7d12e6a1975a79dc4193739161 (0.0B)
81
+ 9919274ad6bc88e37235a4c7245d05e357e404ef3352a90a1ba0594e694893c01114223911 (0.0B)
82
+ 9919274ad6bc88e37235a4c7245d05e357e404ef3352a90a1ba0594e694893c03545613611 (0.0B)
83
+ 9919274ad6bc88e37235a4c7245d05e357e404ef3352a90a1ba0594e694893c0559090370 (2.8MB)
84
+ 9cdf4d1a6972db893c8db1a4f2be0d1ec0362ba22a44542402b336760029c87253830692 (88.0MB)
85
+ b6aed90c79d180c5346994f8e7d0657b3d8a9aab002c057503736b4013a2096b.part (0.0B)
86
+ ba47b9680dc949322877399218d1f210a057249803bc70addfb9528152e4b1662004000729 (218.5MB)
87
+ ca49e0b3f3400f38519a1103b2a567db32c9fa990a7395b1024b94454601479b.part (0.0B)
88
+ d66a5b3267a7935b8ff272bcc166a8f43a8d66fb89c59503d536ac87661a02022501429466 (0.0B)
89
+ d66a5b3267a7935b8ff272bcc166a8f43a8d66fb89c59503d536ac87661a020230475132 (0.0B)
90
+ d66a5b3267a7935b8ff272bcc166a8f43a8d66fb89c59503d536ac87661a0202373225118 (62.5KB)
91
+ e5a3eb3e2d0c47d6f014e294ef7398bf26375920c8d2af80fd65e255396dcc78.part (0.0B)
92
+ f19cacf3a9f9a57abdcafc4a6d242aa9c6fa48188ad0a394b1a2558cb8ab4dc5372340294 (199.2MB)
93
+ 20251021T152133.441099492.log (1.4KB)
94
+ 01/
95
+ 02/
96
+ 03/
97
+ 05/
98
+ 06/
99
+ 07/
100
+ 09/
101
+ 0b/
102
+ 0f/
103
+ 10/
104
+ 12/
105
+ 15/
106
+ 16/
107
+ 19/
108
+ 1d/
109
+ 1e/
110
+ 1f/
111
+ 21/
112
+ 22/
113
+ 23/
114
+ 24/
115
+ 2a/
116
+ 2b/
117
+ 2c/
118
+ 2d/
119
+ 2f/
120
+ 30/
121
+ 32/
122
+ 34/
123
+ 37/
124
+ 3b/
125
+ 3d/
126
+ 44/
127
+ 45/
128
+ 4a/
129
+ 4f/
130
+ 50/
131
+ 52/
132
+ 54/
133
+ 56/
134
+ 58/
135
+ 5a/
136
+ 5b/
137
+ 60/
138
+ 61/
139
+ 64/
140
+ 65/
141
+ 67/
142
+ 68/
143
+ 69/
144
+ 6b/
145
+ 6d/
146
+ 6e/
147
+ 70/
148
+ 75/
149
+ 76/
150
+ 7b/
151
+ 7c/
152
+ 80/
153
+ 87/
154
+ 88/
155
+ 89/
156
+ 8b/
157
+ 8c/
158
+ 90/
159
+ 91/
160
+ 93/
161
+ 99/
162
+ 9a/
163
+ 9b/
164
+ 9c/
165
+ 9e/
166
+ 9f/
167
+ a0/
168
+ a5/
169
+ a9/
170
+ ac/
171
+ ae/
172
+ b1/
173
+ b3/
174
+ b4/
175
+ b6/
176
+ ba/
177
+ bb/
178
+ bc/
179
+ bd/
180
+ be/
181
+ c0/
182
+ c1/
183
+ c2/
184
+ c4/
185
+ c6/
186
+ c7/
187
+ c8/
188
+ ca/
189
+ cb/
190
+ d6/
191
+ d9/
192
+ dd/
193
+ e2/
194
+ e5/
195
+ e7/
196
+ e8/
197
+ e9/
198
+ ee/
199
+ ef/
200
+ f1/
201
+ f3/
202
+ f4/
203
+ f5/
204
+ f6/
205
+ f7/
206
+ f8/
207
+ f9/
208
+ fc/
209
+ exclude (240.0B)
210
+ captioning/
211
+ classification/
212
+ generation/
213
+ reasoning/
214
+ vqa/
215
+ chinesefoodnet-10/
216
+ coco-goi/
217
+ coco-text/
218
+ imagenet/
219
+ iqa/
220
+ itm/
221
+ mocheg/
222
+ refcoco/
223
+ snli-ve/
224
+ ss/
225
+ vsr/
226
+ winoground/
227
+ .gitattributes (141.0B)
228
+ README.md (211.0B)
229
+ instructions.json (1.4KB)
230
+ labels.json (9.0KB)
231
+ test.jsonl (223.5MB)
232
+ train.jsonl (238.9MB)
233
+ val.jsonl (227.6MB)
234
+ README.md (31.0B)
235
+ esnlive_test.jsonl (743.0MB)
236
+ esnlive_train.jsonl (1000.8MB)
237
+ esnlive_val.jsonl (717.9MB)
238
+ instructions.json (1.9KB)
239
+ test_2023-10-09.jsonl (2.9GB)
240
+ train_2023-10-09.jsonl (3.9GB)
241
+ instructions.json (825.0B)
242
+ mapping.txt (30.9KB)
243
+ test_2023-10-08.jsonl (10.6GB)
244
+ train.jsonl (1.5GB)
245
+ train_2023-10-08.jsonl (5.9GB)
246
+ val.jsonl (2.6GB)
247
+ instructions.json (907.0B)
248
+ test.jsonl (330.4MB)
249
+ test_2023-10-09.jsonl (1.3GB)
250
+ train.jsonl (1.9GB)
251
+ train_2023-10-08.jsonl (7.8GB)
252
+ val.jsonl (330.8MB)
253
+ instructions.json (773.0B)
254
+ test.jsonl (730.0MB)
255
+ test_2023-10-09.jsonl (2.9GB)
256
+ train.jsonl (4.3GB)
257
+ train_2023-10-08.jsonl (17.1GB)
258
+ val.jsonl (730.2MB)
259
+ instructions.json (1.4KB)
260
+ test_2023-10-09.jsonl (553.7MB)
261
+ train_2023-10-09.jsonl (1.9GB)
262
+ vsr_test.jsonl (137.7MB)
263
+ vsr_train.jsonl (483.3MB)
264
+ vsr_val.jsonl (68.8MB)
265
+ instructions.json (774.0B)
266
+ test_2023-10-10.jsonl (7.6GB)
267
+ train.jsonl (8.2GB)
268
+ train_2023-10-08.jsonl (32.8GB)
269
+ val.jsonl (1.9GB)
270
+ instructions.json (733.0B)
271
+ test_2023-10-07.jsonl (279.1MB)
272
+ train.jsonl (2.0GB)
273
+ train_2023-10-06.jsonl (4.1GB)
274
+ val.jsonl (138.9MB)
275
+ instructions.json (2.0KB)
276
+ winoground_test.jsonl (245.5MB)
277
+ instructions.json (1.3KB)
278
+ test.jsonl (122.9MB)
279
+ instructions.json (1.0KB)
280
+ mocheg_test.jsonl (60.3MB)
281
+ mocheg_train.jsonl (631.7MB)
282
+ mocheg_val.jsonl (28.2MB)
283
+ test_2023-10-08.jsonl (242.5MB)
284
+ train_2023-10-08.jsonl (2.5GB)
285
+ instructions.json (1.5KB)
286
+ test.jsonl (701.9MB)
287
+ test_2023-10-08.jsonl (2.7GB)
288
+ train.jsonl (3.9GB)
289
+ train_2023-10-08.jsonl (15.6GB)
290
+ val.jsonl (667.7MB)
291
+ clevr/
292
+ nlvr/
293
+ science_qa/
294
+ vcr/
295
+ visual_mrc/
296
+ instructions.json (2.5KB)
297
+ science_qa_test.jsonl (174.0MB)
298
+ science_qa_train.jsonl (531.3MB)
299
+ science_qa_validation.jsonl (176.4MB)
300
+ instructions.json (976.0B)
301
+ train.jsonl (5.6GB)
302
+ train_2023-10-07.jsonl (11.1GB)
303
+ val.jsonl (379.6MB)
304
+ val_2023-10-07.jsonl (760.4MB)
305
+ instructions.json (911.0B)
306
+ test.jsonl (1.2GB)
307
+ train.jsonl (3.9GB)
308
+ val.jsonl (266.9MB)
309
+ instructions.json (1.3KB)
310
+ test.jsonl (909.3MB)
311
+ train.jsonl (4.3GB)
312
+ val.jsonl (992.9MB)
313
+ instructions.json (1.2KB)
314
+ test.jsonl (489.0MB)
315
+ train.jsonl (7.9GB)
316
+ val.jsonl (533.3MB)
317
+ mmchat/
318
+ multi30k/
319
+ vist/
320
+ visual_dialog/
321
+ instructions.json (818.0B)
322
+ test.jsonl (65.2MB)
323
+ test_2023-10-10.jsonl (262.2MB)
324
+ train.jsonl (3.2GB)
325
+ train_2023-10-09.jsonl (13.0GB)
326
+ val.jsonl (66.0MB)
327
+ instructions.json (1.2KB)
328
+ test.jsonl (610.6MB)
329
+ train.jsonl (4.4GB)
330
+ val.jsonl (301.1MB)
331
+ instructions.json (809.0B)
332
+ test.jsonl (2.3GB)
333
+ train.jsonl (6.2GB)
334
+ train_new.jsonl (6.2GB)
335
+ validation.jsonl (2.0GB)
336
+ instructions.json (1.0KB)
337
+ test.jsonl (14.0GB)
338
+ train.jsonl (15.4GB)
339
+ val.jsonl (13.0GB)
340
+ a-okvqa/
341
+ activitynet-qa/
342
+ docvqa/
343
+ fm-iqa/
344
+ gqa/
345
+ ivqa/
346
+ msrvtt-qa/
347
+ msvd-qa/
348
+ ocr-vqa/
349
+ okvqa/
350
+ shapes/
351
+ st-vqa/
352
+ text-vqa/
353
+ viquae/
354
+ vqav2/
355
+ instruction.json (905.0B)
356
+ train.jsonl (533.5MB)
357
+ train_new.jsonl (533.5MB)
358
+ validation.jsonl (228.3MB)
359
+ instructions.json (1.9KB)
360
+ train.jsonl (1.2GB)
361
+ train_v2.jsonl (1.2GB)
362
+ val.jsonl (77.7MB)
363
+ val_v2.jsonl (78.2MB)
364
+ instruction.json (905.0B)
365
+ test.jsonl (713.3MB)
366
+ train.jsonl (3.3GB)
367
+ validation_new.jsonl (529.5MB)
368
+ instruction.json (772.0B)
369
+ train.jsonl (1.5GB)
370
+ validation.jsonl (260.3MB)
371
+ instruction.json (853.0B)
372
+ test.jsonl (229.4MB)
373
+ train.jsonl (1.4GB)
374
+ README.md (288.0B)
375
+ instructions.json (1.2KB)
376
+ test.jsonl (132.4MB)
377
+ train.jsonl (343.1MB)
378
+ val.jsonl (60.9MB)
379
+ instructions.json (853.0B)
380
+ train.jsonl (1.9GB)
381
+ val.jsonl (1.9GB)
382
+ instructions.json (1.7KB)
383
+ train.jsonl (7.2GB)
384
+ val.jsonl (976.6MB)
385
+ instructions.json (1.5KB)
386
+ test.jsonl (1.4MB)
387
+ test_2023-10-08.jsonl (7.0MB)
388
+ train.large.jsonl (18.3MB)
389
+ train_2023-10-08.jsonl (92.6MB)
390
+ val.jsonl (1.4MB)
391
+ README.md (334.0B)
392
+ instructions.json (1.0KB)
393
+ test.jsonl (500.8MB)
394
+ train.jsonl (1.5GB)
395
+ val.jsonl (485.4MB)
396
+ README.md (434.0B)
397
+ instructions.json (1.0KB)
398
+ test.jsonl (348.1MB)
399
+ train.jsonl (757.5MB)
400
+ val.jsonl (58.0MB)
401
+ .gitattributes (141.0B)
402
+ README.md (332.0B)
403
+ instructions.json (1.4KB)
404
+ test.jsonl (474.7MB)
405
+ train.jsonl (2.1GB)
406
+ val.jsonl (1.1GB)
407
+ instructions.json (1.2KB)
408
+ train.jsonl (594.8MB)
409
+ train_v2.jsonl (596.3MB)
410
+ val.jsonl (334.3MB)
411
+ val_v2.jsonl (335.2MB)
412
+ instructions.json (802.0B)
413
+ para_train.jsonl (10.5GB)
414
+ para_val.jsonl (4.8GB)
415
+ train.jsonl (10.5GB)
416
+ val.jsonl (4.8GB)
417
+ instructions.json (1.2KB)
418
+ test.jsonl (122.5MB)
419
+ test_v2.jsonl (120.9MB)
420
+ train.jsonl (110.1MB)
421
+ train_v2.jsonl (110.2MB)
422
+ validation.jsonl (125.5MB)
423
+ validation_v2.jsonl (125.6MB)
424
+ coco/
425
+ coco-cn/
426
+ flickr8k-cn/
427
+ image_paragraph_captioning/
428
+ msrvtt/
429
+ textcap/
430
+ .gitattributes (141.0B)
431
+ README.md (490.0B)
432
+ instructions.json (1010.0B)
433
+ test.jsonl (117.1MB)
434
+ train.jsonl (231.1MB)
435
+ val.jsonl (116.9MB)
436
+ instructions.json (541.0B)
437
+ test.jsonl (49.4MB)
438
+ train.jsonl (300.0MB)
439
+ val.jsonl (49.9MB)
440
+ instructions.json (790.0B)
441
+ test.jsonl (66.4MB)
442
+ train.jsonl (1.2GB)
443
+ val.jsonl (65.0MB)
444
+ image_paragraph_captioning_test.jsonl (120.7MB)
445
+ image_paragraph_captioning_train.jsonl (701.2MB)
446
+ image_paragraph_captioning_val.jsonl (118.0MB)
447
+ instruction.json (1.4KB)
448
+ README.md (73.0B)
449
+ create_dataset.py (5.5KB)
450
+ instructions.json (882.0B)
451
+ test.jsonl (333.1MB)
452
+ train.jsonl (7.4GB)
453
+ val.jsonl (333.4MB)
454
+ instructions.json (1.1KB)
455
+ train.jsonl (5.7GB)
456
+ val.jsonl (851.3MB)
ICL/eval_icl.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ICL 推理评测脚本:模拟多轮 RET/ANS 决策循环。支持多卡并行。
4
+
5
+ 流程:
6
+ 1. 给模型 query_image + question(0-shot)
7
+ 2. 模型输出 <RET> → 用预计算 top5 检索下一张图+caption,追加到 context,再问
8
+ 3. 模型输出 <ANS> → 结束,提取答案
9
+ 4. 最多 max_rounds 轮(防止一直 RET)
10
+
11
+ 多卡策略:
12
+ 每张 GPU 加载一份模型,按 dataset 粒度分配任务,最后 rank 0 汇总。
13
+
14
+ 用法:
15
+ # 单卡
16
+ python3 eval_icl.py \
17
+ --model-path /workspace/xiaobin/ICL/sft_model/merged_hf \
18
+ --category vqa --dataset vqav2 --split val \
19
+ --num-samples 200 --max-rounds 4 --device cuda:0
20
+
21
+ # 多卡(8 GPU)
22
+ torchrun --nproc_per_node=8 eval_icl.py \
23
+ --model-path /workspace/xiaobin/ICL/sft_model/merged_hf \
24
+ --all-categories --num-samples 100 --max-rounds 4
25
+ """
26
+
27
+ import argparse
28
+ import json
29
+ import os
30
+ import random
31
+ import sys
32
+ import time
33
+ from collections import defaultdict
34
+ from typing import Dict, List, Optional, Tuple
35
+
36
+ import torch
37
+ import torch.distributed as dist
38
+ from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
39
+ from qwen_vl_utils import process_vision_info
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # 默认路径
43
+ # ---------------------------------------------------------------------------
44
+ INDEX_ROOT = "/workspace/xiaobin/dataset/index"
45
+ EMBEDDINGS_DIR = "/workspace/xiaobin/dataset/embeddings"
46
+ CAPTION_CACHE_DIR = "/workspace/xiaobin/dataset/caption_cache"
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # 分布式工具
50
+ # ---------------------------------------------------------------------------
51
+
52
+ def setup_distributed():
53
+ """初始化分布式环境,返回 (rank, world_size, device)。
54
+ 单卡时 rank=0, world_size=1。"""
55
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
56
+ rank = int(os.environ["RANK"])
57
+ world_size = int(os.environ["WORLD_SIZE"])
58
+ local_rank = int(os.environ.get("LOCAL_RANK", rank))
59
+ dist.init_process_group("nccl")
60
+ torch.cuda.set_device(local_rank)
61
+ device = f"cuda:{local_rank}"
62
+ else:
63
+ rank, world_size = 0, 1
64
+ device = None # 由 args.device 决定
65
+ return rank, world_size, device
66
+
67
+
68
+ def gather_results(local_results: List[Dict], rank: int, world_size: int) -> List[Dict]:
69
+ """把各 rank 的结果汇总到 rank 0。"""
70
+ if world_size == 1:
71
+ return local_results
72
+
73
+ # 序列化 → bytes → tensor
74
+ data = json.dumps(local_results, ensure_ascii=False).encode("utf-8")
75
+ size = torch.tensor([len(data)], dtype=torch.long, device=f"cuda:{rank}")
76
+
77
+ # 收集各 rank 的大小
78
+ size_list = [torch.zeros(1, dtype=torch.long, device=f"cuda:{rank}") for _ in range(world_size)]
79
+ dist.all_gather(size_list, size)
80
+ max_size = max(s.item() for s in size_list)
81
+
82
+ # pad 到相同长度
83
+ padded = data + b"\x00" * (max_size - len(data))
84
+ tensor = torch.ByteTensor(list(padded)).cuda(rank)
85
+
86
+ tensor_list = [torch.zeros(max_size, dtype=torch.uint8, device=f"cuda:{rank}") for _ in range(world_size)]
87
+ dist.all_gather(tensor_list, tensor)
88
+
89
+ if rank == 0:
90
+ all_results = []
91
+ for i, (t, s) in enumerate(zip(tensor_list, size_list)):
92
+ raw = bytes(t[:s.item()].cpu().tolist())
93
+ all_results.extend(json.loads(raw.decode("utf-8")))
94
+ return all_results
95
+ return []
96
+
97
+
98
+ # ---------------------------------------------------------------------------
99
+ # 数据加载
100
+ # ---------------------------------------------------------------------------
101
+
102
+ def load_records(cat: str, ds: str, split: str, limit: int = 0) -> List[Dict]:
103
+ path = os.path.join(INDEX_ROOT, cat, ds, f"{split}.jsonl")
104
+ if not os.path.exists(path):
105
+ return []
106
+ records = []
107
+ with open(path, "r", encoding="utf-8") as f:
108
+ for line in f:
109
+ line = line.strip()
110
+ if not line:
111
+ continue
112
+ r = json.loads(line)
113
+ if r.get("image") and r.get("answer"):
114
+ records.append(r)
115
+ if limit and len(records) >= limit:
116
+ break
117
+ return records
118
+
119
+
120
+ def load_top5(cat: str, ds: str) -> Dict[str, List[str]]:
121
+ path = os.path.join(EMBEDDINGS_DIR, f"{cat}_{ds}_top5.json")
122
+ if not os.path.exists(path):
123
+ return {}
124
+ with open(path, "r", encoding="utf-8") as f:
125
+ return json.load(f)
126
+
127
+
128
+ def load_caption_cache(cat: str, ds: str) -> Dict[str, str]:
129
+ path = os.path.join(CAPTION_CACHE_DIR, f"{cat}_{ds}.json")
130
+ if not os.path.exists(path):
131
+ return {}
132
+ with open(path, "r", encoding="utf-8") as f:
133
+ return json.load(f)
134
+
135
+
136
+ def load_instructions(cat: str, ds: str) -> List[str]:
137
+ path = os.path.join(INDEX_ROOT, cat, ds, "instructions.json")
138
+ if not os.path.exists(path):
139
+ return ["Look at the image and answer the question."]
140
+ with open(path, "r", encoding="utf-8") as f:
141
+ return json.load(f)
142
+
143
+
144
+ def discover_datasets(categories: List[str]) -> List[Tuple[str, str]]:
145
+ results = []
146
+ for cat in sorted(os.listdir(INDEX_ROOT)):
147
+ if categories and cat not in categories:
148
+ continue
149
+ cat_dir = os.path.join(INDEX_ROOT, cat)
150
+ if not os.path.isdir(cat_dir):
151
+ continue
152
+ for ds in sorted(os.listdir(cat_dir)):
153
+ ds_dir = os.path.join(cat_dir, ds)
154
+ if os.path.isdir(ds_dir):
155
+ results.append((cat, ds))
156
+ return results
157
+
158
+
159
+ # ---------------------------------------------------------------------------
160
+ # 模型加载
161
+ # ---------------------------------------------------------------------------
162
+
163
+ def load_model(model_path: str, device: str):
164
+ print(f"[{device}] Loading model from {model_path} ...")
165
+ processor = AutoProcessor.from_pretrained(
166
+ model_path, trust_remote_code=True,
167
+ min_pixels=256 * 28 * 28,
168
+ max_pixels=1280 * 28 * 28,
169
+ )
170
+
171
+ model = Qwen3VLForConditionalGeneration.from_pretrained(
172
+ model_path,
173
+ trust_remote_code=True,
174
+ torch_dtype=torch.bfloat16,
175
+ attn_implementation="flash_attention_2",
176
+ device_map=device,
177
+ )
178
+
179
+ special_tokens = ["<RET>", "<ANS>", "</ANS>", "<RETQ>", "</RETQ>"]
180
+ num_added = processor.tokenizer.add_tokens(special_tokens, special_tokens=True)
181
+ if num_added > 0:
182
+ model.resize_token_embeddings(len(processor.tokenizer))
183
+
184
+ model.eval()
185
+
186
+ ret_id = processor.tokenizer.convert_tokens_to_ids("<RET>")
187
+ ans_id = processor.tokenizer.convert_tokens_to_ids("<ANS>")
188
+ print(f"[{device}] Ready. <RET>={ret_id}, <ANS>={ans_id}")
189
+
190
+ return model, processor
191
+
192
+
193
+ # ---------------------------------------------------------------------------
194
+ # 推理核心
195
+ # ---------------------------------------------------------------------------
196
+
197
+ def build_messages(
198
+ instruction: str,
199
+ query_image: str,
200
+ question: Optional[str],
201
+ shots: List[Dict],
202
+ min_pixels: int = 256 * 28 * 28,
203
+ max_pixels: int = 1280 * 28 * 28,
204
+ ) -> List[Dict]:
205
+ user_content = []
206
+
207
+ if instruction:
208
+ user_content.append({"type": "text", "text": instruction})
209
+
210
+ user_content.append({
211
+ "type": "image",
212
+ "image": f"file://{query_image}",
213
+ "min_pixels": min_pixels, "max_pixels": max_pixels,
214
+ })
215
+
216
+ if question:
217
+ user_content.append({"type": "text", "text": f"Question: {question}"})
218
+
219
+ for shot in shots:
220
+ user_content.append({
221
+ "type": "image",
222
+ "image": f"file://{shot['image']}",
223
+ "min_pixels": min_pixels, "max_pixels": max_pixels,
224
+ })
225
+ if shot.get("caption"):
226
+ user_content.append({"type": "text", "text": f"Caption: {shot['caption']}"})
227
+
228
+ user_content.append({"type": "text", "text": "Action:"})
229
+ return [{"role": "user", "content": user_content}]
230
+
231
+
232
+ @torch.no_grad()
233
+ def generate_action(model, processor, messages: List[Dict], max_new_tokens: int = 256) -> str:
234
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
235
+
236
+ image_inputs = None
237
+ try:
238
+ image_inputs, _ = process_vision_info(messages)
239
+ except Exception:
240
+ pass
241
+
242
+ inputs = processor(
243
+ text=[text],
244
+ images=image_inputs if image_inputs else None,
245
+ return_tensors="pt",
246
+ padding=False,
247
+ truncation=False,
248
+ )
249
+
250
+ device = next(model.parameters()).device
251
+ inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
252
+
253
+ outputs = model.generate(
254
+ **inputs,
255
+ max_new_tokens=max_new_tokens,
256
+ do_sample=False,
257
+ temperature=None,
258
+ top_p=None,
259
+ )
260
+
261
+ input_len = inputs["input_ids"].shape[1]
262
+ generated = outputs[0][input_len:]
263
+ return processor.tokenizer.decode(generated, skip_special_tokens=False)
264
+
265
+
266
+ def parse_action(text: str) -> Tuple[str, str]:
267
+ text = text.strip()
268
+
269
+ if text.startswith("<RET>"):
270
+ desc = text[len("<RET>"):].strip()
271
+ if desc.startswith("Description:"):
272
+ desc = desc[len("Description:"):].strip()
273
+ for tok in ["<|im_end|>", "</s>", "<|endoftext|>"]:
274
+ desc = desc.replace(tok, "").strip()
275
+ return "ret", desc
276
+
277
+ if text.startswith("<ANS>"):
278
+ ans = text[len("<ANS>"):]
279
+ end_idx = ans.find("</ANS>")
280
+ if end_idx != -1:
281
+ ans = ans[:end_idx]
282
+ else:
283
+ for tok in ["<|im_end|>", "</s>", "<|endoftext|>"]:
284
+ ans = ans.replace(tok, "").strip()
285
+ return "ans", ans.strip()
286
+
287
+ return "unknown", text
288
+
289
+
290
+ def run_icl_loop(
291
+ model, processor,
292
+ record: Dict,
293
+ instruction: str,
294
+ top5: Dict[str, List[str]],
295
+ caption_cache: Dict[str, str],
296
+ max_rounds: int = 4,
297
+ ) -> Dict:
298
+ query_image = record["image"]
299
+ question = record.get("question", "")
300
+ gt_answer = record.get("answer", "")
301
+
302
+ shots = []
303
+ used_images = {query_image}
304
+ rounds = []
305
+ candidates = top5.get(query_image, [])
306
+
307
+ for round_idx in range(max_rounds):
308
+ messages = build_messages(instruction, query_image, question, shots)
309
+ raw_output = generate_action(model, processor, messages)
310
+ action, content = parse_action(raw_output)
311
+
312
+ rounds.append({
313
+ "round": round_idx,
314
+ "action": action,
315
+ "content": content,
316
+ "raw": raw_output[:200],
317
+ })
318
+
319
+ if action == "ans":
320
+ return {
321
+ "image": query_image, "question": question,
322
+ "gt_answer": gt_answer, "rounds": rounds,
323
+ "final_answer": content, "num_rounds": round_idx + 1,
324
+ "terminated_by": "ans",
325
+ }
326
+
327
+ if action == "ret":
328
+ next_image = None
329
+ for c in candidates:
330
+ if c not in used_images:
331
+ next_image = c
332
+ break
333
+
334
+ if next_image is None:
335
+ return {
336
+ "image": query_image, "question": question,
337
+ "gt_answer": gt_answer, "rounds": rounds,
338
+ "final_answer": None, "num_rounds": round_idx + 1,
339
+ "terminated_by": "no_more_shots",
340
+ }
341
+
342
+ cap = caption_cache.get(next_image, content)
343
+ shots.append({"image": next_image, "caption": cap})
344
+ used_images.add(next_image)
345
+ else:
346
+ return {
347
+ "image": query_image, "question": question,
348
+ "gt_answer": gt_answer, "rounds": rounds,
349
+ "final_answer": content, "num_rounds": round_idx + 1,
350
+ "terminated_by": "unknown_action",
351
+ }
352
+
353
+ return {
354
+ "image": query_image, "question": question,
355
+ "gt_answer": gt_answer, "rounds": rounds,
356
+ "final_answer": None, "num_rounds": max_rounds,
357
+ "terminated_by": "max_rounds",
358
+ }
359
+
360
+
361
+ # ---------------------------------------------------------------------------
362
+ # 统计
363
+ # ---------------------------------------------------------------------------
364
+
365
+ def print_stats(results: List[Dict], cat: str = "", ds: str = ""):
366
+ prefix = f"[{cat}/{ds}]" if ds else f"[{cat}]" if cat else "[ALL]"
367
+ n = len(results)
368
+ if n == 0:
369
+ print(f"{prefix} 无结果")
370
+ return
371
+
372
+ term_counts = defaultdict(int)
373
+ for r in results:
374
+ term_counts[r["terminated_by"]] += 1
375
+
376
+ round_actions = defaultdict(lambda: defaultdict(int))
377
+ for r in results:
378
+ for rd in r["rounds"]:
379
+ round_actions[rd["round"]][rd["action"]] += 1
380
+
381
+ avg_rounds = sum(r["num_rounds"] for r in results) / n
382
+
383
+ print(f"\n{'='*60}")
384
+ print(f"{prefix} 共 {n} 条样本")
385
+ print(f" 平均轮次: {avg_rounds:.2f}")
386
+ print(f" 终止原因:")
387
+ for k, v in sorted(term_counts.items()):
388
+ print(f" {k}: {v} ({v/n*100:.1f}%)")
389
+
390
+ print(f" 每轮 RET/ANS 分布:")
391
+ for rd_idx in sorted(round_actions.keys()):
392
+ actions = round_actions[rd_idx]
393
+ total = sum(actions.values())
394
+ parts = [f"{a}={c}({c/total*100:.0f}%)" for a, c in sorted(actions.items())]
395
+ print(f" Round {rd_idx}: {' | '.join(parts)} (共{total}条)")
396
+
397
+ answered = [r for r in results if r["final_answer"] is not None]
398
+ print(f" 产出答案: {len(answered)}/{n} ({len(answered)/n*100:.1f}%)")
399
+ print(f"{'='*60}")
400
+
401
+
402
+ # ---------------------------------------------------------------------------
403
+ # Main
404
+ # ---------------------------------------------------------------------------
405
+
406
+ def main():
407
+ parser = argparse.ArgumentParser(description="ICL 多轮推理评测(支持多卡)")
408
+ parser.add_argument("--model-path", required=True, help="合并后的 HF 模型路径")
409
+ parser.add_argument("--category", type=str, default="")
410
+ parser.add_argument("--dataset", type=str, default="")
411
+ parser.add_argument("--split", type=str, default="val")
412
+ parser.add_argument("--all-categories", action="store_true")
413
+ parser.add_argument("--num-samples", type=int, default=100, help="每个 dataset 采样数")
414
+ parser.add_argument("--max-rounds", type=int, default=4)
415
+ parser.add_argument("--device", type=str, default="cuda:0", help="单卡时用的设备")
416
+ parser.add_argument("--output", type=str, default="")
417
+ parser.add_argument("--seed", type=int, default=42)
418
+ args = parser.parse_args()
419
+
420
+ random.seed(args.seed)
421
+
422
+ # 分布式初始化
423
+ rank, world_size, dist_device = setup_distributed()
424
+ device = dist_device or args.device
425
+ is_main = (rank == 0)
426
+
427
+ if is_main:
428
+ print(f"World size: {world_size}")
429
+
430
+ # 加载模型(每张卡一份)
431
+ model, processor = load_model(args.model_path, device)
432
+
433
+ # 确定 dataset 列表
434
+ if args.all_categories:
435
+ categories = ["vqa", "captioning", "classification", "reasoning"]
436
+ elif args.category:
437
+ categories = [args.category]
438
+ else:
439
+ categories = ["vqa"]
440
+
441
+ if args.dataset:
442
+ ds_list = [(args.category or "vqa", args.dataset)]
443
+ else:
444
+ ds_list = discover_datasets(categories)
445
+
446
+ # ---- 按 rank 分配 dataset ----
447
+ my_ds_list = ds_list[rank::world_size]
448
+ if is_main:
449
+ print(f"共 {len(ds_list)} 个 dataset,每卡约 {len(my_ds_list)} 个")
450
+
451
+ local_results = []
452
+
453
+ for cat, ds in my_ds_list:
454
+ print(f"[rank {rank}] Evaluating {cat}/{ds} ({args.split})")
455
+
456
+ records = load_records(cat, ds, args.split, limit=args.num_samples * 5)
457
+ if not records:
458
+ print(f" [rank {rank}] 跳过 {cat}/{ds}:无记录")
459
+ continue
460
+
461
+ top5 = load_top5(cat, ds)
462
+ if not top5:
463
+ print(f" [rank {rank}] 跳过 {cat}/{ds}:无 top5")
464
+ continue
465
+
466
+ caption_cache = load_caption_cache(cat, ds)
467
+ instructions = load_instructions(cat, ds)
468
+
469
+ records = [r for r in records if r["image"] in top5]
470
+ if not records:
471
+ print(f" [rank {rank}] 跳过 {cat}/{ds}:无 top5 覆盖")
472
+ continue
473
+
474
+ if len(records) > args.num_samples:
475
+ records = random.sample(records, args.num_samples)
476
+ print(f" [rank {rank}] {cat}/{ds}: {len(records)} 条")
477
+
478
+ for i, rec in enumerate(records):
479
+ inst = random.choice(instructions)
480
+ result = run_icl_loop(
481
+ model, processor, rec, inst, top5, caption_cache,
482
+ max_rounds=args.max_rounds,
483
+ )
484
+ result["category"] = cat
485
+ result["dataset"] = ds
486
+ local_results.append(result)
487
+
488
+ if (i + 1) % 10 == 0 or (i + 1) == len(records):
489
+ action_seq = " → ".join(rd["action"].upper() for rd in result["rounds"])
490
+ print(f" [rank {rank}] [{i+1}/{len(records)}] {action_seq} | "
491
+ f"{result['terminated_by']}")
492
+
493
+ # ---- 汇总结果 ----
494
+ all_results = gather_results(local_results, rank, world_size)
495
+
496
+ if is_main:
497
+ # 按 category 统计
498
+ cat_results = defaultdict(list)
499
+ for r in all_results:
500
+ cat_results[r["category"]].append(r)
501
+
502
+ for cat in categories:
503
+ if cat_results[cat]:
504
+ # 按 dataset 子统计
505
+ ds_groups = defaultdict(list)
506
+ for r in cat_results[cat]:
507
+ ds_groups[r["dataset"]].append(r)
508
+ for d in sorted(ds_groups):
509
+ print_stats(ds_groups[d], cat, d)
510
+ print_stats(cat_results[cat], cat)
511
+
512
+ print_stats(all_results)
513
+
514
+ output_path = args.output or f"/workspace/xiaobin/ICL/eval_results_{args.split}.json"
515
+ with open(output_path, "w", encoding="utf-8") as f:
516
+ json.dump(all_results, f, ensure_ascii=False, indent=2)
517
+ print(f"\n详细结果已保存到: {output_path}")
518
+
519
+ if world_size > 1:
520
+ dist.destroy_process_group()
521
+
522
+
523
+ if __name__ == "__main__":
524
+ main()
ICL/extract_images.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 从 /workspace/xiaobin/dataset/data 下所有 JSONL 文件中提取 base64 编码的图片,
4
+ 保存到 /workspace/xiaobin/dataset/images/{category}/{dataset}/{split}/ 目录。
5
+
6
+ split 由文件名推断:含 train -> train, 含 test -> test, 含 val/validation -> val
7
+
8
+ 图片字段名自动检测,支持:
9
+ image_str, image_base64_str, img_str, base64, image_base64, image_base_url,
10
+ video_str (list), images (list)
11
+
12
+ 依赖:无需额外安装(tqdm 已有)
13
+
14
+ 用法:
15
+ python3 extract_images.py # 处理全部
16
+ python3 extract_images.py vqa/shapes # 只处理某个数据集
17
+ python3 extract_images.py /path/to/some.jsonl # 只处理某个文件
18
+ """
19
+
20
+ import os
21
+ import sys
22
+ import json
23
+ import base64
24
+ import glob
25
+ import re
26
+ from tqdm import tqdm
27
+
28
+ DATA_ROOT = "/workspace/xiaobin/dataset/data"
29
+ OUTPUT_ROOT = "/workspace/xiaobin/dataset/images"
30
+
31
+ # 所有可能的图片字段名(优先级顺序)
32
+ # 注意:有些字段在不同数据集中可能是 str 也可能是 list,统一处理
33
+ ALL_IMAGE_FIELDS = [
34
+ "image", # captioning/coco
35
+ "image_str", # 多个数据集(str 或 list)
36
+ "image_base64_str", # snli-ve, multi30k, vcr, visual_mrc
37
+ "img_str", # gqa, ocr-vqa, st-vqa, text-vqa, viquae, vqav2
38
+ "base64", # fm-iqa
39
+ "image_base64", # coco-cn, mmchat(str 或 list,如 chinesefoodnet-10)
40
+ "image_base_url", # textcap
41
+ "video_str", # msrvtt, ss, activitynet-qa, ivqa, msrvtt-qa, msvd-qa (list)
42
+ "images", # vist (list)
43
+ ]
44
+
45
+
46
+ def detect_extension(data_bytes):
47
+ """根据文件头判断图片格式"""
48
+ if data_bytes[:2] == b'\xff\xd8':
49
+ return ".jpg"
50
+ elif data_bytes[:8] == b'\x89PNG\r\n\x1a\n':
51
+ return ".png"
52
+ elif data_bytes[:4] == b'GIF8':
53
+ return ".gif"
54
+ elif data_bytes[:4] == b'RIFF' and data_bytes[8:12] == b'WEBP':
55
+ return ".webp"
56
+ else:
57
+ return ".jpg"
58
+
59
+
60
+ def classify_split(filename):
61
+ """从文件名推断 split 类型"""
62
+ fn = filename.lower()
63
+ if "train" in fn:
64
+ return "train"
65
+ elif "test" in fn:
66
+ return "test"
67
+ elif "val" in fn:
68
+ return "val"
69
+ else:
70
+ return "other"
71
+
72
+
73
+ def extract_images_from_record(record):
74
+ """从一条 JSONL 记录中提取图片 base64 字符串列表"""
75
+ for field in ALL_IMAGE_FIELDS:
76
+ if field not in record or not record[field]:
77
+ continue
78
+ val = record[field]
79
+ if isinstance(val, str) and len(val) > 100:
80
+ return [val]
81
+ elif isinstance(val, list):
82
+ return [item for item in val if isinstance(item, str) and len(item) > 100]
83
+ return []
84
+
85
+
86
+ def count_lines(filepath):
87
+ """快速统计文件行数(用于 tqdm total)"""
88
+ count = 0
89
+ with open(filepath, 'rb') as f:
90
+ # 用 buffer 读取,比逐行快很多
91
+ buf_size = 1024 * 1024 * 8 # 8MB
92
+ buf = f.raw.read(buf_size)
93
+ while buf:
94
+ count += buf.count(b'\n')
95
+ buf = f.raw.read(buf_size)
96
+ return count
97
+
98
+
99
+ def process_jsonl_file(jsonl_path, file_idx, total_files):
100
+ """处理单个 JSONL 文件,提取图片并保存"""
101
+ rel_path = os.path.relpath(jsonl_path, DATA_ROOT)
102
+ parts = rel_path.split(os.sep)
103
+ if len(parts) < 3:
104
+ print(f" [SKIP] 路径层级不够: {rel_path}")
105
+ return 0
106
+
107
+ category = parts[0]
108
+ dataset = parts[1]
109
+ filename = parts[2]
110
+ split = classify_split(filename)
111
+
112
+ out_dir = os.path.join(OUTPUT_ROOT, category, dataset, split)
113
+ os.makedirs(out_dir, exist_ok=True)
114
+
115
+ # 断点续传:统计已有图片数
116
+ existing_count = len([f for f in os.listdir(out_dir) if os.path.isfile(os.path.join(out_dir, f))])
117
+
118
+ # 快速统计总行数
119
+ file_size_mb = os.path.getsize(jsonl_path) / (1024 * 1024)
120
+ total_lines = count_lines(jsonl_path)
121
+
122
+ count = 0
123
+ skipped = 0
124
+ errors = 0
125
+
126
+ desc = f"[{file_idx}/{total_files}] {category}/{dataset}/{split} ({file_size_mb:.0f}MB)"
127
+
128
+ try:
129
+ with open(jsonl_path, 'r', encoding='utf-8') as f:
130
+ pbar = tqdm(f, total=total_lines, desc=desc, unit="行",
131
+ dynamic_ncols=True, miniters=50)
132
+ for line in pbar:
133
+ line = line.strip()
134
+ if not line:
135
+ continue
136
+
137
+ try:
138
+ record = json.loads(line)
139
+ except json.JSONDecodeError:
140
+ errors += 1
141
+ continue
142
+
143
+ b64_list = extract_images_from_record(record)
144
+ if not b64_list:
145
+ skipped += 1
146
+ continue
147
+
148
+ for img_idx, b64_str in enumerate(b64_list):
149
+ global_idx = existing_count + count
150
+ try:
151
+ img_bytes = base64.b64decode(b64_str)
152
+ ext = detect_extension(img_bytes)
153
+ if len(b64_list) > 1:
154
+ img_name = f"{global_idx:08d}_f{img_idx:03d}{ext}"
155
+ else:
156
+ img_name = f"{global_idx:08d}{ext}"
157
+ with open(os.path.join(out_dir, img_name), 'wb') as img_f:
158
+ img_f.write(img_bytes)
159
+ count += 1
160
+ except Exception as e:
161
+ errors += 1
162
+ if errors <= 3:
163
+ tqdm.write(f" [ERROR] {e}")
164
+
165
+ # 更新后缀信息
166
+ pbar.set_postfix(imgs=count, skip=skipped, err=errors, refresh=False)
167
+ pbar.close()
168
+
169
+ except Exception as e:
170
+ print(f" [FATAL] {e}")
171
+
172
+ print(f" -> 完成: {count} 张图片, 跳过 {skipped} 行(无图), 错误 {errors}")
173
+ return count
174
+
175
+
176
+ def find_all_jsonl_files():
177
+ """查找所有需要处理的 JSONL 文件"""
178
+ all_files = []
179
+ for jsonl_path in sorted(glob.glob(os.path.join(DATA_ROOT, "*/*/*.jsonl"))):
180
+ filename = os.path.basename(jsonl_path)
181
+ if re.search(r'_\d{4}-\d{2}-\d{2}\.jsonl$', filename):
182
+ continue
183
+ if '_v2.jsonl' in filename or '_new.jsonl' in filename:
184
+ continue
185
+ if filename.startswith('para_'):
186
+ continue
187
+ all_files.append(jsonl_path)
188
+ return all_files
189
+
190
+
191
+ def main():
192
+ print("=" * 60)
193
+ print("JSONL 图片提取工具")
194
+ print(f"数据源: {DATA_ROOT}")
195
+ print(f"输出到: {OUTPUT_ROOT}")
196
+ print("=" * 60)
197
+
198
+ if len(sys.argv) > 1:
199
+ target = sys.argv[1]
200
+ if os.path.isfile(target):
201
+ files = [target]
202
+ else:
203
+ files = sorted(glob.glob(os.path.join(DATA_ROOT, target, "*.jsonl")))
204
+ files = [f for f in files
205
+ if not re.search(r'_\d{4}-\d{2}-\d{2}\.jsonl$', os.path.basename(f))
206
+ and '_v2.jsonl' not in os.path.basename(f)
207
+ and '_new.jsonl' not in os.path.basename(f)
208
+ and not os.path.basename(f).startswith('para_')]
209
+ else:
210
+ files = find_all_jsonl_files()
211
+
212
+ print(f"\n共 {len(files)} 个 JSONL 文件:")
213
+ total_size = 0
214
+ for f in files:
215
+ size_mb = os.path.getsize(f) / (1024 * 1024)
216
+ total_size += size_mb
217
+ print(f" {os.path.relpath(f, DATA_ROOT):50s} {size_mb:>10.1f} MB")
218
+ print(f" {'合计':50s} {total_size/1024:>10.1f} GB")
219
+
220
+ total_images = 0
221
+ for i, jsonl_path in enumerate(files, 1):
222
+ n = process_jsonl_file(jsonl_path, i, len(files))
223
+ total_images += n
224
+
225
+ print(f"\n{'=' * 60}")
226
+ print(f"全部完成!共提取 {total_images} 张图片")
227
+ print(f"保存在: {OUTPUT_ROOT}")
228
+
229
+
230
+ if __name__ == "__main__":
231
+ main()
ICL/merge_captions.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 把 detail/{cat}/{ds}/{split}/captions.json 合并成 build_sft.py 需要的格式:
4
+ caption_cache/{cat}_{ds}.json = {"items": {img_path: caption, ...}}
5
+
6
+ 这样 build_sft.py --caption-cache-dir caption_cache 就能直接复用。
7
+
8
+ 用法:
9
+ python3 merge_captions.py
10
+ python3 merge_captions.py --force # 强制重建
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import json
16
+ import glob
17
+
18
+ DETAIL_ROOT = "/workspace/xiaobin/dataset/detail"
19
+ CAPTION_CACHE_DIR = "/workspace/xiaobin/dataset/caption_cache"
20
+
21
+
22
+ def main():
23
+ force = "--force" in sys.argv
24
+ os.makedirs(CAPTION_CACHE_DIR, exist_ok=True)
25
+
26
+ # 找所有 dataset 目录 (cat/ds)
27
+ datasets = set()
28
+ for captions_file in glob.glob(os.path.join(DETAIL_ROOT, "*/*/*/captions.json")):
29
+ rel = os.path.relpath(captions_file, DETAIL_ROOT)
30
+ parts = rel.split(os.sep) # cat/ds/split/captions.json
31
+ datasets.add((parts[0], parts[1]))
32
+
33
+ print(f"共 {len(datasets)} 个数据集")
34
+
35
+ for cat, ds in sorted(datasets):
36
+ out_name = f"{cat}_{ds}.json"
37
+ out_path = os.path.join(CAPTION_CACHE_DIR, out_name)
38
+
39
+ if not force and os.path.exists(out_path) and os.path.getsize(out_path) > 0:
40
+ print(f" [SKIP] {out_name}")
41
+ continue
42
+
43
+ merged = {}
44
+ for split in ("train", "val", "test"):
45
+ src = os.path.join(DETAIL_ROOT, cat, ds, split, "captions.json")
46
+ if not os.path.exists(src):
47
+ continue
48
+ try:
49
+ with open(src, 'r', encoding='utf-8') as f:
50
+ data = json.load(f)
51
+ items = data.get("items", {})
52
+ if isinstance(items, dict):
53
+ merged.update(items)
54
+ except Exception as e:
55
+ print(f" [WARN] {src}: {e}")
56
+
57
+ if not merged:
58
+ print(f" [EMPTY] {cat}/{ds}")
59
+ continue
60
+
61
+ with open(out_path, 'w', encoding='utf-8') as f:
62
+ json.dump({"items": merged}, f, ensure_ascii=False)
63
+
64
+ print(f" [OK] {out_name}: {len(merged)} 条")
65
+
66
+ print(f"\n完成! 输出: {CAPTION_CACHE_DIR}")
67
+
68
+
69
+ if __name__ == "__main__":
70
+ main()
ICL/sft_model/epoch3_step1406_fp32/chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set image_count = namespace(value=0) %}\n{%- set video_count = namespace(value=0) %}\n{%- for message in messages %}\n {%- if message.role == \"user\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content_item in message.content %}\n {%- if 'text' in content_item %}\n {{- content_item.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and message.content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n"
3
+ }
ICL/sft_model/epoch3_step1406_fp32/config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3VLForConditionalGeneration"
4
+ ],
5
+ "image_token_id": 151655,
6
+ "model_type": "qwen3_vl",
7
+ "text_config": {
8
+ "attention_bias": false,
9
+ "attention_dropout": 0.0,
10
+ "bos_token_id": 151643,
11
+ "dtype": "bfloat16",
12
+ "eos_token_id": 151645,
13
+ "head_dim": 128,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 4096,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 12288,
18
+ "max_position_embeddings": 262144,
19
+ "model_type": "qwen3_vl_text",
20
+ "num_attention_heads": 32,
21
+ "num_hidden_layers": 36,
22
+ "num_key_value_heads": 8,
23
+ "rms_norm_eps": 1e-06,
24
+ "rope_scaling": {
25
+ "mrope_interleaved": true,
26
+ "mrope_section": [
27
+ 24,
28
+ 20,
29
+ 20
30
+ ],
31
+ "rope_type": "default"
32
+ },
33
+ "rope_theta": 5000000,
34
+ "use_cache": true,
35
+ "vocab_size": 151936
36
+ },
37
+ "tie_word_embeddings": false,
38
+ "transformers_version": "4.57.0.dev0",
39
+ "video_token_id": 151656,
40
+ "vision_config": {
41
+ "deepstack_visual_indexes": [
42
+ 8,
43
+ 16,
44
+ 24
45
+ ],
46
+ "depth": 27,
47
+ "hidden_act": "gelu_pytorch_tanh",
48
+ "hidden_size": 1152,
49
+ "in_channels": 3,
50
+ "initializer_range": 0.02,
51
+ "intermediate_size": 4304,
52
+ "model_type": "qwen3_vl",
53
+ "num_heads": 16,
54
+ "num_position_embeddings": 2304,
55
+ "out_hidden_size": 4096,
56
+ "patch_size": 16,
57
+ "spatial_merge_size": 2,
58
+ "temporal_patch_size": 2
59
+ },
60
+ "vision_end_token_id": 151653,
61
+ "vision_start_token_id": 151652
62
+ }
ICL/sft_model/epoch3_step1406_fp32/generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "pad_token_id": 151643,
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 151645,
7
+ 151643
8
+ ],
9
+ "top_k": 20,
10
+ "top_p": 0.8,
11
+ "repetition_penalty": 1.0,
12
+ "temperature": 0.7,
13
+ "transformers_version": "4.56.0"
14
+ }
ICL/sft_model/epoch3_step1406_fp32/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
ICL/sft_model/epoch3_step1406_fp32/model.safetensors.index.json ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 35059909568
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00008-of-00008.safetensors",
7
+ "model.language_model.embed_tokens.weight": "model-00001-of-00008.safetensors",
8
+ "model.language_model.layers.0.input_layernorm.weight": "model-00002-of-00008.safetensors",
9
+ "model.language_model.layers.0.mlp.down_proj.weight": "model-00002-of-00008.safetensors",
10
+ "model.language_model.layers.0.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
11
+ "model.language_model.layers.0.mlp.up_proj.weight": "model-00002-of-00008.safetensors",
12
+ "model.language_model.layers.0.post_attention_layernorm.weight": "model-00002-of-00008.safetensors",
13
+ "model.language_model.layers.0.self_attn.k_norm.weight": "model-00001-of-00008.safetensors",
14
+ "model.language_model.layers.0.self_attn.k_proj.weight": "model-00001-of-00008.safetensors",
15
+ "model.language_model.layers.0.self_attn.o_proj.weight": "model-00001-of-00008.safetensors",
16
+ "model.language_model.layers.0.self_attn.q_norm.weight": "model-00001-of-00008.safetensors",
17
+ "model.language_model.layers.0.self_attn.q_proj.weight": "model-00001-of-00008.safetensors",
18
+ "model.language_model.layers.0.self_attn.v_proj.weight": "model-00001-of-00008.safetensors",
19
+ "model.language_model.layers.1.input_layernorm.weight": "model-00002-of-00008.safetensors",
20
+ "model.language_model.layers.1.mlp.down_proj.weight": "model-00002-of-00008.safetensors",
21
+ "model.language_model.layers.1.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
22
+ "model.language_model.layers.1.mlp.up_proj.weight": "model-00002-of-00008.safetensors",
23
+ "model.language_model.layers.1.post_attention_layernorm.weight": "model-00002-of-00008.safetensors",
24
+ "model.language_model.layers.1.self_attn.k_norm.weight": "model-00002-of-00008.safetensors",
25
+ "model.language_model.layers.1.self_attn.k_proj.weight": "model-00002-of-00008.safetensors",
26
+ "model.language_model.layers.1.self_attn.o_proj.weight": "model-00002-of-00008.safetensors",
27
+ "model.language_model.layers.1.self_attn.q_norm.weight": "model-00002-of-00008.safetensors",
28
+ "model.language_model.layers.1.self_attn.q_proj.weight": "model-00002-of-00008.safetensors",
29
+ "model.language_model.layers.1.self_attn.v_proj.weight": "model-00002-of-00008.safetensors",
30
+ "model.language_model.layers.10.input_layernorm.weight": "model-00003-of-00008.safetensors",
31
+ "model.language_model.layers.10.mlp.down_proj.weight": "model-00003-of-00008.safetensors",
32
+ "model.language_model.layers.10.mlp.gate_proj.weight": "model-00003-of-00008.safetensors",
33
+ "model.language_model.layers.10.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
34
+ "model.language_model.layers.10.post_attention_layernorm.weight": "model-00003-of-00008.safetensors",
35
+ "model.language_model.layers.10.self_attn.k_norm.weight": "model-00003-of-00008.safetensors",
36
+ "model.language_model.layers.10.self_attn.k_proj.weight": "model-00003-of-00008.safetensors",
37
+ "model.language_model.layers.10.self_attn.o_proj.weight": "model-00003-of-00008.safetensors",
38
+ "model.language_model.layers.10.self_attn.q_norm.weight": "model-00003-of-00008.safetensors",
39
+ "model.language_model.layers.10.self_attn.q_proj.weight": "model-00003-of-00008.safetensors",
40
+ "model.language_model.layers.10.self_attn.v_proj.weight": "model-00003-of-00008.safetensors",
41
+ "model.language_model.layers.11.input_layernorm.weight": "model-00003-of-00008.safetensors",
42
+ "model.language_model.layers.11.mlp.down_proj.weight": "model-00003-of-00008.safetensors",
43
+ "model.language_model.layers.11.mlp.gate_proj.weight": "model-00003-of-00008.safetensors",
44
+ "model.language_model.layers.11.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
45
+ "model.language_model.layers.11.post_attention_layernorm.weight": "model-00003-of-00008.safetensors",
46
+ "model.language_model.layers.11.self_attn.k_norm.weight": "model-00003-of-00008.safetensors",
47
+ "model.language_model.layers.11.self_attn.k_proj.weight": "model-00003-of-00008.safetensors",
48
+ "model.language_model.layers.11.self_attn.o_proj.weight": "model-00003-of-00008.safetensors",
49
+ "model.language_model.layers.11.self_attn.q_norm.weight": "model-00003-of-00008.safetensors",
50
+ "model.language_model.layers.11.self_attn.q_proj.weight": "model-00003-of-00008.safetensors",
51
+ "model.language_model.layers.11.self_attn.v_proj.weight": "model-00003-of-00008.safetensors",
52
+ "model.language_model.layers.12.input_layernorm.weight": "model-00004-of-00008.safetensors",
53
+ "model.language_model.layers.12.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
54
+ "model.language_model.layers.12.mlp.gate_proj.weight": "model-00003-of-00008.safetensors",
55
+ "model.language_model.layers.12.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
56
+ "model.language_model.layers.12.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
57
+ "model.language_model.layers.12.self_attn.k_norm.weight": "model-00003-of-00008.safetensors",
58
+ "model.language_model.layers.12.self_attn.k_proj.weight": "model-00003-of-00008.safetensors",
59
+ "model.language_model.layers.12.self_attn.o_proj.weight": "model-00003-of-00008.safetensors",
60
+ "model.language_model.layers.12.self_attn.q_norm.weight": "model-00003-of-00008.safetensors",
61
+ "model.language_model.layers.12.self_attn.q_proj.weight": "model-00003-of-00008.safetensors",
62
+ "model.language_model.layers.12.self_attn.v_proj.weight": "model-00003-of-00008.safetensors",
63
+ "model.language_model.layers.13.input_layernorm.weight": "model-00004-of-00008.safetensors",
64
+ "model.language_model.layers.13.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
65
+ "model.language_model.layers.13.mlp.gate_proj.weight": "model-00004-of-00008.safetensors",
66
+ "model.language_model.layers.13.mlp.up_proj.weight": "model-00004-of-00008.safetensors",
67
+ "model.language_model.layers.13.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
68
+ "model.language_model.layers.13.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
69
+ "model.language_model.layers.13.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
70
+ "model.language_model.layers.13.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
71
+ "model.language_model.layers.13.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
72
+ "model.language_model.layers.13.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
73
+ "model.language_model.layers.13.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
74
+ "model.language_model.layers.14.input_layernorm.weight": "model-00004-of-00008.safetensors",
75
+ "model.language_model.layers.14.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
76
+ "model.language_model.layers.14.mlp.gate_proj.weight": "model-00004-of-00008.safetensors",
77
+ "model.language_model.layers.14.mlp.up_proj.weight": "model-00004-of-00008.safetensors",
78
+ "model.language_model.layers.14.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
79
+ "model.language_model.layers.14.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
80
+ "model.language_model.layers.14.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
81
+ "model.language_model.layers.14.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
82
+ "model.language_model.layers.14.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
83
+ "model.language_model.layers.14.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
84
+ "model.language_model.layers.14.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
85
+ "model.language_model.layers.15.input_layernorm.weight": "model-00004-of-00008.safetensors",
86
+ "model.language_model.layers.15.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
87
+ "model.language_model.layers.15.mlp.gate_proj.weight": "model-00004-of-00008.safetensors",
88
+ "model.language_model.layers.15.mlp.up_proj.weight": "model-00004-of-00008.safetensors",
89
+ "model.language_model.layers.15.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
90
+ "model.language_model.layers.15.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
91
+ "model.language_model.layers.15.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
92
+ "model.language_model.layers.15.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
93
+ "model.language_model.layers.15.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
94
+ "model.language_model.layers.15.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
95
+ "model.language_model.layers.15.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
96
+ "model.language_model.layers.16.input_layernorm.weight": "model-00004-of-00008.safetensors",
97
+ "model.language_model.layers.16.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
98
+ "model.language_model.layers.16.mlp.gate_proj.weight": "model-00004-of-00008.safetensors",
99
+ "model.language_model.layers.16.mlp.up_proj.weight": "model-00004-of-00008.safetensors",
100
+ "model.language_model.layers.16.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
101
+ "model.language_model.layers.16.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
102
+ "model.language_model.layers.16.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
103
+ "model.language_model.layers.16.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
104
+ "model.language_model.layers.16.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
105
+ "model.language_model.layers.16.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
106
+ "model.language_model.layers.16.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
107
+ "model.language_model.layers.17.input_layernorm.weight": "model-00004-of-00008.safetensors",
108
+ "model.language_model.layers.17.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
109
+ "model.language_model.layers.17.mlp.gate_proj.weight": "model-00004-of-00008.safetensors",
110
+ "model.language_model.layers.17.mlp.up_proj.weight": "model-00004-of-00008.safetensors",
111
+ "model.language_model.layers.17.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
112
+ "model.language_model.layers.17.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
113
+ "model.language_model.layers.17.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
114
+ "model.language_model.layers.17.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
115
+ "model.language_model.layers.17.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
116
+ "model.language_model.layers.17.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
117
+ "model.language_model.layers.17.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
118
+ "model.language_model.layers.18.input_layernorm.weight": "model-00004-of-00008.safetensors",
119
+ "model.language_model.layers.18.mlp.down_proj.weight": "model-00004-of-00008.safetensors",
120
+ "model.language_model.layers.18.mlp.gate_proj.weight": "model-00004-of-00008.safetensors",
121
+ "model.language_model.layers.18.mlp.up_proj.weight": "model-00004-of-00008.safetensors",
122
+ "model.language_model.layers.18.post_attention_layernorm.weight": "model-00004-of-00008.safetensors",
123
+ "model.language_model.layers.18.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
124
+ "model.language_model.layers.18.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
125
+ "model.language_model.layers.18.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
126
+ "model.language_model.layers.18.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
127
+ "model.language_model.layers.18.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
128
+ "model.language_model.layers.18.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
129
+ "model.language_model.layers.19.input_layernorm.weight": "model-00005-of-00008.safetensors",
130
+ "model.language_model.layers.19.mlp.down_proj.weight": "model-00005-of-00008.safetensors",
131
+ "model.language_model.layers.19.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
132
+ "model.language_model.layers.19.mlp.up_proj.weight": "model-00005-of-00008.safetensors",
133
+ "model.language_model.layers.19.post_attention_layernorm.weight": "model-00005-of-00008.safetensors",
134
+ "model.language_model.layers.19.self_attn.k_norm.weight": "model-00004-of-00008.safetensors",
135
+ "model.language_model.layers.19.self_attn.k_proj.weight": "model-00004-of-00008.safetensors",
136
+ "model.language_model.layers.19.self_attn.o_proj.weight": "model-00004-of-00008.safetensors",
137
+ "model.language_model.layers.19.self_attn.q_norm.weight": "model-00004-of-00008.safetensors",
138
+ "model.language_model.layers.19.self_attn.q_proj.weight": "model-00004-of-00008.safetensors",
139
+ "model.language_model.layers.19.self_attn.v_proj.weight": "model-00004-of-00008.safetensors",
140
+ "model.language_model.layers.2.input_layernorm.weight": "model-00002-of-00008.safetensors",
141
+ "model.language_model.layers.2.mlp.down_proj.weight": "model-00002-of-00008.safetensors",
142
+ "model.language_model.layers.2.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
143
+ "model.language_model.layers.2.mlp.up_proj.weight": "model-00002-of-00008.safetensors",
144
+ "model.language_model.layers.2.post_attention_layernorm.weight": "model-00002-of-00008.safetensors",
145
+ "model.language_model.layers.2.self_attn.k_norm.weight": "model-00002-of-00008.safetensors",
146
+ "model.language_model.layers.2.self_attn.k_proj.weight": "model-00002-of-00008.safetensors",
147
+ "model.language_model.layers.2.self_attn.o_proj.weight": "model-00002-of-00008.safetensors",
148
+ "model.language_model.layers.2.self_attn.q_norm.weight": "model-00002-of-00008.safetensors",
149
+ "model.language_model.layers.2.self_attn.q_proj.weight": "model-00002-of-00008.safetensors",
150
+ "model.language_model.layers.2.self_attn.v_proj.weight": "model-00002-of-00008.safetensors",
151
+ "model.language_model.layers.20.input_layernorm.weight": "model-00005-of-00008.safetensors",
152
+ "model.language_model.layers.20.mlp.down_proj.weight": "model-00005-of-00008.safetensors",
153
+ "model.language_model.layers.20.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
154
+ "model.language_model.layers.20.mlp.up_proj.weight": "model-00005-of-00008.safetensors",
155
+ "model.language_model.layers.20.post_attention_layernorm.weight": "model-00005-of-00008.safetensors",
156
+ "model.language_model.layers.20.self_attn.k_norm.weight": "model-00005-of-00008.safetensors",
157
+ "model.language_model.layers.20.self_attn.k_proj.weight": "model-00005-of-00008.safetensors",
158
+ "model.language_model.layers.20.self_attn.o_proj.weight": "model-00005-of-00008.safetensors",
159
+ "model.language_model.layers.20.self_attn.q_norm.weight": "model-00005-of-00008.safetensors",
160
+ "model.language_model.layers.20.self_attn.q_proj.weight": "model-00005-of-00008.safetensors",
161
+ "model.language_model.layers.20.self_attn.v_proj.weight": "model-00005-of-00008.safetensors",
162
+ "model.language_model.layers.21.input_layernorm.weight": "model-00005-of-00008.safetensors",
163
+ "model.language_model.layers.21.mlp.down_proj.weight": "model-00005-of-00008.safetensors",
164
+ "model.language_model.layers.21.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
165
+ "model.language_model.layers.21.mlp.up_proj.weight": "model-00005-of-00008.safetensors",
166
+ "model.language_model.layers.21.post_attention_layernorm.weight": "model-00005-of-00008.safetensors",
167
+ "model.language_model.layers.21.self_attn.k_norm.weight": "model-00005-of-00008.safetensors",
168
+ "model.language_model.layers.21.self_attn.k_proj.weight": "model-00005-of-00008.safetensors",
169
+ "model.language_model.layers.21.self_attn.o_proj.weight": "model-00005-of-00008.safetensors",
170
+ "model.language_model.layers.21.self_attn.q_norm.weight": "model-00005-of-00008.safetensors",
171
+ "model.language_model.layers.21.self_attn.q_proj.weight": "model-00005-of-00008.safetensors",
172
+ "model.language_model.layers.21.self_attn.v_proj.weight": "model-00005-of-00008.safetensors",
173
+ "model.language_model.layers.22.input_layernorm.weight": "model-00005-of-00008.safetensors",
174
+ "model.language_model.layers.22.mlp.down_proj.weight": "model-00005-of-00008.safetensors",
175
+ "model.language_model.layers.22.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
176
+ "model.language_model.layers.22.mlp.up_proj.weight": "model-00005-of-00008.safetensors",
177
+ "model.language_model.layers.22.post_attention_layernorm.weight": "model-00005-of-00008.safetensors",
178
+ "model.language_model.layers.22.self_attn.k_norm.weight": "model-00005-of-00008.safetensors",
179
+ "model.language_model.layers.22.self_attn.k_proj.weight": "model-00005-of-00008.safetensors",
180
+ "model.language_model.layers.22.self_attn.o_proj.weight": "model-00005-of-00008.safetensors",
181
+ "model.language_model.layers.22.self_attn.q_norm.weight": "model-00005-of-00008.safetensors",
182
+ "model.language_model.layers.22.self_attn.q_proj.weight": "model-00005-of-00008.safetensors",
183
+ "model.language_model.layers.22.self_attn.v_proj.weight": "model-00005-of-00008.safetensors",
184
+ "model.language_model.layers.23.input_layernorm.weight": "model-00005-of-00008.safetensors",
185
+ "model.language_model.layers.23.mlp.down_proj.weight": "model-00005-of-00008.safetensors",
186
+ "model.language_model.layers.23.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
187
+ "model.language_model.layers.23.mlp.up_proj.weight": "model-00005-of-00008.safetensors",
188
+ "model.language_model.layers.23.post_attention_layernorm.weight": "model-00005-of-00008.safetensors",
189
+ "model.language_model.layers.23.self_attn.k_norm.weight": "model-00005-of-00008.safetensors",
190
+ "model.language_model.layers.23.self_attn.k_proj.weight": "model-00005-of-00008.safetensors",
191
+ "model.language_model.layers.23.self_attn.o_proj.weight": "model-00005-of-00008.safetensors",
192
+ "model.language_model.layers.23.self_attn.q_norm.weight": "model-00005-of-00008.safetensors",
193
+ "model.language_model.layers.23.self_attn.q_proj.weight": "model-00005-of-00008.safetensors",
194
+ "model.language_model.layers.23.self_attn.v_proj.weight": "model-00005-of-00008.safetensors",
195
+ "model.language_model.layers.24.input_layernorm.weight": "model-00005-of-00008.safetensors",
196
+ "model.language_model.layers.24.mlp.down_proj.weight": "model-00005-of-00008.safetensors",
197
+ "model.language_model.layers.24.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
198
+ "model.language_model.layers.24.mlp.up_proj.weight": "model-00005-of-00008.safetensors",
199
+ "model.language_model.layers.24.post_attention_layernorm.weight": "model-00005-of-00008.safetensors",
200
+ "model.language_model.layers.24.self_attn.k_norm.weight": "model-00005-of-00008.safetensors",
201
+ "model.language_model.layers.24.self_attn.k_proj.weight": "model-00005-of-00008.safetensors",
202
+ "model.language_model.layers.24.self_attn.o_proj.weight": "model-00005-of-00008.safetensors",
203
+ "model.language_model.layers.24.self_attn.q_norm.weight": "model-00005-of-00008.safetensors",
204
+ "model.language_model.layers.24.self_attn.q_proj.weight": "model-00005-of-00008.safetensors",
205
+ "model.language_model.layers.24.self_attn.v_proj.weight": "model-00005-of-00008.safetensors",
206
+ "model.language_model.layers.25.input_layernorm.weight": "model-00006-of-00008.safetensors",
207
+ "model.language_model.layers.25.mlp.down_proj.weight": "model-00006-of-00008.safetensors",
208
+ "model.language_model.layers.25.mlp.gate_proj.weight": "model-00005-of-00008.safetensors",
209
+ "model.language_model.layers.25.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
210
+ "model.language_model.layers.25.post_attention_layernorm.weight": "model-00006-of-00008.safetensors",
211
+ "model.language_model.layers.25.self_attn.k_norm.weight": "model-00005-of-00008.safetensors",
212
+ "model.language_model.layers.25.self_attn.k_proj.weight": "model-00005-of-00008.safetensors",
213
+ "model.language_model.layers.25.self_attn.o_proj.weight": "model-00005-of-00008.safetensors",
214
+ "model.language_model.layers.25.self_attn.q_norm.weight": "model-00005-of-00008.safetensors",
215
+ "model.language_model.layers.25.self_attn.q_proj.weight": "model-00005-of-00008.safetensors",
216
+ "model.language_model.layers.25.self_attn.v_proj.weight": "model-00005-of-00008.safetensors",
217
+ "model.language_model.layers.26.input_layernorm.weight": "model-00006-of-00008.safetensors",
218
+ "model.language_model.layers.26.mlp.down_proj.weight": "model-00006-of-00008.safetensors",
219
+ "model.language_model.layers.26.mlp.gate_proj.weight": "model-00006-of-00008.safetensors",
220
+ "model.language_model.layers.26.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
221
+ "model.language_model.layers.26.post_attention_layernorm.weight": "model-00006-of-00008.safetensors",
222
+ "model.language_model.layers.26.self_attn.k_norm.weight": "model-00006-of-00008.safetensors",
223
+ "model.language_model.layers.26.self_attn.k_proj.weight": "model-00006-of-00008.safetensors",
224
+ "model.language_model.layers.26.self_attn.o_proj.weight": "model-00006-of-00008.safetensors",
225
+ "model.language_model.layers.26.self_attn.q_norm.weight": "model-00006-of-00008.safetensors",
226
+ "model.language_model.layers.26.self_attn.q_proj.weight": "model-00006-of-00008.safetensors",
227
+ "model.language_model.layers.26.self_attn.v_proj.weight": "model-00006-of-00008.safetensors",
228
+ "model.language_model.layers.27.input_layernorm.weight": "model-00006-of-00008.safetensors",
229
+ "model.language_model.layers.27.mlp.down_proj.weight": "model-00006-of-00008.safetensors",
230
+ "model.language_model.layers.27.mlp.gate_proj.weight": "model-00006-of-00008.safetensors",
231
+ "model.language_model.layers.27.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
232
+ "model.language_model.layers.27.post_attention_layernorm.weight": "model-00006-of-00008.safetensors",
233
+ "model.language_model.layers.27.self_attn.k_norm.weight": "model-00006-of-00008.safetensors",
234
+ "model.language_model.layers.27.self_attn.k_proj.weight": "model-00006-of-00008.safetensors",
235
+ "model.language_model.layers.27.self_attn.o_proj.weight": "model-00006-of-00008.safetensors",
236
+ "model.language_model.layers.27.self_attn.q_norm.weight": "model-00006-of-00008.safetensors",
237
+ "model.language_model.layers.27.self_attn.q_proj.weight": "model-00006-of-00008.safetensors",
238
+ "model.language_model.layers.27.self_attn.v_proj.weight": "model-00006-of-00008.safetensors",
239
+ "model.language_model.layers.28.input_layernorm.weight": "model-00006-of-00008.safetensors",
240
+ "model.language_model.layers.28.mlp.down_proj.weight": "model-00006-of-00008.safetensors",
241
+ "model.language_model.layers.28.mlp.gate_proj.weight": "model-00006-of-00008.safetensors",
242
+ "model.language_model.layers.28.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
243
+ "model.language_model.layers.28.post_attention_layernorm.weight": "model-00006-of-00008.safetensors",
244
+ "model.language_model.layers.28.self_attn.k_norm.weight": "model-00006-of-00008.safetensors",
245
+ "model.language_model.layers.28.self_attn.k_proj.weight": "model-00006-of-00008.safetensors",
246
+ "model.language_model.layers.28.self_attn.o_proj.weight": "model-00006-of-00008.safetensors",
247
+ "model.language_model.layers.28.self_attn.q_norm.weight": "model-00006-of-00008.safetensors",
248
+ "model.language_model.layers.28.self_attn.q_proj.weight": "model-00006-of-00008.safetensors",
249
+ "model.language_model.layers.28.self_attn.v_proj.weight": "model-00006-of-00008.safetensors",
250
+ "model.language_model.layers.29.input_layernorm.weight": "model-00006-of-00008.safetensors",
251
+ "model.language_model.layers.29.mlp.down_proj.weight": "model-00006-of-00008.safetensors",
252
+ "model.language_model.layers.29.mlp.gate_proj.weight": "model-00006-of-00008.safetensors",
253
+ "model.language_model.layers.29.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
254
+ "model.language_model.layers.29.post_attention_layernorm.weight": "model-00006-of-00008.safetensors",
255
+ "model.language_model.layers.29.self_attn.k_norm.weight": "model-00006-of-00008.safetensors",
256
+ "model.language_model.layers.29.self_attn.k_proj.weight": "model-00006-of-00008.safetensors",
257
+ "model.language_model.layers.29.self_attn.o_proj.weight": "model-00006-of-00008.safetensors",
258
+ "model.language_model.layers.29.self_attn.q_norm.weight": "model-00006-of-00008.safetensors",
259
+ "model.language_model.layers.29.self_attn.q_proj.weight": "model-00006-of-00008.safetensors",
260
+ "model.language_model.layers.29.self_attn.v_proj.weight": "model-00006-of-00008.safetensors",
261
+ "model.language_model.layers.3.input_layernorm.weight": "model-00002-of-00008.safetensors",
262
+ "model.language_model.layers.3.mlp.down_proj.weight": "model-00002-of-00008.safetensors",
263
+ "model.language_model.layers.3.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
264
+ "model.language_model.layers.3.mlp.up_proj.weight": "model-00002-of-00008.safetensors",
265
+ "model.language_model.layers.3.post_attention_layernorm.weight": "model-00002-of-00008.safetensors",
266
+ "model.language_model.layers.3.self_attn.k_norm.weight": "model-00002-of-00008.safetensors",
267
+ "model.language_model.layers.3.self_attn.k_proj.weight": "model-00002-of-00008.safetensors",
268
+ "model.language_model.layers.3.self_attn.o_proj.weight": "model-00002-of-00008.safetensors",
269
+ "model.language_model.layers.3.self_attn.q_norm.weight": "model-00002-of-00008.safetensors",
270
+ "model.language_model.layers.3.self_attn.q_proj.weight": "model-00002-of-00008.safetensors",
271
+ "model.language_model.layers.3.self_attn.v_proj.weight": "model-00002-of-00008.safetensors",
272
+ "model.language_model.layers.30.input_layernorm.weight": "model-00006-of-00008.safetensors",
273
+ "model.language_model.layers.30.mlp.down_proj.weight": "model-00006-of-00008.safetensors",
274
+ "model.language_model.layers.30.mlp.gate_proj.weight": "model-00006-of-00008.safetensors",
275
+ "model.language_model.layers.30.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
276
+ "model.language_model.layers.30.post_attention_layernorm.weight": "model-00006-of-00008.safetensors",
277
+ "model.language_model.layers.30.self_attn.k_norm.weight": "model-00006-of-00008.safetensors",
278
+ "model.language_model.layers.30.self_attn.k_proj.weight": "model-00006-of-00008.safetensors",
279
+ "model.language_model.layers.30.self_attn.o_proj.weight": "model-00006-of-00008.safetensors",
280
+ "model.language_model.layers.30.self_attn.q_norm.weight": "model-00006-of-00008.safetensors",
281
+ "model.language_model.layers.30.self_attn.q_proj.weight": "model-00006-of-00008.safetensors",
282
+ "model.language_model.layers.30.self_attn.v_proj.weight": "model-00006-of-00008.safetensors",
283
+ "model.language_model.layers.31.input_layernorm.weight": "model-00007-of-00008.safetensors",
284
+ "model.language_model.layers.31.mlp.down_proj.weight": "model-00007-of-00008.safetensors",
285
+ "model.language_model.layers.31.mlp.gate_proj.weight": "model-00006-of-00008.safetensors",
286
+ "model.language_model.layers.31.mlp.up_proj.weight": "model-00006-of-00008.safetensors",
287
+ "model.language_model.layers.31.post_attention_layernorm.weight": "model-00007-of-00008.safetensors",
288
+ "model.language_model.layers.31.self_attn.k_norm.weight": "model-00006-of-00008.safetensors",
289
+ "model.language_model.layers.31.self_attn.k_proj.weight": "model-00006-of-00008.safetensors",
290
+ "model.language_model.layers.31.self_attn.o_proj.weight": "model-00006-of-00008.safetensors",
291
+ "model.language_model.layers.31.self_attn.q_norm.weight": "model-00006-of-00008.safetensors",
292
+ "model.language_model.layers.31.self_attn.q_proj.weight": "model-00006-of-00008.safetensors",
293
+ "model.language_model.layers.31.self_attn.v_proj.weight": "model-00006-of-00008.safetensors",
294
+ "model.language_model.layers.32.input_layernorm.weight": "model-00007-of-00008.safetensors",
295
+ "model.language_model.layers.32.mlp.down_proj.weight": "model-00007-of-00008.safetensors",
296
+ "model.language_model.layers.32.mlp.gate_proj.weight": "model-00007-of-00008.safetensors",
297
+ "model.language_model.layers.32.mlp.up_proj.weight": "model-00007-of-00008.safetensors",
298
+ "model.language_model.layers.32.post_attention_layernorm.weight": "model-00007-of-00008.safetensors",
299
+ "model.language_model.layers.32.self_attn.k_norm.weight": "model-00007-of-00008.safetensors",
300
+ "model.language_model.layers.32.self_attn.k_proj.weight": "model-00007-of-00008.safetensors",
301
+ "model.language_model.layers.32.self_attn.o_proj.weight": "model-00007-of-00008.safetensors",
302
+ "model.language_model.layers.32.self_attn.q_norm.weight": "model-00007-of-00008.safetensors",
303
+ "model.language_model.layers.32.self_attn.q_proj.weight": "model-00007-of-00008.safetensors",
304
+ "model.language_model.layers.32.self_attn.v_proj.weight": "model-00007-of-00008.safetensors",
305
+ "model.language_model.layers.33.input_layernorm.weight": "model-00007-of-00008.safetensors",
306
+ "model.language_model.layers.33.mlp.down_proj.weight": "model-00007-of-00008.safetensors",
307
+ "model.language_model.layers.33.mlp.gate_proj.weight": "model-00007-of-00008.safetensors",
308
+ "model.language_model.layers.33.mlp.up_proj.weight": "model-00007-of-00008.safetensors",
309
+ "model.language_model.layers.33.post_attention_layernorm.weight": "model-00007-of-00008.safetensors",
310
+ "model.language_model.layers.33.self_attn.k_norm.weight": "model-00007-of-00008.safetensors",
311
+ "model.language_model.layers.33.self_attn.k_proj.weight": "model-00007-of-00008.safetensors",
312
+ "model.language_model.layers.33.self_attn.o_proj.weight": "model-00007-of-00008.safetensors",
313
+ "model.language_model.layers.33.self_attn.q_norm.weight": "model-00007-of-00008.safetensors",
314
+ "model.language_model.layers.33.self_attn.q_proj.weight": "model-00007-of-00008.safetensors",
315
+ "model.language_model.layers.33.self_attn.v_proj.weight": "model-00007-of-00008.safetensors",
316
+ "model.language_model.layers.34.input_layernorm.weight": "model-00007-of-00008.safetensors",
317
+ "model.language_model.layers.34.mlp.down_proj.weight": "model-00007-of-00008.safetensors",
318
+ "model.language_model.layers.34.mlp.gate_proj.weight": "model-00007-of-00008.safetensors",
319
+ "model.language_model.layers.34.mlp.up_proj.weight": "model-00007-of-00008.safetensors",
320
+ "model.language_model.layers.34.post_attention_layernorm.weight": "model-00007-of-00008.safetensors",
321
+ "model.language_model.layers.34.self_attn.k_norm.weight": "model-00007-of-00008.safetensors",
322
+ "model.language_model.layers.34.self_attn.k_proj.weight": "model-00007-of-00008.safetensors",
323
+ "model.language_model.layers.34.self_attn.o_proj.weight": "model-00007-of-00008.safetensors",
324
+ "model.language_model.layers.34.self_attn.q_norm.weight": "model-00007-of-00008.safetensors",
325
+ "model.language_model.layers.34.self_attn.q_proj.weight": "model-00007-of-00008.safetensors",
326
+ "model.language_model.layers.34.self_attn.v_proj.weight": "model-00007-of-00008.safetensors",
327
+ "model.language_model.layers.35.input_layernorm.weight": "model-00007-of-00008.safetensors",
328
+ "model.language_model.layers.35.mlp.down_proj.weight": "model-00007-of-00008.safetensors",
329
+ "model.language_model.layers.35.mlp.gate_proj.weight": "model-00007-of-00008.safetensors",
330
+ "model.language_model.layers.35.mlp.up_proj.weight": "model-00007-of-00008.safetensors",
331
+ "model.language_model.layers.35.post_attention_layernorm.weight": "model-00007-of-00008.safetensors",
332
+ "model.language_model.layers.35.self_attn.k_norm.weight": "model-00007-of-00008.safetensors",
333
+ "model.language_model.layers.35.self_attn.k_proj.weight": "model-00007-of-00008.safetensors",
334
+ "model.language_model.layers.35.self_attn.o_proj.weight": "model-00007-of-00008.safetensors",
335
+ "model.language_model.layers.35.self_attn.q_norm.weight": "model-00007-of-00008.safetensors",
336
+ "model.language_model.layers.35.self_attn.q_proj.weight": "model-00007-of-00008.safetensors",
337
+ "model.language_model.layers.35.self_attn.v_proj.weight": "model-00007-of-00008.safetensors",
338
+ "model.language_model.layers.4.input_layernorm.weight": "model-00002-of-00008.safetensors",
339
+ "model.language_model.layers.4.mlp.down_proj.weight": "model-00002-of-00008.safetensors",
340
+ "model.language_model.layers.4.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
341
+ "model.language_model.layers.4.mlp.up_proj.weight": "model-00002-of-00008.safetensors",
342
+ "model.language_model.layers.4.post_attention_layernorm.weight": "model-00002-of-00008.safetensors",
343
+ "model.language_model.layers.4.self_attn.k_norm.weight": "model-00002-of-00008.safetensors",
344
+ "model.language_model.layers.4.self_attn.k_proj.weight": "model-00002-of-00008.safetensors",
345
+ "model.language_model.layers.4.self_attn.o_proj.weight": "model-00002-of-00008.safetensors",
346
+ "model.language_model.layers.4.self_attn.q_norm.weight": "model-00002-of-00008.safetensors",
347
+ "model.language_model.layers.4.self_attn.q_proj.weight": "model-00002-of-00008.safetensors",
348
+ "model.language_model.layers.4.self_attn.v_proj.weight": "model-00002-of-00008.safetensors",
349
+ "model.language_model.layers.5.input_layernorm.weight": "model-00002-of-00008.safetensors",
350
+ "model.language_model.layers.5.mlp.down_proj.weight": "model-00002-of-00008.safetensors",
351
+ "model.language_model.layers.5.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
352
+ "model.language_model.layers.5.mlp.up_proj.weight": "model-00002-of-00008.safetensors",
353
+ "model.language_model.layers.5.post_attention_layernorm.weight": "model-00002-of-00008.safetensors",
354
+ "model.language_model.layers.5.self_attn.k_norm.weight": "model-00002-of-00008.safetensors",
355
+ "model.language_model.layers.5.self_attn.k_proj.weight": "model-00002-of-00008.safetensors",
356
+ "model.language_model.layers.5.self_attn.o_proj.weight": "model-00002-of-00008.safetensors",
357
+ "model.language_model.layers.5.self_attn.q_norm.weight": "model-00002-of-00008.safetensors",
358
+ "model.language_model.layers.5.self_attn.q_proj.weight": "model-00002-of-00008.safetensors",
359
+ "model.language_model.layers.5.self_attn.v_proj.weight": "model-00002-of-00008.safetensors",
360
+ "model.language_model.layers.6.input_layernorm.weight": "model-00003-of-00008.safetensors",
361
+ "model.language_model.layers.6.mlp.down_proj.weight": "model-00003-of-00008.safetensors",
362
+ "model.language_model.layers.6.mlp.gate_proj.weight": "model-00002-of-00008.safetensors",
363
+ "model.language_model.layers.6.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
364
+ "model.language_model.layers.6.post_attention_layernorm.weight": "model-00003-of-00008.safetensors",
365
+ "model.language_model.layers.6.self_attn.k_norm.weight": "model-00002-of-00008.safetensors",
366
+ "model.language_model.layers.6.self_attn.k_proj.weight": "model-00002-of-00008.safetensors",
367
+ "model.language_model.layers.6.self_attn.o_proj.weight": "model-00002-of-00008.safetensors",
368
+ "model.language_model.layers.6.self_attn.q_norm.weight": "model-00002-of-00008.safetensors",
369
+ "model.language_model.layers.6.self_attn.q_proj.weight": "model-00002-of-00008.safetensors",
370
+ "model.language_model.layers.6.self_attn.v_proj.weight": "model-00002-of-00008.safetensors",
371
+ "model.language_model.layers.7.input_layernorm.weight": "model-00003-of-00008.safetensors",
372
+ "model.language_model.layers.7.mlp.down_proj.weight": "model-00003-of-00008.safetensors",
373
+ "model.language_model.layers.7.mlp.gate_proj.weight": "model-00003-of-00008.safetensors",
374
+ "model.language_model.layers.7.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
375
+ "model.language_model.layers.7.post_attention_layernorm.weight": "model-00003-of-00008.safetensors",
376
+ "model.language_model.layers.7.self_attn.k_norm.weight": "model-00003-of-00008.safetensors",
377
+ "model.language_model.layers.7.self_attn.k_proj.weight": "model-00003-of-00008.safetensors",
378
+ "model.language_model.layers.7.self_attn.o_proj.weight": "model-00003-of-00008.safetensors",
379
+ "model.language_model.layers.7.self_attn.q_norm.weight": "model-00003-of-00008.safetensors",
380
+ "model.language_model.layers.7.self_attn.q_proj.weight": "model-00003-of-00008.safetensors",
381
+ "model.language_model.layers.7.self_attn.v_proj.weight": "model-00003-of-00008.safetensors",
382
+ "model.language_model.layers.8.input_layernorm.weight": "model-00003-of-00008.safetensors",
383
+ "model.language_model.layers.8.mlp.down_proj.weight": "model-00003-of-00008.safetensors",
384
+ "model.language_model.layers.8.mlp.gate_proj.weight": "model-00003-of-00008.safetensors",
385
+ "model.language_model.layers.8.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
386
+ "model.language_model.layers.8.post_attention_layernorm.weight": "model-00003-of-00008.safetensors",
387
+ "model.language_model.layers.8.self_attn.k_norm.weight": "model-00003-of-00008.safetensors",
388
+ "model.language_model.layers.8.self_attn.k_proj.weight": "model-00003-of-00008.safetensors",
389
+ "model.language_model.layers.8.self_attn.o_proj.weight": "model-00003-of-00008.safetensors",
390
+ "model.language_model.layers.8.self_attn.q_norm.weight": "model-00003-of-00008.safetensors",
391
+ "model.language_model.layers.8.self_attn.q_proj.weight": "model-00003-of-00008.safetensors",
392
+ "model.language_model.layers.8.self_attn.v_proj.weight": "model-00003-of-00008.safetensors",
393
+ "model.language_model.layers.9.input_layernorm.weight": "model-00003-of-00008.safetensors",
394
+ "model.language_model.layers.9.mlp.down_proj.weight": "model-00003-of-00008.safetensors",
395
+ "model.language_model.layers.9.mlp.gate_proj.weight": "model-00003-of-00008.safetensors",
396
+ "model.language_model.layers.9.mlp.up_proj.weight": "model-00003-of-00008.safetensors",
397
+ "model.language_model.layers.9.post_attention_layernorm.weight": "model-00003-of-00008.safetensors",
398
+ "model.language_model.layers.9.self_attn.k_norm.weight": "model-00003-of-00008.safetensors",
399
+ "model.language_model.layers.9.self_attn.k_proj.weight": "model-00003-of-00008.safetensors",
400
+ "model.language_model.layers.9.self_attn.o_proj.weight": "model-00003-of-00008.safetensors",
401
+ "model.language_model.layers.9.self_attn.q_norm.weight": "model-00003-of-00008.safetensors",
402
+ "model.language_model.layers.9.self_attn.q_proj.weight": "model-00003-of-00008.safetensors",
403
+ "model.language_model.layers.9.self_attn.v_proj.weight": "model-00003-of-00008.safetensors",
404
+ "model.language_model.norm.weight": "model-00007-of-00008.safetensors",
405
+ "model.visual.blocks.0.attn.proj.bias": "model-00001-of-00008.safetensors",
406
+ "model.visual.blocks.0.attn.proj.weight": "model-00001-of-00008.safetensors",
407
+ "model.visual.blocks.0.attn.qkv.bias": "model-00001-of-00008.safetensors",
408
+ "model.visual.blocks.0.attn.qkv.weight": "model-00001-of-00008.safetensors",
409
+ "model.visual.blocks.0.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
410
+ "model.visual.blocks.0.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
411
+ "model.visual.blocks.0.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
412
+ "model.visual.blocks.0.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
413
+ "model.visual.blocks.0.norm1.bias": "model-00001-of-00008.safetensors",
414
+ "model.visual.blocks.0.norm1.weight": "model-00001-of-00008.safetensors",
415
+ "model.visual.blocks.0.norm2.bias": "model-00001-of-00008.safetensors",
416
+ "model.visual.blocks.0.norm2.weight": "model-00001-of-00008.safetensors",
417
+ "model.visual.blocks.1.attn.proj.bias": "model-00001-of-00008.safetensors",
418
+ "model.visual.blocks.1.attn.proj.weight": "model-00001-of-00008.safetensors",
419
+ "model.visual.blocks.1.attn.qkv.bias": "model-00001-of-00008.safetensors",
420
+ "model.visual.blocks.1.attn.qkv.weight": "model-00001-of-00008.safetensors",
421
+ "model.visual.blocks.1.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
422
+ "model.visual.blocks.1.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
423
+ "model.visual.blocks.1.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
424
+ "model.visual.blocks.1.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
425
+ "model.visual.blocks.1.norm1.bias": "model-00001-of-00008.safetensors",
426
+ "model.visual.blocks.1.norm1.weight": "model-00001-of-00008.safetensors",
427
+ "model.visual.blocks.1.norm2.bias": "model-00001-of-00008.safetensors",
428
+ "model.visual.blocks.1.norm2.weight": "model-00001-of-00008.safetensors",
429
+ "model.visual.blocks.10.attn.proj.bias": "model-00001-of-00008.safetensors",
430
+ "model.visual.blocks.10.attn.proj.weight": "model-00001-of-00008.safetensors",
431
+ "model.visual.blocks.10.attn.qkv.bias": "model-00001-of-00008.safetensors",
432
+ "model.visual.blocks.10.attn.qkv.weight": "model-00001-of-00008.safetensors",
433
+ "model.visual.blocks.10.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
434
+ "model.visual.blocks.10.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
435
+ "model.visual.blocks.10.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
436
+ "model.visual.blocks.10.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
437
+ "model.visual.blocks.10.norm1.bias": "model-00001-of-00008.safetensors",
438
+ "model.visual.blocks.10.norm1.weight": "model-00001-of-00008.safetensors",
439
+ "model.visual.blocks.10.norm2.bias": "model-00001-of-00008.safetensors",
440
+ "model.visual.blocks.10.norm2.weight": "model-00001-of-00008.safetensors",
441
+ "model.visual.blocks.11.attn.proj.bias": "model-00001-of-00008.safetensors",
442
+ "model.visual.blocks.11.attn.proj.weight": "model-00001-of-00008.safetensors",
443
+ "model.visual.blocks.11.attn.qkv.bias": "model-00001-of-00008.safetensors",
444
+ "model.visual.blocks.11.attn.qkv.weight": "model-00001-of-00008.safetensors",
445
+ "model.visual.blocks.11.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
446
+ "model.visual.blocks.11.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
447
+ "model.visual.blocks.11.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
448
+ "model.visual.blocks.11.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
449
+ "model.visual.blocks.11.norm1.bias": "model-00001-of-00008.safetensors",
450
+ "model.visual.blocks.11.norm1.weight": "model-00001-of-00008.safetensors",
451
+ "model.visual.blocks.11.norm2.bias": "model-00001-of-00008.safetensors",
452
+ "model.visual.blocks.11.norm2.weight": "model-00001-of-00008.safetensors",
453
+ "model.visual.blocks.12.attn.proj.bias": "model-00001-of-00008.safetensors",
454
+ "model.visual.blocks.12.attn.proj.weight": "model-00001-of-00008.safetensors",
455
+ "model.visual.blocks.12.attn.qkv.bias": "model-00001-of-00008.safetensors",
456
+ "model.visual.blocks.12.attn.qkv.weight": "model-00001-of-00008.safetensors",
457
+ "model.visual.blocks.12.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
458
+ "model.visual.blocks.12.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
459
+ "model.visual.blocks.12.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
460
+ "model.visual.blocks.12.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
461
+ "model.visual.blocks.12.norm1.bias": "model-00001-of-00008.safetensors",
462
+ "model.visual.blocks.12.norm1.weight": "model-00001-of-00008.safetensors",
463
+ "model.visual.blocks.12.norm2.bias": "model-00001-of-00008.safetensors",
464
+ "model.visual.blocks.12.norm2.weight": "model-00001-of-00008.safetensors",
465
+ "model.visual.blocks.13.attn.proj.bias": "model-00001-of-00008.safetensors",
466
+ "model.visual.blocks.13.attn.proj.weight": "model-00001-of-00008.safetensors",
467
+ "model.visual.blocks.13.attn.qkv.bias": "model-00001-of-00008.safetensors",
468
+ "model.visual.blocks.13.attn.qkv.weight": "model-00001-of-00008.safetensors",
469
+ "model.visual.blocks.13.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
470
+ "model.visual.blocks.13.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
471
+ "model.visual.blocks.13.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
472
+ "model.visual.blocks.13.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
473
+ "model.visual.blocks.13.norm1.bias": "model-00001-of-00008.safetensors",
474
+ "model.visual.blocks.13.norm1.weight": "model-00001-of-00008.safetensors",
475
+ "model.visual.blocks.13.norm2.bias": "model-00001-of-00008.safetensors",
476
+ "model.visual.blocks.13.norm2.weight": "model-00001-of-00008.safetensors",
477
+ "model.visual.blocks.14.attn.proj.bias": "model-00001-of-00008.safetensors",
478
+ "model.visual.blocks.14.attn.proj.weight": "model-00001-of-00008.safetensors",
479
+ "model.visual.blocks.14.attn.qkv.bias": "model-00001-of-00008.safetensors",
480
+ "model.visual.blocks.14.attn.qkv.weight": "model-00001-of-00008.safetensors",
481
+ "model.visual.blocks.14.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
482
+ "model.visual.blocks.14.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
483
+ "model.visual.blocks.14.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
484
+ "model.visual.blocks.14.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
485
+ "model.visual.blocks.14.norm1.bias": "model-00001-of-00008.safetensors",
486
+ "model.visual.blocks.14.norm1.weight": "model-00001-of-00008.safetensors",
487
+ "model.visual.blocks.14.norm2.bias": "model-00001-of-00008.safetensors",
488
+ "model.visual.blocks.14.norm2.weight": "model-00001-of-00008.safetensors",
489
+ "model.visual.blocks.15.attn.proj.bias": "model-00001-of-00008.safetensors",
490
+ "model.visual.blocks.15.attn.proj.weight": "model-00001-of-00008.safetensors",
491
+ "model.visual.blocks.15.attn.qkv.bias": "model-00001-of-00008.safetensors",
492
+ "model.visual.blocks.15.attn.qkv.weight": "model-00001-of-00008.safetensors",
493
+ "model.visual.blocks.15.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
494
+ "model.visual.blocks.15.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
495
+ "model.visual.blocks.15.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
496
+ "model.visual.blocks.15.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
497
+ "model.visual.blocks.15.norm1.bias": "model-00001-of-00008.safetensors",
498
+ "model.visual.blocks.15.norm1.weight": "model-00001-of-00008.safetensors",
499
+ "model.visual.blocks.15.norm2.bias": "model-00001-of-00008.safetensors",
500
+ "model.visual.blocks.15.norm2.weight": "model-00001-of-00008.safetensors",
501
+ "model.visual.blocks.16.attn.proj.bias": "model-00001-of-00008.safetensors",
502
+ "model.visual.blocks.16.attn.proj.weight": "model-00001-of-00008.safetensors",
503
+ "model.visual.blocks.16.attn.qkv.bias": "model-00001-of-00008.safetensors",
504
+ "model.visual.blocks.16.attn.qkv.weight": "model-00001-of-00008.safetensors",
505
+ "model.visual.blocks.16.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
506
+ "model.visual.blocks.16.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
507
+ "model.visual.blocks.16.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
508
+ "model.visual.blocks.16.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
509
+ "model.visual.blocks.16.norm1.bias": "model-00001-of-00008.safetensors",
510
+ "model.visual.blocks.16.norm1.weight": "model-00001-of-00008.safetensors",
511
+ "model.visual.blocks.16.norm2.bias": "model-00001-of-00008.safetensors",
512
+ "model.visual.blocks.16.norm2.weight": "model-00001-of-00008.safetensors",
513
+ "model.visual.blocks.17.attn.proj.bias": "model-00001-of-00008.safetensors",
514
+ "model.visual.blocks.17.attn.proj.weight": "model-00001-of-00008.safetensors",
515
+ "model.visual.blocks.17.attn.qkv.bias": "model-00001-of-00008.safetensors",
516
+ "model.visual.blocks.17.attn.qkv.weight": "model-00001-of-00008.safetensors",
517
+ "model.visual.blocks.17.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
518
+ "model.visual.blocks.17.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
519
+ "model.visual.blocks.17.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
520
+ "model.visual.blocks.17.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
521
+ "model.visual.blocks.17.norm1.bias": "model-00001-of-00008.safetensors",
522
+ "model.visual.blocks.17.norm1.weight": "model-00001-of-00008.safetensors",
523
+ "model.visual.blocks.17.norm2.bias": "model-00001-of-00008.safetensors",
524
+ "model.visual.blocks.17.norm2.weight": "model-00001-of-00008.safetensors",
525
+ "model.visual.blocks.18.attn.proj.bias": "model-00001-of-00008.safetensors",
526
+ "model.visual.blocks.18.attn.proj.weight": "model-00001-of-00008.safetensors",
527
+ "model.visual.blocks.18.attn.qkv.bias": "model-00001-of-00008.safetensors",
528
+ "model.visual.blocks.18.attn.qkv.weight": "model-00001-of-00008.safetensors",
529
+ "model.visual.blocks.18.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
530
+ "model.visual.blocks.18.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
531
+ "model.visual.blocks.18.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
532
+ "model.visual.blocks.18.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
533
+ "model.visual.blocks.18.norm1.bias": "model-00001-of-00008.safetensors",
534
+ "model.visual.blocks.18.norm1.weight": "model-00001-of-00008.safetensors",
535
+ "model.visual.blocks.18.norm2.bias": "model-00001-of-00008.safetensors",
536
+ "model.visual.blocks.18.norm2.weight": "model-00001-of-00008.safetensors",
537
+ "model.visual.blocks.19.attn.proj.bias": "model-00001-of-00008.safetensors",
538
+ "model.visual.blocks.19.attn.proj.weight": "model-00001-of-00008.safetensors",
539
+ "model.visual.blocks.19.attn.qkv.bias": "model-00001-of-00008.safetensors",
540
+ "model.visual.blocks.19.attn.qkv.weight": "model-00001-of-00008.safetensors",
541
+ "model.visual.blocks.19.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
542
+ "model.visual.blocks.19.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
543
+ "model.visual.blocks.19.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
544
+ "model.visual.blocks.19.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
545
+ "model.visual.blocks.19.norm1.bias": "model-00001-of-00008.safetensors",
546
+ "model.visual.blocks.19.norm1.weight": "model-00001-of-00008.safetensors",
547
+ "model.visual.blocks.19.norm2.bias": "model-00001-of-00008.safetensors",
548
+ "model.visual.blocks.19.norm2.weight": "model-00001-of-00008.safetensors",
549
+ "model.visual.blocks.2.attn.proj.bias": "model-00001-of-00008.safetensors",
550
+ "model.visual.blocks.2.attn.proj.weight": "model-00001-of-00008.safetensors",
551
+ "model.visual.blocks.2.attn.qkv.bias": "model-00001-of-00008.safetensors",
552
+ "model.visual.blocks.2.attn.qkv.weight": "model-00001-of-00008.safetensors",
553
+ "model.visual.blocks.2.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
554
+ "model.visual.blocks.2.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
555
+ "model.visual.blocks.2.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
556
+ "model.visual.blocks.2.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
557
+ "model.visual.blocks.2.norm1.bias": "model-00001-of-00008.safetensors",
558
+ "model.visual.blocks.2.norm1.weight": "model-00001-of-00008.safetensors",
559
+ "model.visual.blocks.2.norm2.bias": "model-00001-of-00008.safetensors",
560
+ "model.visual.blocks.2.norm2.weight": "model-00001-of-00008.safetensors",
561
+ "model.visual.blocks.20.attn.proj.bias": "model-00001-of-00008.safetensors",
562
+ "model.visual.blocks.20.attn.proj.weight": "model-00001-of-00008.safetensors",
563
+ "model.visual.blocks.20.attn.qkv.bias": "model-00001-of-00008.safetensors",
564
+ "model.visual.blocks.20.attn.qkv.weight": "model-00001-of-00008.safetensors",
565
+ "model.visual.blocks.20.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
566
+ "model.visual.blocks.20.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
567
+ "model.visual.blocks.20.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
568
+ "model.visual.blocks.20.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
569
+ "model.visual.blocks.20.norm1.bias": "model-00001-of-00008.safetensors",
570
+ "model.visual.blocks.20.norm1.weight": "model-00001-of-00008.safetensors",
571
+ "model.visual.blocks.20.norm2.bias": "model-00001-of-00008.safetensors",
572
+ "model.visual.blocks.20.norm2.weight": "model-00001-of-00008.safetensors",
573
+ "model.visual.blocks.21.attn.proj.bias": "model-00001-of-00008.safetensors",
574
+ "model.visual.blocks.21.attn.proj.weight": "model-00001-of-00008.safetensors",
575
+ "model.visual.blocks.21.attn.qkv.bias": "model-00001-of-00008.safetensors",
576
+ "model.visual.blocks.21.attn.qkv.weight": "model-00001-of-00008.safetensors",
577
+ "model.visual.blocks.21.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
578
+ "model.visual.blocks.21.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
579
+ "model.visual.blocks.21.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
580
+ "model.visual.blocks.21.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
581
+ "model.visual.blocks.21.norm1.bias": "model-00001-of-00008.safetensors",
582
+ "model.visual.blocks.21.norm1.weight": "model-00001-of-00008.safetensors",
583
+ "model.visual.blocks.21.norm2.bias": "model-00001-of-00008.safetensors",
584
+ "model.visual.blocks.21.norm2.weight": "model-00001-of-00008.safetensors",
585
+ "model.visual.blocks.22.attn.proj.bias": "model-00001-of-00008.safetensors",
586
+ "model.visual.blocks.22.attn.proj.weight": "model-00001-of-00008.safetensors",
587
+ "model.visual.blocks.22.attn.qkv.bias": "model-00001-of-00008.safetensors",
588
+ "model.visual.blocks.22.attn.qkv.weight": "model-00001-of-00008.safetensors",
589
+ "model.visual.blocks.22.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
590
+ "model.visual.blocks.22.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
591
+ "model.visual.blocks.22.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
592
+ "model.visual.blocks.22.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
593
+ "model.visual.blocks.22.norm1.bias": "model-00001-of-00008.safetensors",
594
+ "model.visual.blocks.22.norm1.weight": "model-00001-of-00008.safetensors",
595
+ "model.visual.blocks.22.norm2.bias": "model-00001-of-00008.safetensors",
596
+ "model.visual.blocks.22.norm2.weight": "model-00001-of-00008.safetensors",
597
+ "model.visual.blocks.23.attn.proj.bias": "model-00001-of-00008.safetensors",
598
+ "model.visual.blocks.23.attn.proj.weight": "model-00001-of-00008.safetensors",
599
+ "model.visual.blocks.23.attn.qkv.bias": "model-00001-of-00008.safetensors",
600
+ "model.visual.blocks.23.attn.qkv.weight": "model-00001-of-00008.safetensors",
601
+ "model.visual.blocks.23.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
602
+ "model.visual.blocks.23.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
603
+ "model.visual.blocks.23.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
604
+ "model.visual.blocks.23.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
605
+ "model.visual.blocks.23.norm1.bias": "model-00001-of-00008.safetensors",
606
+ "model.visual.blocks.23.norm1.weight": "model-00001-of-00008.safetensors",
607
+ "model.visual.blocks.23.norm2.bias": "model-00001-of-00008.safetensors",
608
+ "model.visual.blocks.23.norm2.weight": "model-00001-of-00008.safetensors",
609
+ "model.visual.blocks.24.attn.proj.bias": "model-00001-of-00008.safetensors",
610
+ "model.visual.blocks.24.attn.proj.weight": "model-00001-of-00008.safetensors",
611
+ "model.visual.blocks.24.attn.qkv.bias": "model-00001-of-00008.safetensors",
612
+ "model.visual.blocks.24.attn.qkv.weight": "model-00001-of-00008.safetensors",
613
+ "model.visual.blocks.24.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
614
+ "model.visual.blocks.24.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
615
+ "model.visual.blocks.24.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
616
+ "model.visual.blocks.24.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
617
+ "model.visual.blocks.24.norm1.bias": "model-00001-of-00008.safetensors",
618
+ "model.visual.blocks.24.norm1.weight": "model-00001-of-00008.safetensors",
619
+ "model.visual.blocks.24.norm2.bias": "model-00001-of-00008.safetensors",
620
+ "model.visual.blocks.24.norm2.weight": "model-00001-of-00008.safetensors",
621
+ "model.visual.blocks.25.attn.proj.bias": "model-00001-of-00008.safetensors",
622
+ "model.visual.blocks.25.attn.proj.weight": "model-00001-of-00008.safetensors",
623
+ "model.visual.blocks.25.attn.qkv.bias": "model-00001-of-00008.safetensors",
624
+ "model.visual.blocks.25.attn.qkv.weight": "model-00001-of-00008.safetensors",
625
+ "model.visual.blocks.25.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
626
+ "model.visual.blocks.25.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
627
+ "model.visual.blocks.25.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
628
+ "model.visual.blocks.25.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
629
+ "model.visual.blocks.25.norm1.bias": "model-00001-of-00008.safetensors",
630
+ "model.visual.blocks.25.norm1.weight": "model-00001-of-00008.safetensors",
631
+ "model.visual.blocks.25.norm2.bias": "model-00001-of-00008.safetensors",
632
+ "model.visual.blocks.25.norm2.weight": "model-00001-of-00008.safetensors",
633
+ "model.visual.blocks.26.attn.proj.bias": "model-00001-of-00008.safetensors",
634
+ "model.visual.blocks.26.attn.proj.weight": "model-00001-of-00008.safetensors",
635
+ "model.visual.blocks.26.attn.qkv.bias": "model-00001-of-00008.safetensors",
636
+ "model.visual.blocks.26.attn.qkv.weight": "model-00001-of-00008.safetensors",
637
+ "model.visual.blocks.26.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
638
+ "model.visual.blocks.26.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
639
+ "model.visual.blocks.26.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
640
+ "model.visual.blocks.26.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
641
+ "model.visual.blocks.26.norm1.bias": "model-00001-of-00008.safetensors",
642
+ "model.visual.blocks.26.norm1.weight": "model-00001-of-00008.safetensors",
643
+ "model.visual.blocks.26.norm2.bias": "model-00001-of-00008.safetensors",
644
+ "model.visual.blocks.26.norm2.weight": "model-00001-of-00008.safetensors",
645
+ "model.visual.blocks.3.attn.proj.bias": "model-00001-of-00008.safetensors",
646
+ "model.visual.blocks.3.attn.proj.weight": "model-00001-of-00008.safetensors",
647
+ "model.visual.blocks.3.attn.qkv.bias": "model-00001-of-00008.safetensors",
648
+ "model.visual.blocks.3.attn.qkv.weight": "model-00001-of-00008.safetensors",
649
+ "model.visual.blocks.3.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
650
+ "model.visual.blocks.3.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
651
+ "model.visual.blocks.3.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
652
+ "model.visual.blocks.3.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
653
+ "model.visual.blocks.3.norm1.bias": "model-00001-of-00008.safetensors",
654
+ "model.visual.blocks.3.norm1.weight": "model-00001-of-00008.safetensors",
655
+ "model.visual.blocks.3.norm2.bias": "model-00001-of-00008.safetensors",
656
+ "model.visual.blocks.3.norm2.weight": "model-00001-of-00008.safetensors",
657
+ "model.visual.blocks.4.attn.proj.bias": "model-00001-of-00008.safetensors",
658
+ "model.visual.blocks.4.attn.proj.weight": "model-00001-of-00008.safetensors",
659
+ "model.visual.blocks.4.attn.qkv.bias": "model-00001-of-00008.safetensors",
660
+ "model.visual.blocks.4.attn.qkv.weight": "model-00001-of-00008.safetensors",
661
+ "model.visual.blocks.4.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
662
+ "model.visual.blocks.4.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
663
+ "model.visual.blocks.4.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
664
+ "model.visual.blocks.4.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
665
+ "model.visual.blocks.4.norm1.bias": "model-00001-of-00008.safetensors",
666
+ "model.visual.blocks.4.norm1.weight": "model-00001-of-00008.safetensors",
667
+ "model.visual.blocks.4.norm2.bias": "model-00001-of-00008.safetensors",
668
+ "model.visual.blocks.4.norm2.weight": "model-00001-of-00008.safetensors",
669
+ "model.visual.blocks.5.attn.proj.bias": "model-00001-of-00008.safetensors",
670
+ "model.visual.blocks.5.attn.proj.weight": "model-00001-of-00008.safetensors",
671
+ "model.visual.blocks.5.attn.qkv.bias": "model-00001-of-00008.safetensors",
672
+ "model.visual.blocks.5.attn.qkv.weight": "model-00001-of-00008.safetensors",
673
+ "model.visual.blocks.5.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
674
+ "model.visual.blocks.5.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
675
+ "model.visual.blocks.5.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
676
+ "model.visual.blocks.5.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
677
+ "model.visual.blocks.5.norm1.bias": "model-00001-of-00008.safetensors",
678
+ "model.visual.blocks.5.norm1.weight": "model-00001-of-00008.safetensors",
679
+ "model.visual.blocks.5.norm2.bias": "model-00001-of-00008.safetensors",
680
+ "model.visual.blocks.5.norm2.weight": "model-00001-of-00008.safetensors",
681
+ "model.visual.blocks.6.attn.proj.bias": "model-00001-of-00008.safetensors",
682
+ "model.visual.blocks.6.attn.proj.weight": "model-00001-of-00008.safetensors",
683
+ "model.visual.blocks.6.attn.qkv.bias": "model-00001-of-00008.safetensors",
684
+ "model.visual.blocks.6.attn.qkv.weight": "model-00001-of-00008.safetensors",
685
+ "model.visual.blocks.6.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
686
+ "model.visual.blocks.6.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
687
+ "model.visual.blocks.6.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
688
+ "model.visual.blocks.6.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
689
+ "model.visual.blocks.6.norm1.bias": "model-00001-of-00008.safetensors",
690
+ "model.visual.blocks.6.norm1.weight": "model-00001-of-00008.safetensors",
691
+ "model.visual.blocks.6.norm2.bias": "model-00001-of-00008.safetensors",
692
+ "model.visual.blocks.6.norm2.weight": "model-00001-of-00008.safetensors",
693
+ "model.visual.blocks.7.attn.proj.bias": "model-00001-of-00008.safetensors",
694
+ "model.visual.blocks.7.attn.proj.weight": "model-00001-of-00008.safetensors",
695
+ "model.visual.blocks.7.attn.qkv.bias": "model-00001-of-00008.safetensors",
696
+ "model.visual.blocks.7.attn.qkv.weight": "model-00001-of-00008.safetensors",
697
+ "model.visual.blocks.7.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
698
+ "model.visual.blocks.7.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
699
+ "model.visual.blocks.7.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
700
+ "model.visual.blocks.7.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
701
+ "model.visual.blocks.7.norm1.bias": "model-00001-of-00008.safetensors",
702
+ "model.visual.blocks.7.norm1.weight": "model-00001-of-00008.safetensors",
703
+ "model.visual.blocks.7.norm2.bias": "model-00001-of-00008.safetensors",
704
+ "model.visual.blocks.7.norm2.weight": "model-00001-of-00008.safetensors",
705
+ "model.visual.blocks.8.attn.proj.bias": "model-00001-of-00008.safetensors",
706
+ "model.visual.blocks.8.attn.proj.weight": "model-00001-of-00008.safetensors",
707
+ "model.visual.blocks.8.attn.qkv.bias": "model-00001-of-00008.safetensors",
708
+ "model.visual.blocks.8.attn.qkv.weight": "model-00001-of-00008.safetensors",
709
+ "model.visual.blocks.8.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
710
+ "model.visual.blocks.8.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
711
+ "model.visual.blocks.8.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
712
+ "model.visual.blocks.8.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
713
+ "model.visual.blocks.8.norm1.bias": "model-00001-of-00008.safetensors",
714
+ "model.visual.blocks.8.norm1.weight": "model-00001-of-00008.safetensors",
715
+ "model.visual.blocks.8.norm2.bias": "model-00001-of-00008.safetensors",
716
+ "model.visual.blocks.8.norm2.weight": "model-00001-of-00008.safetensors",
717
+ "model.visual.blocks.9.attn.proj.bias": "model-00001-of-00008.safetensors",
718
+ "model.visual.blocks.9.attn.proj.weight": "model-00001-of-00008.safetensors",
719
+ "model.visual.blocks.9.attn.qkv.bias": "model-00001-of-00008.safetensors",
720
+ "model.visual.blocks.9.attn.qkv.weight": "model-00001-of-00008.safetensors",
721
+ "model.visual.blocks.9.mlp.linear_fc1.bias": "model-00001-of-00008.safetensors",
722
+ "model.visual.blocks.9.mlp.linear_fc1.weight": "model-00001-of-00008.safetensors",
723
+ "model.visual.blocks.9.mlp.linear_fc2.bias": "model-00001-of-00008.safetensors",
724
+ "model.visual.blocks.9.mlp.linear_fc2.weight": "model-00001-of-00008.safetensors",
725
+ "model.visual.blocks.9.norm1.bias": "model-00001-of-00008.safetensors",
726
+ "model.visual.blocks.9.norm1.weight": "model-00001-of-00008.safetensors",
727
+ "model.visual.blocks.9.norm2.bias": "model-00001-of-00008.safetensors",
728
+ "model.visual.blocks.9.norm2.weight": "model-00001-of-00008.safetensors",
729
+ "model.visual.deepstack_merger_list.0.linear_fc1.bias": "model-00001-of-00008.safetensors",
730
+ "model.visual.deepstack_merger_list.0.linear_fc1.weight": "model-00001-of-00008.safetensors",
731
+ "model.visual.deepstack_merger_list.0.linear_fc2.bias": "model-00001-of-00008.safetensors",
732
+ "model.visual.deepstack_merger_list.0.linear_fc2.weight": "model-00001-of-00008.safetensors",
733
+ "model.visual.deepstack_merger_list.0.norm.bias": "model-00001-of-00008.safetensors",
734
+ "model.visual.deepstack_merger_list.0.norm.weight": "model-00001-of-00008.safetensors",
735
+ "model.visual.deepstack_merger_list.1.linear_fc1.bias": "model-00001-of-00008.safetensors",
736
+ "model.visual.deepstack_merger_list.1.linear_fc1.weight": "model-00001-of-00008.safetensors",
737
+ "model.visual.deepstack_merger_list.1.linear_fc2.bias": "model-00001-of-00008.safetensors",
738
+ "model.visual.deepstack_merger_list.1.linear_fc2.weight": "model-00001-of-00008.safetensors",
739
+ "model.visual.deepstack_merger_list.1.norm.bias": "model-00001-of-00008.safetensors",
740
+ "model.visual.deepstack_merger_list.1.norm.weight": "model-00001-of-00008.safetensors",
741
+ "model.visual.deepstack_merger_list.2.linear_fc1.bias": "model-00001-of-00008.safetensors",
742
+ "model.visual.deepstack_merger_list.2.linear_fc1.weight": "model-00001-of-00008.safetensors",
743
+ "model.visual.deepstack_merger_list.2.linear_fc2.bias": "model-00001-of-00008.safetensors",
744
+ "model.visual.deepstack_merger_list.2.linear_fc2.weight": "model-00001-of-00008.safetensors",
745
+ "model.visual.deepstack_merger_list.2.norm.bias": "model-00001-of-00008.safetensors",
746
+ "model.visual.deepstack_merger_list.2.norm.weight": "model-00001-of-00008.safetensors",
747
+ "model.visual.merger.linear_fc1.bias": "model-00001-of-00008.safetensors",
748
+ "model.visual.merger.linear_fc1.weight": "model-00001-of-00008.safetensors",
749
+ "model.visual.merger.linear_fc2.bias": "model-00001-of-00008.safetensors",
750
+ "model.visual.merger.linear_fc2.weight": "model-00001-of-00008.safetensors",
751
+ "model.visual.merger.norm.bias": "model-00001-of-00008.safetensors",
752
+ "model.visual.merger.norm.weight": "model-00001-of-00008.safetensors",
753
+ "model.visual.patch_embed.proj.bias": "model-00001-of-00008.safetensors",
754
+ "model.visual.patch_embed.proj.weight": "model-00001-of-00008.safetensors",
755
+ "model.visual.pos_embed.weight": "model-00001-of-00008.safetensors"
756
+ }
757
+ }
ICL/sft_model/epoch3_step1406_fp32/preprocessor_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "size": {
3
+ "longest_edge": 16777216,
4
+ "shortest_edge": 65536
5
+ },
6
+ "patch_size": 16,
7
+ "temporal_patch_size": 2,
8
+ "merge_size": 2,
9
+ "image_mean": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "image_std": [
15
+ 0.5,
16
+ 0.5,
17
+ 0.5
18
+ ],
19
+ "processor_class": "Qwen3VLProcessor",
20
+ "image_processor_type": "Qwen2VLImageProcessorFast"
21
+ }
ICL/sft_model/epoch3_step1406_fp32/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
ICL/sft_model/epoch3_step1406_fp32/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].content is string %}\n {{- messages[0].content }}\n {%- else %}\n {%- for content in messages[0].content %}\n {%- if 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set image_count = namespace(value=0) %}\n{%- set video_count = namespace(value=0) %}\n{%- for message in messages %}\n {%- if message.role == \"user\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content_item in message.content %}\n {%- if 'text' in content_item %}\n {{- content_item.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and message.content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {%- if message.content is string %}\n {{- message.content }}\n {%- else %}\n {%- for content in message.content %}\n {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}\n {%- set image_count.value = image_count.value + 1 %}\n {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}\n <|vision_start|><|image_pad|><|vision_end|>\n {%- elif content.type == 'video' or 'video' in content %}\n {%- set video_count.value = video_count.value + 1 %}\n {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}\n <|vision_start|><|video_pad|><|vision_end|>\n {%- elif 'text' in content %}\n {{- content.text }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
231
+ "clean_up_tokenization_spaces": false,
232
+ "eos_token": "<|im_end|>",
233
+ "errors": "replace",
234
+ "model_max_length": 262144,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
ICL/sft_model/epoch3_step1406_fp32/video_preprocessor_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "size": {
3
+ "longest_edge": 25165824,
4
+ "shortest_edge": 4096
5
+ },
6
+ "patch_size": 16,
7
+ "temporal_patch_size": 2,
8
+ "merge_size": 2,
9
+ "image_mean": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "image_std": [
15
+ 0.5,
16
+ 0.5,
17
+ 0.5
18
+ ],
19
+ "processor_class": "Qwen3VLProcessor",
20
+ "video_processor_type": "Qwen3VLVideoProcessor"
21
+ }
ICL/sft_model/epoch3_step1406_fp32/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
ICL/sft_model/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info("Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info("Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
RL_dataset/.gitattributes ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar filter=lfs diff=lfs merge=lfs -text
22
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
23
+ *.mat filter=lfs diff=lfs merge=lfs -text
24
+ *.npz filter=lfs diff=lfs merge=lfs -text
25
+ *.npy filter=lfs diff=lfs merge=lfs -text
26
+ *.h5 filter=lfs diff=lfs merge=lfs -text
27
+ *.hdf5 filter=lfs diff=lfs merge=lfs -text
28
+ *.pickle filter=lfs diff=lfs merge=lfs -text
29
+ *.pkl filter=lfs diff=lfs merge=lfs -text
30
+ *.tflite filter=lfs diff=lfs merge=lfs -text
31
+ *.tgz filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
35
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.db* filter=lfs diff=lfs merge=lfs -text
37
+ *.ark* filter=lfs diff=lfs merge=lfs -text
38
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
39
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
40
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
41
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
42
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
43
+ *.jpg filter=lfs diff=lfs merge=lfs -text
44
+ *.png filter=lfs diff=lfs merge=lfs -text
45
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
46
+ *.bmp filter=lfs diff=lfs merge=lfs -text
47
+ *.gif filter=lfs diff=lfs merge=lfs -text
48
+ *.webp filter=lfs diff=lfs merge=lfs -text
49
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
50
+ *.wav filter=lfs diff=lfs merge=lfs -text
51
+ *.wma filter=lfs diff=lfs merge=lfs -text
52
+ *.aac filter=lfs diff=lfs merge=lfs -text
53
+ *.ogg filter=lfs diff=lfs merge=lfs -text
54
+ *.m4a filter=lfs diff=lfs merge=lfs -text
55
+ *.m3u8 filter=lfs diff=lfs merge=lfs -text
56
+ *.amr filter=lfs diff=lfs merge=lfs -text
57
+ *.audio filter=lfs diff=lfs merge=lfs -text
58
+ *.avi filter=lfs diff=lfs merge=lfs -text
59
+ *.flv filter=lfs diff=lfs merge=lfs -text
60
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
61
+ *.mpg filter=lfs diff=lfs merge=lfs -text
62
+ *.asf filter=lfs diff=lfs merge=lfs -text
63
+ *.mov filter=lfs diff=lfs merge=lfs -text
64
+ *.mpeg filter=lfs diff=lfs merge=lfs -text
65
+ *.3gp filter=lfs diff=lfs merge=lfs -text
66
+ *.wmv filter=lfs diff=lfs merge=lfs -text
67
+ *.rmvb filter=lfs diff=lfs merge=lfs -text
68
+ *.rm filter=lfs diff=lfs merge=lfs -text
69
+ *.ts filter=lfs diff=lfs merge=lfs -text
70
+ *.mkv filter=lfs diff=lfs merge=lfs -text
71
+ *.flash filter=lfs diff=lfs merge=lfs -text
72
+ *.vob filter=lfs diff=lfs merge=lfs -text
73
+ *.pdf filter=lfs diff=lfs merge=lfs -text
74
+ *.ost filter=lfs diff=lfs merge=lfs -text
75
+ *.pst filter=lfs diff=lfs merge=lfs -text
76
+ *.doc filter=lfs diff=lfs merge=lfs -text
77
+ *.docx filter=lfs diff=lfs merge=lfs -text
78
+ *.txt filter=lfs diff=lfs merge=lfs -text
79
+ *.ppt filter=lfs diff=lfs merge=lfs -text
80
+ *.pptx filter=lfs diff=lfs merge=lfs -text
81
+ *.xls filter=lfs diff=lfs merge=lfs -text
82
+ *.xlsx filter=lfs diff=lfs merge=lfs -text
83
+ *.vsd filter=lfs diff=lfs merge=lfs -text
84
+ *.vsdx filter=lfs diff=lfs merge=lfs -text
85
+ *.jsonl filter=lfs diff=lfs merge=lfs -text
86
+ *.json filter=lfs diff=lfs merge=lfs -text
87
+ dataset_infos.json ignore
88
+ *.csv filter=lfs diff=lfs merge=lfs -text
89
+ *.tsv filter=lfs diff=lfs merge=lfs -text
RL_dataset/.msc ADDED
Binary file (546 Bytes). View file
 
RL_dataset/.mv ADDED
@@ -0,0 +1 @@
 
 
1
+ master
RL_dataset/INFOSEEK_DOWNLOAD.md ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # InfoSeek Data Download
2
+
3
+ This document collects ready-to-run scripts for downloading the InfoSeek dataset into:
4
+
5
+ `/workspace/xiaobin/RL_dataset/data`
6
+
7
+ It covers:
8
+
9
+ - InfoSeek annotations
10
+ - InfoSeek KB mapping files
11
+ - InfoSeek human set
12
+ - Wiki6M text files
13
+ - OVEN image snapshot on Hugging Face
14
+ - OVEN original-source image download workflow
15
+
16
+ InfoSeek images are derived from OVEN, so image download is handled through the OVEN release pipeline.
17
+
18
+ ## 1. Recommended Directory Layout
19
+
20
+ ```bash
21
+ mkdir -p /workspace/xiaobin/RL_dataset/data/infoseek
22
+ mkdir -p /workspace/xiaobin/RL_dataset/data/oven_hf
23
+ mkdir -p /workspace/xiaobin/RL_dataset/data/oven_source
24
+ ```
25
+
26
+ Suggested usage:
27
+
28
+ - `/workspace/xiaobin/RL_dataset/data/infoseek`: InfoSeek jsonl files
29
+ - `/workspace/xiaobin/RL_dataset/data/oven_hf`: Hugging Face image snapshot files
30
+ - `/workspace/xiaobin/RL_dataset/data/oven_source`: upstream OVEN repo for original-source image download
31
+
32
+ ## 2. Proxy Workaround
33
+
34
+ If your shell is configured with an invalid local proxy such as `127.0.0.1:7890`, use one of these patterns.
35
+
36
+ Temporarily disable proxy for a single command:
37
+
38
+ ```bash
39
+ env -u http_proxy -u https_proxy -u HTTP_PROXY -u HTTPS_PROXY wget -c URL
40
+ ```
41
+
42
+ Or disable proxy for the current shell session:
43
+
44
+ ```bash
45
+ unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY
46
+ ```
47
+
48
+ ## 3. Download All InfoSeek Text Data With `wget`
49
+
50
+ This is the simplest full download for the released InfoSeek jsonl files.
51
+
52
+ ```bash
53
+ #!/usr/bin/env bash
54
+ set -euo pipefail
55
+
56
+ TARGET_DIR="/workspace/xiaobin/RL_dataset/data/infoseek"
57
+ mkdir -p "${TARGET_DIR}"
58
+ cd "${TARGET_DIR}"
59
+
60
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_train.jsonl
61
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_val.jsonl
62
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_test.jsonl
63
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_train_withkb.jsonl
64
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_val_withkb.jsonl
65
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_human.jsonl
66
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/Wiki6M_ver_1_0.jsonl.gz
67
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/Wiki6M_ver_1_0_title_only.jsonl
68
+
69
+ ls -lh "${TARGET_DIR}"
70
+ ```
71
+
72
+ ## 4. Download All InfoSeek Text Data With `curl`
73
+
74
+ Use this if `wget` is not available.
75
+
76
+ ```bash
77
+ #!/usr/bin/env bash
78
+ set -euo pipefail
79
+
80
+ TARGET_DIR="/workspace/xiaobin/RL_dataset/data/infoseek"
81
+ mkdir -p "${TARGET_DIR}"
82
+ cd "${TARGET_DIR}"
83
+
84
+ curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_train.jsonl
85
+ curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_val.jsonl
86
+ curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_test.jsonl
87
+ curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_train_withkb.jsonl
88
+ curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_val_withkb.jsonl
89
+ curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_human.jsonl
90
+ curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/Wiki6M_ver_1_0.jsonl.gz
91
+ curl -L -O http://storage.googleapis.com/gresearch/open-vision-language/Wiki6M_ver_1_0_title_only.jsonl
92
+
93
+ ls -lh "${TARGET_DIR}"
94
+ ```
95
+
96
+ ## 5. Download Only Core InfoSeek Splits
97
+
98
+ If you only need the standard train/val/test annotations:
99
+
100
+ ```bash
101
+ #!/usr/bin/env bash
102
+ set -euo pipefail
103
+
104
+ TARGET_DIR="/workspace/xiaobin/RL_dataset/data/infoseek"
105
+ mkdir -p "${TARGET_DIR}"
106
+ cd "${TARGET_DIR}"
107
+
108
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_train.jsonl
109
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_val.jsonl
110
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_test.jsonl
111
+ ```
112
+
113
+ ## 6. Download Only KB Mapping Files
114
+
115
+ ```bash
116
+ #!/usr/bin/env bash
117
+ set -euo pipefail
118
+
119
+ TARGET_DIR="/workspace/xiaobin/RL_dataset/data/infoseek"
120
+ mkdir -p "${TARGET_DIR}"
121
+ cd "${TARGET_DIR}"
122
+
123
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_train_withkb.jsonl
124
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_val_withkb.jsonl
125
+ ```
126
+
127
+ ## 7. Download Only Human Eval Set
128
+
129
+ ```bash
130
+ #!/usr/bin/env bash
131
+ set -euo pipefail
132
+
133
+ TARGET_DIR="/workspace/xiaobin/RL_dataset/data/infoseek"
134
+ mkdir -p "${TARGET_DIR}"
135
+ cd "${TARGET_DIR}"
136
+
137
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/infoseek/infoseek_human.jsonl
138
+ ```
139
+
140
+ ## 8. Download Only Wiki6M Files
141
+
142
+ ```bash
143
+ #!/usr/bin/env bash
144
+ set -euo pipefail
145
+
146
+ TARGET_DIR="/workspace/xiaobin/RL_dataset/data/infoseek"
147
+ mkdir -p "${TARGET_DIR}"
148
+ cd "${TARGET_DIR}"
149
+
150
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/Wiki6M_ver_1_0.jsonl.gz
151
+ wget -c http://storage.googleapis.com/gresearch/open-vision-language/Wiki6M_ver_1_0_title_only.jsonl
152
+ ```
153
+
154
+ Optional decompression:
155
+
156
+ ```bash
157
+ gunzip -k /workspace/xiaobin/RL_dataset/data/infoseek/Wiki6M_ver_1_0.jsonl.gz
158
+ ```
159
+
160
+ ## 9. Download OVEN Image Snapshot From Hugging Face
161
+
162
+ Upstream OVEN now points image snapshot downloads to the gated dataset `ychenNLP/oven` on Hugging Face. Before downloading:
163
+
164
+ 1. Open `https://huggingface.co/datasets/ychenNLP/oven`
165
+ 2. Accept the dataset access conditions
166
+ 3. Log in with the Hugging Face CLI
167
+
168
+ Install the CLI if needed:
169
+
170
+ ```bash
171
+ python -m pip install -U "huggingface_hub[cli]"
172
+ ```
173
+
174
+ Login:
175
+
176
+ ```bash
177
+ hf auth login
178
+ ```
179
+
180
+ Download the image snapshot and mapping file into `/workspace/xiaobin/RL_dataset/data/oven_hf`:
181
+
182
+ ```bash
183
+ #!/usr/bin/env bash
184
+ set -euo pipefail
185
+
186
+ TARGET_DIR="/workspace/xiaobin/RL_dataset/data/oven_hf"
187
+ mkdir -p "${TARGET_DIR}"
188
+
189
+ hf download ychenNLP/oven \
190
+ --repo-type dataset \
191
+ --local-dir "${TARGET_DIR}" \
192
+ --include "shard*.tar" \
193
+ --include "all_wikipedia_images.tar" \
194
+ --include "ovenid2impath.csv"
195
+ ```
196
+
197
+ Extract the snapshot tar files:
198
+
199
+ ```bash
200
+ #!/usr/bin/env bash
201
+ set -euo pipefail
202
+
203
+ HF_DIR="/workspace/xiaobin/RL_dataset/data/oven_hf"
204
+ IMG_DIR="/workspace/xiaobin/RL_dataset/data/infoseek/images"
205
+ mkdir -p "${IMG_DIR}"
206
+
207
+ for f in "${HF_DIR}"/shard*.tar; do
208
+ tar -xf "${f}" -C "${IMG_DIR}"
209
+ done
210
+
211
+ tar -xf "${HF_DIR}/all_wikipedia_images.tar" -C "${IMG_DIR}"
212
+ ```
213
+
214
+ Notes:
215
+
216
+ - Hugging Face file listing shows `shard01.tar` to `shard08.tar` plus `all_wikipedia_images.tar`
217
+ - The compressed download is very large, roughly 293 GB based on the published file sizes
218
+ - You need additional free space for extraction
219
+
220
+ ## 10. Download OVEN Images From Original Sources
221
+
222
+ This follows the upstream `oven_eval/image_downloads` workflow.
223
+
224
+ ### 10.1 Clone the Upstream Repo
225
+
226
+ ```bash
227
+ git clone https://github.com/edchengg/oven_eval /workspace/xiaobin/RL_dataset/data/oven_source/oven_eval
228
+ ```
229
+
230
+ ### 10.2 Run All Source Download Scripts
231
+
232
+ The upstream image download directory contains these scripts:
233
+
234
+ - `download_aircraft.sh`
235
+ - `download_car196.sh`
236
+ - `download_coco.sh`
237
+ - `download_food101.sh`
238
+ - `download_gldv2.sh`
239
+ - `download_imagenet.sh`
240
+ - `download_inat.sh`
241
+ - `download_oxfordflower.sh`
242
+ - `download_sports100.sh`
243
+ - `download_sun397.sh`
244
+ - `download_textvqa.sh`
245
+ - `download_v7w.sh`
246
+ - `download_vg.sh`
247
+
248
+ Run them one by one:
249
+
250
+ ```bash
251
+ #!/usr/bin/env bash
252
+ set -euo pipefail
253
+
254
+ cd /workspace/xiaobin/RL_dataset/data/oven_source/oven_eval/image_downloads
255
+
256
+ bash download_aircraft.sh
257
+ bash download_car196.sh
258
+ bash download_coco.sh
259
+ bash download_food101.sh
260
+ bash download_gldv2.sh
261
+ bash download_imagenet.sh
262
+ bash download_inat.sh
263
+ bash download_oxfordflower.sh
264
+ bash download_sports100.sh
265
+ bash download_sun397.sh
266
+ bash download_textvqa.sh
267
+ bash download_v7w.sh
268
+ bash download_vg.sh
269
+ ```
270
+
271
+ Or run them in a loop:
272
+
273
+ ```bash
274
+ #!/usr/bin/env bash
275
+ set -euo pipefail
276
+
277
+ cd /workspace/xiaobin/RL_dataset/data/oven_source/oven_eval/image_downloads
278
+
279
+ for script in download_*.sh; do
280
+ bash "${script}"
281
+ done
282
+ ```
283
+
284
+ ### 10.3 Download `ovenid2impath.csv`
285
+
286
+ You need `ovenid2impath.csv` for the merge step. The current recommended source is the Hugging Face dataset:
287
+
288
+ ```bash
289
+ #!/usr/bin/env bash
290
+ set -euo pipefail
291
+
292
+ TARGET_DIR="/workspace/xiaobin/RL_dataset/data/oven_hf"
293
+ mkdir -p "${TARGET_DIR}"
294
+
295
+ hf download ychenNLP/oven \
296
+ --repo-type dataset \
297
+ --local-dir "${TARGET_DIR}" \
298
+ --include "ovenid2impath.csv"
299
+ ```
300
+
301
+ ### 10.4 Merge Into the Final OVEN Image Layout
302
+
303
+ Run the upstream merge script after all downloads finish:
304
+
305
+ ```bash
306
+ cd /workspace/xiaobin/RL_dataset/data/oven_source/oven_eval/image_downloads
307
+ python merge_oven_images.py
308
+ ```
309
+
310
+ The upstream documentation states that `merge_oven_images.py` should be run after all image download scripts complete and after `ovenid2impath.csv` is available.
311
+
312
+ ## 11. Verify the Downloaded Files
313
+
314
+ Check text files:
315
+
316
+ ```bash
317
+ ls -lh /workspace/xiaobin/RL_dataset/data/infoseek
318
+ ```
319
+
320
+ Check Hugging Face snapshot files:
321
+
322
+ ```bash
323
+ ls -lh /workspace/xiaobin/RL_dataset/data/oven_hf
324
+ ```
325
+
326
+ Check extracted images:
327
+
328
+ ```bash
329
+ find /workspace/xiaobin/RL_dataset/data/infoseek/images -type f | wc -l
330
+ ```
331
+
332
+ ## 12. Upstream References
333
+
334
+ - InfoSeek release page: `https://github.com/open-vision-language/infoseek`
335
+ - OVEN image download page: `https://github.com/edchengg/oven_eval/tree/main/image_downloads`
336
+ - Hugging Face OVEN dataset: `https://huggingface.co/datasets/ychenNLP/oven`
337
+ - Hugging Face CLI download docs: `https://huggingface.co/docs/huggingface_hub/guides/cli`
RL_dataset/README.md ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ task_categories:
4
+ - question-answering
5
+ tags:
6
+ - deep-research
7
+ - hierarchical-reasoning
8
+ - multi-hop-qa
9
+ - synthetic-data
10
+ - data-synthesis
11
+ language:
12
+ - en
13
+ ---
14
+
15
+ # InfoSeek: Open Data Synthesis For Deep Research
16
+
17
+ [Paper](https://huggingface.co/papers/2509.00375) | [Code](https://github.com/VectorSpaceLab/InfoSeek)
18
+
19
+ ## Dataset Information
20
+
21
+ * **`data/InfoSeek.jsonl`**
22
+ Contains the full research tree structures of *InfoSeek*. Each sample starts from a root node with a research question, its corresponding entity, and process information for sub-questions (stored in `root`). Also expands into intermediate tree structure during each step of construction (stored in `all_tree_list`). Totally 52K samples.
23
+
24
+ * **`data/InfoSeekQA.jsonl`**
25
+ A collection of QA pairs derived from *InfoSeek*. Each entry corresponds to the final question (`sample['root']['question']`) and its answer entity (`sample['root']['entity']`) in `InfoSeek.jsonl`.
26
+
27
+ * **`data/InfoSeek-Hard-18K.jsonl`**
28
+ A challenging subset of *InfoSeek* (18K samples), which is better to conduct end-to-end RL, identified using an LLM with a dedicated prompt for complex deep research.
29
+
30
+ * **`data/Trajectory-RFT-17K.jsonl`**
31
+ Contains 17K reasoning trajectories generated through the workflow described in our paper. These can be used as training data for supervised fine-tuning (SFT).
32
+
33
+ ## Abstract
34
+ Large language models (LLMs) are increasingly expected to go beyond simple factual queries toward Deep Research-tasks that require decomposing questions into sub-problems, coordinating multi-step reasoning, and synthesizing evidence from diverse sources. We formalize Deep Research tasks with verifiable answers as Hierarchical Constraint Satisfaction Problems (HCSPs), which are fundamentally different from single-constraint, multi-hop, or flat CSP formulations. However, existing benchmarks (e.g., Natural Questions, HotpotQA) fail to capture this complexity, while recent synthetic datasets often introduce shortcut reasoning, knowledge leakage, or lack sufficient structural depth. To address this gap, we introduce InfoSeek, a scalable framework for synthesizing complex Deep Research tasks. InfoSeek uses a dual-agent system to recursively build a Research Tree from large-scale webpages, blurring intermediate nodes into valid sub-problems, and converting these trees into natural language questions that require traversing the full hierarchy. It also enables rapid scaling, yielding over 50K training examples, a curated test set, and reasoning trajectories generated via reject sampling. Experiments show that models trained on InfoSeek consistently outperform strong baselines. On a challenging benchmark BrowseComp-Plus, 3B LLMs optimized with InfoSeek surpass much larger 32B models and lightweight commercial APIs (e.g., Gemini2.5-Flash), while achieving performance comparable to stronger APIs (e.g., Gemini2.5-Pro). By preserving meta-information such as intermediate steps and retrieval labels, InfoSeek further supports advanced optimization strategies, including compound reward design and trajectory-level exploration.
35
+
36
+ ## 🔆 Overview
37
+ We propose **InfoSeek**, a scalable data synthesis framework for constructing structurally complex Deep Research tasks. InfoSeek designs a dual-agent system to recursively build a *Research Tree* by mining entities and relations from large-scale text, and blurring itermediate vertices to ensure they form valid sub-problems. The agent then transform these trees into natural language questions whose solutions require traversing the entire hierarchy. Using InfoSeek pipeline, we construct a high-quality, complexity-controllable, and intrinsically verifiable dataset.
38
+
39
+
40
+ ### Example 1:
41
+ **Question:** What is a species of bird that was named by a person employed under his father between 1818 and 1824, whose wife was a British artist, and which has three subspecies and body length is generally no more than 6 inches?
42
+
43
+ **Answer:** Russet sparrow
44
+
45
+ <details>
46
+ <summary>Tree Structure</summary>
47
+
48
+ ```
49
+ {
50
+ "root": {
51
+ "id": "A",
52
+ "entity": "Russet sparrow",
53
+ "question": "What is a species of bird that was named by a person employed under his father between 1818 and 1824, whose wife was a British artist, and which has three subspecies and body length is generally no more than 6 inches?",
54
+ "claims": [
55
+ { "target_id": "B", "claim": "A was named by B" },
56
+ { "target_id": "C", "claim": "A has three subspecies" },
57
+ { "target_id": "D", "claim": "A's body length is generally no more than 6 inches" }
58
+ ],\
59
+ "children": [
60
+ {
61
+ "id": "B",
62
+ "entity": "John Gould",
63
+ "claims": [
64
+ { "target_id": "E", "claim": "B was employed by his father between 1818 and 1824" },
65
+ { "target_id": "F", "claim": "B's wife was F" }
66
+ ],\
67
+ "children": [
68
+ { "id": "E", "entity": "None", "claims": [], "children": [] },
69
+ { "id": "F", "entity": "Elizabeth Gould", "claims": [], "children": [] }
70
+ ]
71
+ },\
72
+ { "id": "C", "entity": "None", "claims": [], "children": [] },
73
+ { "id": "D", "entity": "None", "claims": [], "children": [] }
74
+ ]
75
+ }
76
+ }
77
+ ```
78
+
79
+ ```
80
+ (A: Russet sparrow)
81
+
82
+
83
+ │── [claim] "was named by" ──> (B: John Gould)
84
+ │ │
85
+ │ │
86
+ │ │── [claim] "was employed by his father (1818-1824)"
87
+ │ │
88
+ │ │
89
+ │ │── [claim] "wife was" ──> (F: Elizabeth Gould)
90
+
91
+
92
+ │── [claim] "has three subspecies"
93
+
94
+
95
+ │── [claim] "body length is generally no more than 6 inches"
96
+ ```
97
+ </details>
98
+
99
+ ### Example 2:
100
+
101
+ **Question:** What is a women's football team whose first goals in the 2. Bundesliga were scored by a player born in Korogocho, who was discovered and developed by the Mathare Youth Sports Association?
102
+
103
+ **Answer:** SV Werder Bremen (women)
104
+
105
+ <details>
106
+ <summary>Tree Structure</summary>
107
+
108
+ ```
109
+ {
110
+ "root": {
111
+ "id": "A",
112
+ "entity": "SV Werder Bremen (women)",
113
+ "question": "What is a women's football team whose first goals in the 2. Bundesliga were scored by a player born in Korogocho, who was discovered and developed by the Mathare Youth Sports Association?",
114
+ "claims": [
115
+ { "target_id": "B", "claim": "A's first goals in the 2. Bundesliga were scored by B" }
116
+ ],\
117
+ "children": [
118
+ {
119
+ "id": "B",
120
+ "entity": "Doreen Nabwire",
121
+ "claims": [
122
+ { "target_id": "C", "claim": "B was discovered and developed by C" },
123
+ { "target_id": "D", "claim": "B was born in D" }
124
+ ],\
125
+ "children": [
126
+ { "id": "C", "entity": "Mathare Youth Sports Association", "claims": [], "children": [] },
127
+ { "id": "D", "entity": "Korogocho", "claims": [], "children": [] }
128
+ ]
129
+ }
130
+ ]
131
+ }
132
+ }
133
+ ```
134
+
135
+ ```
136
+ (A: SV Werder Bremen (women))
137
+
138
+
139
+ │── [claim] "first goals scored by" ──> (B: Doreen Nabwire)
140
+
141
+
142
+ │── [claim] "discovered and developed by" ──> (C:Mathare Youth Sports Association)
143
+
144
+
145
+ │── [claim] "was born in" ──> (D: Korogocho)
146
+ ```
147
+ </details>
148
+
149
+
150
+ ## 📊 Performance
151
+ Model trained on InfoSeek and our framework shows strong performances on traditional multi-hop benchmarks:
152
+
153
+ <img src="https://github.com/VectorSpaceLab/InfoSeek/raw/main/assets/results.png" width="800">
154
+
155
+ Our 3B model shows competitive results on [BrowseComp-Plus](https://github.com/texttron/BrowseComp-Plus):
156
+
157
+ <img src="https://github.com/VectorSpaceLab/InfoSeek/raw/main/assets/browsecomp_plus.png" width="800">
158
+
159
+ ## ❤️ Citing Us
160
+ If you find this repository or our work useful, please consider giving a star ⭐ and or citing our work, which would be greatly appreciated:
161
+ ```bibtex
162
+ @misc{xia2025opendatasynthesisdeep,
163
+ title={Open Data Synthesis For Deep Research},
164
+ author={Ziyi Xia and Kun Luo and Hongjin Qian and Zheng Liu},
165
+ year={2025},\
166
+ eprint={2509.00375},
167
+ archivePrefix={arXiv},
168
+ primaryClass={cs.CL},\
169
+ url={https://arxiv.org/abs/2509.00375},
170
+ }
171
+ ```
RL_dataset/dataset_infos.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"default": {"features": {"root": {"_type": "Value"}, "all_tree_list": {"_type": "Value"}, "vertices": {"_type": "Value"}}, "splits": {"train": {"name": "train", "dataset_name": "InfoSeek"}}}}
RL_dataset/download_oven_hf_mirror.sh ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ MODE="${1:-all}"
5
+
6
+ REPO_ID="ychenNLP/oven"
7
+ TARGET_DIR="/workspace/xiaobin/RL_dataset/data"
8
+ CACHE_DIR="${TARGET_DIR}/.hf_cache"
9
+ ASSETS_DIR="${TARGET_DIR}/.hf_assets"
10
+ DEFAULT_ENDPOINT="https://hf-mirror.com"
11
+ MIRROR_URL="${HF_ENDPOINT:-${HF_ENDPOINT_OVERRIDE:-${DEFAULT_ENDPOINT}}}"
12
+ HARDCODED_TOKEN="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"
13
+ META_FILES=(
14
+ "download_infoseek_jsonl.sh"
15
+ "download_oven_jsonl.sh"
16
+ "ovenid2impath.csv"
17
+ )
18
+ IMAGE_FILES=(
19
+ "shard01.tar"
20
+ "shard02.tar"
21
+ "shard03.tar"
22
+ "shard04.tar"
23
+ "shard05.tar"
24
+ "shard06.tar"
25
+ "shard07.tar"
26
+ "shard08.tar"
27
+ "all_wikipedia_images.tar"
28
+ )
29
+
30
+ unset http_proxy
31
+ unset https_proxy
32
+ unset HTTP_PROXY
33
+ unset HTTPS_PROXY
34
+ unset all_proxy
35
+ unset ALL_PROXY
36
+
37
+ export HF_ENDPOINT="${MIRROR_URL}"
38
+ export HF_HUB_CACHE="${CACHE_DIR}"
39
+ export HF_ASSETS_CACHE="${ASSETS_DIR}"
40
+
41
+ mkdir -p "${TARGET_DIR}" "${CACHE_DIR}" "${ASSETS_DIR}"
42
+
43
+ if command -v hf >/dev/null 2>&1; then
44
+ HF_BIN=(hf download)
45
+ elif command -v huggingface-cli >/dev/null 2>&1; then
46
+ HF_BIN=(huggingface-cli download)
47
+ else
48
+ echo "Missing Hugging Face CLI. Install it with:" >&2
49
+ echo " python -m pip install -U \"huggingface_hub[cli]\"" >&2
50
+ exit 1
51
+ fi
52
+
53
+ TOKEN_ARGS=()
54
+ if [[ -n "${HF_TOKEN:-}" ]]; then
55
+ TOKEN_ARGS=(--token "${HF_TOKEN}")
56
+ elif [[ -n "${HARDCODED_TOKEN}" ]]; then
57
+ TOKEN_ARGS=(--token "${HARDCODED_TOKEN}")
58
+ fi
59
+
60
+ print_help() {
61
+ cat <<'EOF'
62
+ Usage:
63
+ bash download_oven_hf_mirror.sh [meta|images|all]
64
+
65
+ Modes:
66
+ meta Download metadata files only:
67
+ - download_infoseek_jsonl.sh
68
+ - download_oven_jsonl.sh
69
+ - ovenid2impath.csv
70
+ images Download image tar files only:
71
+ - shard01.tar ... shard08.tar
72
+ - all_wikipedia_images.tar
73
+ all Download both metadata and image tar files
74
+
75
+ Behavior:
76
+ - unsets proxy variables before downloading
77
+ - uses the mirror endpoint: https://hf-mirror.com
78
+ - endpoint can be overridden:
79
+ HF_ENDPOINT=https://huggingface.co bash download_oven_hf_mirror.sh meta
80
+ - stores downloaded files in: /workspace/xiaobin/RL_dataset/data
81
+ - stores Hugging Face cache in: /workspace/xiaobin/RL_dataset/data/.hf_cache
82
+
83
+ Notes:
84
+ - The dataset is gated. First accept access at:
85
+ https://huggingface.co/datasets/ychenNLP/oven
86
+ - The script contains a hardcoded token by default.
87
+ - If needed, export your token before running to override it:
88
+ export HF_TOKEN=hf_xxx
89
+ EOF
90
+ }
91
+
92
+ if [[ "${MODE}" == "-h" || "${MODE}" == "--help" || "${MODE}" == "help" ]]; then
93
+ print_help
94
+ exit 0
95
+ fi
96
+
97
+ require_auth() {
98
+ if [[ -n "${HF_TOKEN:-}" ]]; then
99
+ return 0
100
+ fi
101
+
102
+ if hf auth whoami >/dev/null 2>&1; then
103
+ return 0
104
+ fi
105
+
106
+ echo "No Hugging Face authentication detected." >&2
107
+ echo "Do this first:" >&2
108
+ echo " 1. Open https://huggingface.co/datasets/ychenNLP/oven and accept access." >&2
109
+ echo " 2. Run: hf auth login" >&2
110
+ echo " or: export HF_TOKEN=hf_xxx" >&2
111
+ exit 2
112
+ }
113
+
114
+ run_download() {
115
+ if ! "$@"; then
116
+ echo >&2
117
+ echo "Download failed." >&2
118
+ echo "Check these items:" >&2
119
+ echo " - access was approved for https://huggingface.co/datasets/ychenNLP/oven" >&2
120
+ echo " - HF_TOKEN is valid, or 'hf auth login' succeeded" >&2
121
+ echo " - the mirror endpoint is reachable: ${MIRROR_URL}" >&2
122
+ exit 1
123
+ fi
124
+ }
125
+
126
+ verify_files() {
127
+ local missing=0
128
+ local file
129
+
130
+ for file in "$@"; do
131
+ if [[ ! -f "${TARGET_DIR}/${file}" ]]; then
132
+ echo "Missing expected file: ${TARGET_DIR}/${file}" >&2
133
+ missing=1
134
+ fi
135
+ done
136
+
137
+ if [[ "${missing}" -ne 0 ]]; then
138
+ echo >&2
139
+ echo "Download did not complete successfully." >&2
140
+ echo "This usually means one of these:" >&2
141
+ echo " - the mirror endpoint could not be reached" >&2
142
+ echo " - access to the gated dataset was not approved" >&2
143
+ echo " - authentication was missing or invalid" >&2
144
+ exit 1
145
+ fi
146
+ }
147
+
148
+ download_meta() {
149
+ run_download "${HF_BIN[@]}" "${REPO_ID}" \
150
+ --repo-type dataset \
151
+ --local-dir "${TARGET_DIR}" \
152
+ --include "download_infoseek_jsonl.sh" \
153
+ --include "download_oven_jsonl.sh" \
154
+ --include "ovenid2impath.csv" \
155
+ "${TOKEN_ARGS[@]}"
156
+ verify_files "${META_FILES[@]}"
157
+ }
158
+
159
+ download_images() {
160
+ run_download "${HF_BIN[@]}" "${REPO_ID}" \
161
+ --repo-type dataset \
162
+ --local-dir "${TARGET_DIR}" \
163
+ --include "all_wikipedia_images.tar" \
164
+ --include "shard*.tar" \
165
+ "${TOKEN_ARGS[@]}"
166
+ verify_files "${IMAGE_FILES[@]}"
167
+ }
168
+
169
+ require_auth
170
+
171
+ case "${MODE}" in
172
+ meta)
173
+ download_meta
174
+ ;;
175
+ images)
176
+ download_images
177
+ ;;
178
+ all)
179
+ download_meta
180
+ download_images
181
+ ;;
182
+ *)
183
+ echo "Unknown mode: ${MODE}" >&2
184
+ print_help >&2
185
+ exit 1
186
+ ;;
187
+ esac
188
+
189
+ echo "Download completed. Files are under: ${TARGET_DIR}"
RL_dataset/download_scienceqa_hf.sh ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ MODE="${1:-all}"
5
+
6
+ REPO_ID="derek-thomas/ScienceQA"
7
+ ROOT_DIR="/workspace/xiaobin/RL_dataset/data/ScienceQA"
8
+ HF_DIR="${ROOT_DIR}/hf"
9
+ IMG_DIR="${ROOT_DIR}/images"
10
+ CACHE_DIR="${ROOT_DIR}/.hf_cache"
11
+ DEFAULT_ENDPOINT="https://hf-mirror.com"
12
+ HF_ENDPOINT_VALUE="${HF_ENDPOINT:-${HF_ENDPOINT_OVERRIDE:-${DEFAULT_ENDPOINT}}}"
13
+
14
+ unset http_proxy
15
+ unset https_proxy
16
+ unset HTTP_PROXY
17
+ unset HTTPS_PROXY
18
+ unset all_proxy
19
+ unset ALL_PROXY
20
+
21
+ export HF_ENDPOINT="${HF_ENDPOINT_VALUE}"
22
+
23
+ mkdir -p "${HF_DIR}" "${IMG_DIR}" "${CACHE_DIR}"
24
+
25
+ if command -v hf >/dev/null 2>&1; then
26
+ HF_BIN=(hf download)
27
+ elif command -v huggingface-cli >/dev/null 2>&1; then
28
+ HF_BIN=(huggingface-cli download)
29
+ else
30
+ echo "Missing Hugging Face CLI. Install it with:" >&2
31
+ echo " python -m pip install -U \"huggingface_hub[cli]\"" >&2
32
+ exit 1
33
+ fi
34
+
35
+ print_help() {
36
+ cat <<'EOF'
37
+ Usage:
38
+ bash download_scienceqa_hf.sh [parquet|images|all]
39
+
40
+ Modes:
41
+ parquet Download the public Hugging Face parquet files only
42
+ images Download the original ScienceQA image zip files only
43
+ all Download both parquet files and images
44
+
45
+ Output layout:
46
+ /workspace/xiaobin/RL_dataset/data/ScienceQA/hf
47
+ /workspace/xiaobin/RL_dataset/data/ScienceQA/images
48
+
49
+ Notes:
50
+ - This dataset is public and should not require an HF token.
51
+ - Image URLs are adapted from:
52
+ /workspace/xiaobin/RL_dataset/ScienceQA/tools/download.sh
53
+ - Proxies are unset before download.
54
+ - Default HF endpoint: https://hf-mirror.com
55
+ - To override and use the official endpoint:
56
+ HF_ENDPOINT=https://huggingface.co bash download_scienceqa_hf.sh parquet
57
+ EOF
58
+ }
59
+
60
+ if [[ "${MODE}" == "-h" || "${MODE}" == "--help" || "${MODE}" == "help" ]]; then
61
+ print_help
62
+ exit 0
63
+ fi
64
+
65
+ verify_glob() {
66
+ local pattern="$1"
67
+
68
+ if ! compgen -G "${pattern}" >/dev/null; then
69
+ echo "Missing expected file matching: ${pattern}" >&2
70
+ exit 1
71
+ fi
72
+ }
73
+
74
+ download_parquet() {
75
+ "${HF_BIN[@]}" "${REPO_ID}" \
76
+ --repo-type dataset \
77
+ --cache-dir "${CACHE_DIR}" \
78
+ --local-dir "${HF_DIR}" \
79
+ --include "data/*.parquet" \
80
+ --include "README.md" \
81
+ --include "ScienceQA.py"
82
+
83
+ verify_glob "${HF_DIR}/data/train-*.parquet"
84
+ verify_glob "${HF_DIR}/data/validation-*.parquet"
85
+ verify_glob "${HF_DIR}/data/test-*.parquet"
86
+ }
87
+
88
+ download_one_split() {
89
+ local split="$1"
90
+ local zip_path="${IMG_DIR}/${split}.zip"
91
+ local split_dir="${IMG_DIR}/${split}"
92
+ local url="https://scienceqa.s3.us-west-1.amazonaws.com/images/${split}.zip"
93
+
94
+ if [[ -d "${split_dir}" ]]; then
95
+ echo "Image split already exists: ${split_dir}"
96
+ return 0
97
+ fi
98
+
99
+ wget -c -O "${zip_path}" "${url}"
100
+ unzip -q -o "${zip_path}" -d "${IMG_DIR}"
101
+ rm -f "${zip_path}"
102
+
103
+ if [[ ! -d "${split_dir}" ]]; then
104
+ echo "Failed to extract image split: ${split}" >&2
105
+ exit 1
106
+ fi
107
+ }
108
+
109
+ download_images() {
110
+ download_one_split train
111
+ download_one_split val
112
+ download_one_split test
113
+ }
114
+
115
+ case "${MODE}" in
116
+ parquet)
117
+ download_parquet
118
+ ;;
119
+ images)
120
+ download_images
121
+ ;;
122
+ all)
123
+ download_parquet
124
+ download_images
125
+ ;;
126
+ *)
127
+ echo "Unknown mode: ${MODE}" >&2
128
+ print_help >&2
129
+ exit 1
130
+ ;;
131
+ esac
132
+
133
+ echo "Download completed."
134
+ echo "Parquet dir: ${HF_DIR}"
135
+ echo "Image dir: ${IMG_DIR}"
download_hf.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Hugging Face 断点续传下载脚本
4
+ 镜像站: hf-mirror.com
5
+ 目标: MMInstruction/M3IT
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ # 设置国内镜像站
12
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
13
+
14
+ from huggingface_hub import snapshot_download
15
+ from huggingface_hub import hf_hub_download
16
+ import huggingface_hub
17
+
18
+ REPO_ID = "MMInstruction/M3IT"
19
+ LOCAL_DIR = "/workspace/xiaobin/dataset"
20
+ REPO_TYPE = "dataset" # M3IT 是数据集
21
+
22
+
23
+ def download():
24
+ print(f"镜像站: {os.environ['HF_ENDPOINT']}")
25
+ print(f"下载仓库: {REPO_ID}")
26
+ print(f"保存目录: {LOCAL_DIR}")
27
+ print("-" * 50)
28
+
29
+ os.makedirs(LOCAL_DIR, exist_ok=True)
30
+
31
+ try:
32
+ snapshot_download(
33
+ repo_id=REPO_ID,
34
+ repo_type=REPO_TYPE,
35
+ local_dir=LOCAL_DIR,
36
+ local_dir_use_symlinks=False, # 直接复制文件,不用软链接
37
+ resume_download=True, # 断点续传
38
+ ignore_patterns=["*.gitattributes"],
39
+ )
40
+ print("\n下载完成!")
41
+
42
+ except Exception as e:
43
+ print(f"\n出错: {e}")
44
+ print("提示: 如果是模型仓库,请将 REPO_TYPE 改为 'model' 后重试")
45
+ sys.exit(1)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ download()