| | |
| | import re |
| | import random |
| | from datasets import load_dataset |
| |
|
| | |
| | SYS_HEAD = re.compile(r"^<\|im_start\|>system\s.*?<\|im_end\|>\s*", re.S) |
| | TURN_WITH_ROLE = re.compile(r"(<\|im_start\|>(user|assistant)\s*.*?<\|im_end\|>)", re.S) |
| |
|
| | |
| | NAME_COLON = re.compile(r"^[\w\u4e00-\u9fa5][\w\u4e00-\u9fa5 _]{0,40}:\s*$") |
| |
|
| | def last_4rounds_user_to_open_assistant(chatml: str) -> str: |
| | """ |
| | 取最近四轮:user → assistant → user → assistant → user → assistant → user → assistant(开放式) |
| | 去掉最前面的 system 段。 |
| | 说明: |
| | - 在开放式 assistant 之前,我们选择以 user 结尾的最近 7 段:U, A, U, A, U, A, U |
| | - 然后拼上最后的 assistant(开放式),构成第 4 轮的 assistant。 |
| | """ |
| | if not isinstance(chatml, str): |
| | return chatml |
| |
|
| | text = SYS_HEAD.sub("", chatml) |
| |
|
| | |
| | if ("<|im_start|>user" not in text) and ("<|im_start|>assistant" not in text): |
| | return text |
| |
|
| | |
| | last_ast = text.rfind("<|im_start|>assistant") |
| | if last_ast == -1: |
| | return text.strip() |
| |
|
| | |
| | final_assistant_open = text[last_ast:] |
| | final_assistant_open = re.sub(r"<\|im_end\|>.*$", "", final_assistant_open, flags=re.S) |
| |
|
| | |
| | head = text[:last_ast] |
| | turns = [(m.group(2), m.group(1)) for m in TURN_WITH_ROLE.finditer(head)] |
| |
|
| | |
| | if len(turns) < 7: |
| | prefix = "\n".join(t[1] for t in turns) |
| | if prefix: |
| | prefix += "\n" |
| | return prefix + final_assistant_open |
| |
|
| | |
| | j = next((i for i in range(len(turns)-1, -1, -1) if turns[i][0] == "user"), None) |
| | if j is None: |
| | selected = [t[1] for t in turns[-7:]] |
| | else: |
| | i = max(0, j - 6) |
| | selected = [t[1] for t in turns[i:j+1]] |
| |
|
| | prefix = ("\n".join(selected) + "\n") if selected else "" |
| | return prefix + final_assistant_open |
| |
|
| |
|
| | |
| | in_path = "/home/data/train_v3full.parquet" |
| | out_path = "/home/data/train_3round.parquet" |
| | ds = load_dataset("parquet", data_files=in_path, split="train") |
| |
|
| | |
| | keep_cols = ["chosen_prompt", "chosen", "reject"] |
| | drop_cols = [c for c in ds.column_names if c not in keep_cols] |
| | if drop_cols: |
| | ds = ds.remove_columns(drop_cols) |
| |
|
| | def ensure_linebreak_after_assistant(chosen_prompt: str) -> str: |
| | """ |
| | - <|im_start|>assistant 后必须换行 |
| | - 人名: 后面不换行 |
| | """ |
| | |
| | chosen_prompt = re.sub( |
| | r"(<\|im_start\|>assistant)(?!\s*\n)", |
| | r"\1\n", |
| | chosen_prompt |
| | ) |
| |
|
| | |
| | m = re.search(r"(<\|im_start\|>assistant\s*\n)([^\n]{1,60}:)(\s*\r?\n\s*)", chosen_prompt) |
| | if m: |
| | before = m.group(1) |
| | name_colon = m.group(2) |
| | chosen_prompt = chosen_prompt.replace( |
| | before + name_colon + m.group(3), |
| | before + name_colon + " " |
| | ) |
| | return chosen_prompt |
| |
|
| | def _map_fn(ex): |
| | cp = last_4rounds_user_to_open_assistant(ex["chosen_prompt"]) |
| | cp = ensure_linebreak_after_assistant(cp) |
| | ex["chosen_prompt"] = cp |
| | return ex |
| |
|
| | |
| | ds = ds.map(_map_fn, desc="Build last 4 rounds (open assistant) + linebreak rules") |
| |
|
| | ds.to_parquet(out_path) |
| | print(f"✅ Saved -> {out_path}") |
| |
|
| | |
| | idxs = random.sample(range(len(ds)), min(5, len(ds))) |
| | sampled = ds.select(idxs) |
| | for i, ex in enumerate(sampled): |
| | print(f"===== Sample {i+1} / chosen_prompt 原样 =====") |
| | print(ex["chosen_prompt"]) |
| | print(f"===== Sample {i+1} / chosen_prompt + chosen =====") |
| | print(ex["chosen_prompt"] + ex["chosen"]) |
| | print(f"===== Sample {i+1} / chosen_prompt + reject =====") |
| | print(ex["chosen_prompt"] + ex["reject"]) |
| | print() |
| |
|