rm_code / 3round.py
hahayang012's picture
Upload folder using huggingface_hub
d8a76be verified
# pip install datasets pyarrow regex
import re
import random
from datasets import load_dataset
# ========= 正则 =========
SYS_HEAD = re.compile(r"^<\|im_start\|>system\s.*?<\|im_end\|>\s*", re.S)
TURN_WITH_ROLE = re.compile(r"(<\|im_start\|>(user|assistant)\s*.*?<\|im_end\|>)", re.S)
# 人名+冒号(中英数字空格下划线),如:Kerensa: / 小明:
NAME_COLON = re.compile(r"^[\w\u4e00-\u9fa5][\w\u4e00-\u9fa5 _]{0,40}:\s*$")
def last_4rounds_user_to_open_assistant(chatml: str) -> str:
"""
取最近四轮:user → assistant → user → assistant → user → assistant → user → assistant(开放式)
去掉最前面的 system 段。
说明:
- 在开放式 assistant 之前,我们选择以 user 结尾的最近 7 段:U, A, U, A, U, A, U
- 然后拼上最后的 assistant(开放式),构成第 4 轮的 assistant。
"""
if not isinstance(chatml, str):
return chatml
text = SYS_HEAD.sub("", chatml)
# 非 ChatML 就保守返回
if ("<|im_start|>user" not in text) and ("<|im_start|>assistant" not in text):
return text
# 找到最后一次 assistant(开放式起点)
last_ast = text.rfind("<|im_start|>assistant")
if last_ast == -1:
return text.strip()
# 开放式 assistant:去掉它后面的 <|im_end|> 及其后续
final_assistant_open = text[last_ast:]
final_assistant_open = re.sub(r"<\|im_end\|>.*$", "", final_assistant_open, flags=re.S)
# 在开放式之前收集闭合轮次
head = text[:last_ast]
turns = [(m.group(2), m.group(1)) for m in TURN_WITH_ROLE.finditer(head)]
# 四轮需要 U,A,U,A,U,A,U 共 7 段历史;不足则尽力返回
if len(turns) < 7:
prefix = "\n".join(t[1] for t in turns)
if prefix:
prefix += "\n"
return prefix + final_assistant_open
# 取以 user 结尾的最近 7 段:U, A, U, A, U, A, U
j = next((i for i in range(len(turns)-1, -1, -1) if turns[i][0] == "user"), None)
if j is None:
selected = [t[1] for t in turns[-7:]] # 兜底
else:
i = max(0, j - 6) # 需要 7 段 => j-6 .. j
selected = [t[1] for t in turns[i:j+1]]
prefix = ("\n".join(selected) + "\n") if selected else ""
return prefix + final_assistant_open
# ============ 批处理 + 抽样打印 ============
in_path = "/home/data/train_v3full.parquet" # 输入
out_path = "/home/data/train_3round.parquet" # 输出(改名以示区分)
ds = load_dataset("parquet", data_files=in_path, split="train")
# 只保留三列
keep_cols = ["chosen_prompt", "chosen", "reject"]
drop_cols = [c for c in ds.column_names if c not in keep_cols]
if drop_cols:
ds = ds.remove_columns(drop_cols)
def ensure_linebreak_after_assistant(chosen_prompt: str) -> str:
"""
- <|im_start|>assistant 后必须换行
- 人名: 后面不换行
"""
# 1) 如果 assistant 标签后不是换行,就加换行
chosen_prompt = re.sub(
r"(<\|im_start\|>assistant)(?!\s*\n)", # 后面不是换行
r"\1\n",
chosen_prompt
)
# 2) 如果是人名: 后面有换行,就去掉换行(保证人名和内容在同一行)
m = re.search(r"(<\|im_start\|>assistant\s*\n)([^\n]{1,60}:)(\s*\r?\n\s*)", chosen_prompt)
if m:
before = m.group(1)
name_colon = m.group(2)
chosen_prompt = chosen_prompt.replace(
before + name_colon + m.group(3),
before + name_colon + " "
)
return chosen_prompt
def _map_fn(ex):
cp = last_4rounds_user_to_open_assistant(ex["chosen_prompt"])
cp = ensure_linebreak_after_assistant(cp)
ex["chosen_prompt"] = cp
return ex
# 可用 num_proc=4~8 加速(注意内存)
ds = ds.map(_map_fn, desc="Build last 4 rounds (open assistant) + linebreak rules")
ds.to_parquet(out_path)
print(f"✅ Saved -> {out_path}")
# 抽样打印 5 条(原样 + 拼接效果,便于检查是否多空行/人名是否同一行)
idxs = random.sample(range(len(ds)), min(5, len(ds)))
sampled = ds.select(idxs)
for i, ex in enumerate(sampled):
print(f"===== Sample {i+1} / chosen_prompt 原样 =====")
print(ex["chosen_prompt"])
print(f"===== Sample {i+1} / chosen_prompt + chosen =====")
print(ex["chosen_prompt"] + ex["chosen"])
print(f"===== Sample {i+1} / chosen_prompt + reject =====")
print(ex["chosen_prompt"] + ex["reject"])
print()