况兑 commited on
Commit
e45d7fc
·
1 Parent(s): ded23ce

eval: greedy decode + numeric strict; system: force full decimals; regressions: A/B/C/noisy

Browse files
Files changed (4) hide show
  1. eval_simple.py +128 -0
  2. test_regress.sh +13 -0
  3. test_regress_full.sh +8 -0
  4. train_lora.py +348 -0
eval_simple.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, re, argparse, torch
2
+ import re
3
+
4
+ def _collapse_digit_separators(s: str) -> str:
5
+ # 去掉出现在“数字 与 数字”之间的空白/逗号(含常见窄空格)
6
+ return re.sub(r'(?<=\d)[,\s\u00A0\u202F\u2009\u2007\u2060]+(?=\d)', '', s)
7
+
8
+
9
+ def _postproc_model_text(s: str) -> str:
10
+ # collapse spaces/commas between digits like '330.7 6' -> '330.76', '1,234' -> '1234'
11
+ s = re.sub(r'(?<=\d)[,\s]+(?=\d)', '', s)
12
+ return s
13
+
14
+
15
+ def _preprocess_user_text(s: str) -> str:
16
+ # 全角标点 -> 半角
17
+ s = s.replace(",", ",").replace("。", ".").replace(":", ":").replace("(","(").replace(")",")")
18
+ # 去掉数字内部的逗号/空格(保留小数点)
19
+ s = re.sub(r'(?<=\d)[,\s]+(?=\d)', '', s)
20
+ # 压缩空白
21
+ s = re.sub(r'\s+', ' ', s).strip()
22
+ return s
23
+
24
+ from transformers import AutoTokenizer, AutoModelForCausalLM
25
+ from peft import PeftModel
26
+
27
+ def messages_to_pairs(messages):
28
+ pairs, buf = [], []
29
+ for m in messages:
30
+ if m.get("role")=="user":
31
+ buf.append(m.get("content",""))
32
+ elif m.get("role")=="assistant" and buf:
33
+ pairs.append({"prompt":"\n\n".join(buf), "response":m.get("content","")})
34
+ buf=[]
35
+ return pairs
36
+
37
+ def normalize(s: str) -> str:
38
+ s = s.replace("\u3000"," ").strip()
39
+ trans = str.maketrans(",。:!?【】()%+-×÷=“”‘’", ",.:!?[]()%+-*/=\"\"''")
40
+ s = s.translate(trans)
41
+ s = re.sub(r"\s+", " ", s)
42
+ return s
43
+
44
+ def to_num(x):
45
+ try:
46
+ return float(x)
47
+ except:
48
+ if not isinstance(x, str):
49
+ x = str(x)
50
+ x = _collapse_digit_separators(x)
51
+ m = re.search(r"[-+]?\d*\.?\d+(?:e[-+]?\d+)?", x, re.I)
52
+ return float(m.group(0)) if m else None
53
+
54
+ def main():
55
+ ap = argparse.ArgumentParser()
56
+ ap.add_argument("--base", default="Qwen/Qwen2.5-0.5B-Instruct")
57
+ ap.add_argument("--adapter", required=True)
58
+ ap.add_argument("--data", required=True)
59
+ ap.add_argument("--max_new", type=int, default=64)
60
+ ap.add_argument("--limit", type=int, default=0, help="只评测前 N 条,0=全部")
61
+ args = ap.parse_args()
62
+
63
+ tok = AutoTokenizer.from_pretrained(args.base, trust_remote_code=True, use_fast=True)
64
+ if tok.pad_token is None:
65
+ tok.pad_token = tok.eos_token
66
+
67
+ base = AutoModelForCausalLM.from_pretrained(args.base, trust_remote_code=True, torch_dtype=torch.float32)
68
+ model = PeftModel.from_pretrained(base, args.adapter)
69
+ model.eval()
70
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
71
+ model.to(device)
72
+
73
+ # 清理 generation 默认的采样项,避免 warning 且确保贪心
74
+ gc = model.generation_config
75
+ gc.do_sample = False
76
+ gc.temperature = None
77
+ gc.top_p = None
78
+ gc.top_k = None
79
+
80
+ golds, preds = [], []
81
+
82
+ with open(args.data, "r", encoding="utf-8") as f:
83
+ for i, line in enumerate(f):
84
+ if args.limit and i >= args.limit: break
85
+ row = json.loads(line)
86
+ pairs = messages_to_pairs(row["messages"])
87
+ if not pairs: continue
88
+ ex = pairs[0]
89
+ gold = ex["response"]
90
+
91
+ ctx = [
92
+ {"role":"system","content":"请只输出最终答案,不要解释。只输出一个数字;若为小数,完整输出全部小数位,不要四舍五入或截断。"},
93
+ {"role":"user","content": ex["prompt"]}
94
+ ]
95
+ prompt_text = tok.apply_chat_template(ctx, tokenize=False, add_generation_prompt=True)
96
+ inputs = tok(prompt_text, return_tensors="pt").to(device)
97
+
98
+ with torch.no_grad():
99
+ out = model.generate(
100
+ **inputs,
101
+ max_new_tokens=args.max_new,
102
+ do_sample=False,
103
+ temperature=None,
104
+ eos_token_id=tok.eos_token_id,
105
+ pad_token_id=tok.eos_token_id
106
+ )
107
+ pred = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
108
+
109
+ pred = _postproc_model_text(pred)
110
+ golds.append(gold)
111
+ preds.append(pred)
112
+
113
+ print(f"[{i}] GT={repr(gold)} | PRED={repr(pred)}")
114
+
115
+ # 计算三种 EM
116
+ strict = sum(1 for g,p in zip(golds,preds) if p==g)
117
+ loose = sum(1 for g,p in zip(golds,preds) if normalize(p)==normalize(g))
118
+ numem = 0
119
+ for g,p in zip(golds,preds):
120
+ ng, np = to_num(g), to_num(p)
121
+ if ng is not None and np is not None and abs(ng-np)<1e-6:
122
+ numem += 1
123
+
124
+ n = len(golds) if golds else 1
125
+ print(f"\n==> EM strict={strict/n:.3f} EM loose={loose/n:.3f} EM numeric={numem/n:.3f} (N={n})")
126
+
127
+ if __name__ == "__main__":
128
+ main()
test_regress.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -euo pipefail
2
+ DATA1=subset10.numeric.jsonl
3
+ DATA2=subset10.perturbed.chat.jsonl
4
+ DATA3=subset10.perturbed.chat.norm.jsonl
5
+ AD=./runs/overfit10_gold
6
+
7
+ echo "[A] 原始集"
8
+ python eval_simple.py --adapter "$AD" --data "$DATA1"
9
+ echo "[B] 扰动集"
10
+ python eval_simple.py --adapter "$AD" --data "$DATA2"
11
+ echo "[C] 扰动归一化集"
12
+ python eval_simple.py --adapter "$AD" --data "$DATA3"
13
+ echo "==> 回归测试跑完"
test_regress_full.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ set -euo pipefail
2
+ AD=./runs/overfit10_gold
3
+ for D in subset10.numeric.jsonl subset10.perturbed.chat.jsonl subset10.perturbed.chat.norm.jsonl subset10.noisy.chat.jsonl
4
+ do
5
+ echo "==> $D"
6
+ python eval_simple.py --adapter "$AD" --data "$D"
7
+ done
8
+ echo "OK: full regression passed."
train_lora.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, random, math
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional
4
+ from datasets import Dataset, DatasetDict
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
6
+ import torch
7
+ from peft import LoraConfig, get_peft_model
8
+
9
+ # --------------------
10
+ # Config via env
11
+ # --------------------
12
+ BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
13
+ DATA_PATH = os.environ.get("DATA_PATH", "s1k_chat_1.1_small.jsonl")
14
+ OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./runs/qwen25-0p5b-lora")
15
+ SEED = int(os.environ.get("SEED", "42"))
16
+ EPOCHS = float(os.environ.get("EPOCHS", "3"))
17
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4"))
18
+ GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
19
+ LR = float(os.environ.get("LR", "2e-4"))
20
+ MAX_LEN = int(os.environ.get("MAX_LEN", "1024"))
21
+ WARMUP_RATIO = float(os.environ.get("WARMUP_RATIO", "0.05"))
22
+ LORA_R = int(os.environ.get("LORA_R", "16"))
23
+ LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "32"))
24
+ LORA_DROPOUT = float(os.environ.get("LORA_DROPOUT", "0.05"))
25
+ SAVE_STEPS = int(os.environ.get("SAVE_STEPS", "0")) # 0 -> only end
26
+ VAL_RATIO = float(os.environ.get("VAL_RATIO", "0.1"))
27
+
28
+ random.seed(SEED)
29
+
30
+ # --------------------
31
+ # Load & split dataset
32
+ # --------------------
33
+ def load_jsonl_messages(path: str) -> List[Dict]:
34
+ rows = []
35
+ with open(path, "r", encoding="utf-8") as f:
36
+ for line in f:
37
+ line = line.strip()
38
+ if not line:
39
+ continue
40
+ obj = json.loads(line)
41
+ rows.append(obj)
42
+ return rows
43
+
44
+ raw = load_jsonl_messages(DATA_PATH)
45
+
46
+ # Basic shuffle & split
47
+ random.shuffle(raw)
48
+ val_n = max(1, int(len(raw) * VAL_RATIO))
49
+ val_list = raw[:val_n]
50
+ train_list = raw[val_n:]
51
+
52
+ def messages_to_pairs(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
53
+ """
54
+ 将多轮 messages 压成若干 (prompt -> response) 对:
55
+ 连续 user 合并成一个 prompt,遇到 assistant 产出一对。
56
+ """
57
+ pairs = []
58
+ last_user = []
59
+ for m in messages:
60
+ role = m.get("role", "")
61
+ content = m.get("content", "")
62
+ if role == "user":
63
+ last_user.append(content)
64
+ elif role == "assistant" and last_user:
65
+ prompt = "\n\n".join(last_user)
66
+ pairs.append({"prompt": prompt, "response": content})
67
+ last_user = []
68
+ return pairs
69
+
70
+ def flatten_jsonl_to_pairs(jsonl_rows: List[Dict]) -> List[Dict]:
71
+ pairs_all = []
72
+ for r in jsonl_rows:
73
+ msgs = r.get("messages", [])
74
+ pairs = messages_to_pairs(msgs)
75
+ pairs_all.extend(pairs)
76
+ return pairs_all
77
+
78
+ train_pairs = flatten_jsonl_to_pairs(train_list)
79
+ val_pairs = flatten_jsonl_to_pairs(val_list)
80
+
81
+ train_ds = Dataset.from_list(train_pairs)
82
+ val_ds = Dataset.from_list(val_pairs)
83
+ ds = DatasetDict({"train": train_ds, "validation": val_ds})
84
+
85
+ # --------------------
86
+ # Tokenizer & chat template
87
+ # --------------------
88
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True, trust_remote_code=True)
89
+
90
+ if tokenizer.pad_token_id is None:
91
+ tokenizer.pad_token = tokenizer.eos_token
92
+ try:
93
+ tokenizer.padding_side = "right"
94
+ except Exception:
95
+ pass
96
+ if tokenizer.pad_token is None:
97
+ tokenizer.pad_token = tokenizer.eos_token
98
+
99
+
100
+ # === chat_template tokenization ===
101
+
102
+ def _sft_tokenize_with_chat_template(example):
103
+ # 将 (prompt, response) 还原成 messages,用 chat_template 构造
104
+ ctx_msgs = [{"role":"user","content": example["prompt"]}]
105
+ tgt = example["response"]
106
+
107
+ # 仅上下文,要求“准备生成”
108
+ prompt_text = tokenizer.apply_chat_template(
109
+ ctx_msgs, tokenize=False, add_generation_prompt=True
110
+ )
111
+ # 仅答案 + eos
112
+ target_text = tgt + (tokenizer.eos_token or "")
113
+
114
+ prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
115
+ target_ids = tokenizer(target_text, add_special_tokens=False)["input_ids"]
116
+
117
+ # 截断:尽量保住答案(右截断)
118
+ max_p = MAX_LEN - len(target_ids)
119
+ if max_p <= 0:
120
+ target_ids = target_ids[-(MAX_LEN-1):]
121
+ prompt_ids = []
122
+ else:
123
+ prompt_ids = prompt_ids[-max_p:]
124
+
125
+ ids = prompt_ids + target_ids
126
+ labels = [-100]*len(prompt_ids) + target_ids[:] # 只对答案段计损失
127
+ attn = [1]*len(ids)
128
+
129
+ return {"input_ids": ids, "labels": labels, "attention_mask": attn}
130
+
131
+
132
+ IGNORE_INDEX = -100
133
+
134
+ def tokenize(example: Dict) -> Dict:
135
+ # 仅对 assistant 段计算损失
136
+ prompt_text = build_chat_prompt(example["prompt"], None)
137
+ full_text = build_chat_prompt(example["prompt"], example["response"])
138
+
139
+ prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
140
+ full = tokenizer(
141
+ full_text,
142
+ max_length=MAX_LEN,
143
+ truncation=True,
144
+ padding=False,
145
+ add_special_tokens=False,
146
+ )["input_ids"]
147
+
148
+ labels = [IGNORE_INDEX] * len(full)
149
+ start = len(prompt_ids)
150
+ for i in range(start, len(full)):
151
+ labels[i] = full[i]
152
+
153
+ return {
154
+ "input_ids": full,
155
+ "labels": labels,
156
+ "attention_mask": [1] * len(full),
157
+ }
158
+
159
+ tokenized = ds.map(_sft_tokenize_with_chat_template, remove_columns=ds["train"].column_names, desc="Tokenizing with chat_template")
160
+
161
+ # --------------------
162
+ # Model & LoRA —— 适配 Mac (MPS):禁用混合精度,用 fp32
163
+ # --------------------
164
+ use_mps = torch.backends.mps.is_available()
165
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
166
+ compute_dtype = torch.bfloat16
167
+ elif torch.cuda.is_available():
168
+ compute_dtype = torch.float16
169
+ else:
170
+ compute_dtype = torch.float32 # MPS/CPU 用全精度
171
+
172
+ device_map = "auto" if torch.cuda.is_available() else None
173
+
174
+ model = AutoModelForCausalLM.from_pretrained(
175
+ BASE_MODEL,
176
+ torch_dtype=compute_dtype,
177
+ device_map=device_map,
178
+ trust_remote_code=True,
179
+ )
180
+
181
+ if use_mps:
182
+ model.to("mps")
183
+
184
+ model.gradient_checkpointing_enable()
185
+ model.enable_input_require_grads()
186
+
187
+ lora_cfg = LoraConfig(
188
+ r=LORA_R,
189
+ lora_alpha=LORA_ALPHA,
190
+ lora_dropout=LORA_DROPOUT,
191
+ bias="none",
192
+ task_type="CAUSAL_LM",
193
+ target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
194
+ )
195
+
196
+ model = get_peft_model(model, lora_cfg)
197
+
198
+ def print_trainable(model):
199
+ trainable = 0
200
+ total = 0
201
+ for n,p in model.named_parameters():
202
+ c = p.numel()
203
+ total += c
204
+ if p.requires_grad:
205
+ trainable += c
206
+ print(f"[PARAMS] total={total} trainable={trainable} ratio={trainable/max(total,1):.6f}")
207
+ print_trainable(model)
208
+
209
+ # --------------------
210
+ # Collator(不打乱 labels)
211
+ # --------------------
212
+ @dataclass
213
+ class DataCollatorForCausalLM:
214
+ tokenizer: AutoTokenizer
215
+ pad_to_multiple_of: Optional[int] = 8
216
+
217
+ def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
218
+ max_len = max(len(f["input_ids"]) for f in features)
219
+ if self.pad_to_multiple_of:
220
+ max_len = int(math.ceil(max_len / self.pad_to_multiple_of) * self.pad_to_multiple_of)
221
+
222
+ input_ids, labels, attention_mask = [], [], []
223
+ for f in features:
224
+ ids = f["input_ids"]
225
+ labs = f["labels"]
226
+ mask = f["attention_mask"]
227
+ pad_len = max_len - len(ids)
228
+ input_ids.append(ids + [tokenizer.pad_token_id] * pad_len)
229
+ attention_mask.append(mask + [0] * pad_len)
230
+ labels.append(labs + [IGNORE_INDEX] * pad_len)
231
+
232
+ return {
233
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
234
+ "labels": torch.tensor(labels, dtype=torch.long),
235
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long)
236
+ }
237
+
238
+ collator = DataCollatorForCausalLM(tokenizer)
239
+
240
+ # --------------------
241
+ # Training(在 Mac 上强制不用 bf16/fp16)
242
+ # --------------------
243
+ steps_per_epoch = max(1, len(tokenized["train"]) // (BATCH_SIZE * GRAD_ACCUM))
244
+ save_strategy = "steps" if SAVE_STEPS > 0 else "epoch"
245
+
246
+ training_args = TrainingArguments(
247
+ output_dir=OUTPUT_DIR,
248
+ num_train_epochs=EPOCHS,
249
+ per_device_train_batch_size=BATCH_SIZE,
250
+ per_device_eval_batch_size=BATCH_SIZE,
251
+ gradient_accumulation_steps=GRAD_ACCUM,
252
+ learning_rate=LR,
253
+ warmup_ratio=WARMUP_RATIO,
254
+ logging_steps=max(1, steps_per_epoch // 5),
255
+ evaluation_strategy="epoch",
256
+ save_strategy=save_strategy,
257
+ save_steps=SAVE_STEPS if SAVE_STEPS > 0 else None,
258
+ save_total_limit=2,
259
+ bf16=False,
260
+ fp16=False,
261
+ weight_decay=0.0,
262
+ lr_scheduler_type="cosine",
263
+ seed=SEED,
264
+ max_grad_norm=1.0,
265
+ remove_unused_columns=False,
266
+ report_to=["none"],
267
+ )
268
+
269
+ trainer = Trainer(
270
+ model=model,
271
+ args=training_args,
272
+ train_dataset=tokenized["train"],
273
+ eval_dataset=tokenized["validation"],
274
+ data_collator=collator,
275
+ tokenizer=tokenizer,
276
+ )
277
+
278
+ trainer.train()
279
+ metrics = trainer.evaluate()
280
+ trainer.save_model()
281
+ tokenizer.save_pretrained(OUTPUT_DIR)
282
+
283
+ with open(os.path.join(OUTPUT_DIR, "eval_metrics.json"), "w", encoding="utf-8") as f:
284
+ json.dump(metrics, f, indent=2, ensure_ascii=False)
285
+
286
+ print("==> Training done. Eval metrics:", metrics)
287
+
288
+
289
+
290
+ def _build_sft_examples(examples, tokenizer, max_len=1024):
291
+ # 期望每行是 {"messages":[{"role":"user"/"system"/"assistant","content":...}, ...]}
292
+ texts=[]
293
+ for msgs in examples["messages"]:
294
+ # 找最后一条 assistant 作为监督目标;其余作为上下文
295
+ if not isinstance(msgs, list) or not msgs:
296
+ continue
297
+ # 拆出上下文(user/system等,不含最后assistant)
298
+ ctx = [m for m in msgs if m.get("role")!="assistant"]
299
+ # 目标:最后一个 assistant(若没有则跳过)
300
+ tgt = None
301
+ for m in reversed(msgs):
302
+ if m.get("role")=="assistant":
303
+ tgt = m["content"]
304
+ break
305
+ if tgt is None:
306
+ continue
307
+ # 构造:上下文 + 目标
308
+ prompt = tokenizer.apply_chat_template(ctx + [{"role":"assistant","content":tgt}],
309
+ tokenize=False, add_generation_prompt=False)
310
+ texts.append(prompt)
311
+
312
+ tokenized = tokenizer(texts, truncation=True, max_length=max_len)
313
+ return tokenized
314
+
315
+
316
+
317
+
318
+ IGNORE_INDEX = -100
319
+
320
+ def _sft_tokenize_with_chat_template(example):
321
+ # 将 (prompt, response) 还原成 messages,用 chat_template 构造
322
+ ctx_msgs = [{"role":"user","content": example["prompt"]}]
323
+ tgt = example["response"]
324
+
325
+ # 仅上下文,要求“准备生成”
326
+ prompt_text = tokenizer.apply_chat_template(
327
+ ctx_msgs, tokenize=False, add_generation_prompt=True
328
+ )
329
+ # 仅答案 + eos
330
+ target_text = tgt + (tokenizer.eos_token or "")
331
+
332
+ prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
333
+ target_ids = tokenizer(target_text, add_special_tokens=False)["input_ids"]
334
+
335
+ # 截断:尽量保住答案(右截断)
336
+ max_p = MAX_LEN - len(target_ids)
337
+ if max_p <= 0:
338
+ target_ids = target_ids[-(MAX_LEN-1):]
339
+ prompt_ids = []
340
+ else:
341
+ prompt_ids = prompt_ids[-max_p:]
342
+
343
+ ids = prompt_ids + target_ids
344
+ labels = [-100]*len(prompt_ids) + target_ids[:] # 只对答案段计损失
345
+ attn = [1]*len(ids)
346
+
347
+ return {"input_ids": ids, "labels": labels, "attention_mask": attn}
348
+