| import os |
| from pathlib import Path |
| from typing import List, Union |
|
|
| from datasets import load_dataset, concatenate_datasets |
| from transformers import AutoTokenizer |
|
|
|
|
| |
| |
| |
| parquet_paths: List[str] = [ |
| "/home/data/train_10k_sys_3round.parquet", |
| ] |
| tokenizer_path = "/home/rm3.4.1_9e-6" |
| output_path = "/home/data/prefiltered.parquet" |
| num_proc = max(1, (os.cpu_count() or 4) // 2) |
| min_tokens, max_tokens = 20, 80 |
| |
|
|
|
|
| def collect_parquet_files() -> List[str]: |
| if parquet_paths: |
| return [str(Path(p)) for p in parquet_paths] |
| p = Path(data_dir) |
| if not p.exists(): |
| raise FileNotFoundError(f"目录不存在:{p}") |
| files = sorted([str(fp) for fp in p.glob("*.parquet")]) |
| if not files: |
| raise FileNotFoundError(f"目录中未找到 .parquet 文件:{p}") |
| return files |
|
|
|
|
| def main(): |
| files = collect_parquet_files() |
| print(f"发现 {len(files)} 个 parquet 文件,将合并处理:") |
| for f in files: |
| print(" -", f) |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
|
|
| |
| dataset = load_dataset("parquet", data_files=files, split="train") |
|
|
| |
| |
| |
|
|
| total_before = len(dataset) |
| print(f"\n合并后样本数:{total_before}") |
|
|
| |
| def add_token_lengths(batch): |
| chosen = batch["chosen"] |
| reject = batch["reject"] |
|
|
| |
| chosen_ids = tokenizer(chosen, add_special_tokens=False)["input_ids"] |
| reject_ids = tokenizer(reject, add_special_tokens=False)["input_ids"] |
|
|
| return { |
| "chosen_tokens": [len(x) for x in chosen_ids], |
| "reject_tokens": [len(x) for x in reject_ids], |
| } |
|
|
| dataset = dataset.map( |
| add_token_lengths, |
| batched=True, |
| num_proc=num_proc, |
| desc="计算 token 数", |
| ) |
|
|
| |
| def in_range_filter(batch): |
| ct = batch["chosen_tokens"] |
| rt = batch["reject_tokens"] |
| |
| return [ |
| (min_tokens <= c <= max_tokens) and (min_tokens <= r <= max_tokens) |
| for c, r in zip(ct, rt) |
| ] |
|
|
| dataset = dataset.filter( |
| in_range_filter, |
| batched=True, |
| num_proc=num_proc, |
| desc=f"过滤:保留 {min_tokens}~{max_tokens} tokens", |
| ) |
|
|
| kept = len(dataset) |
| print(f"过滤完成:保留 {kept} / {total_before} (保留率 {kept/total_before:.2%})") |
|
|
| |
| |
| for col in ["chosen_tokens", "reject_tokens"]: |
| if col in dataset.column_names: |
| dataset = dataset.remove_columns(col) |
|
|
| |
| Path(output_path).parent.mkdir(parents=True, exist_ok=True) |
| dataset.to_parquet(output_path) |
| print(f"已保存到:{output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|