| |
| """Create deterministic dataset/retrieval-cache shards by question id.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| from typing import Dict |
|
|
|
|
| def split_json_dataset(dataset_path: Path, out_dir: Path, num_shards: int) -> Dict[str, int]: |
| data = json.load(open(dataset_path, "r", encoding="utf-8")) |
| qid_to_shard: Dict[str, int] = {} |
| shards = [[] for _ in range(num_shards)] |
| dataset_dir = out_dir / "dataset" |
| dataset_dir.mkdir(parents=True, exist_ok=True) |
|
|
| for idx, entry in enumerate(data): |
| shard = idx % num_shards |
| qid_to_shard[entry["question_id"]] = shard |
| shards[shard].append(entry) |
|
|
| for shard, rows in enumerate(shards): |
| path = dataset_dir / f"shard_{shard:02d}.json" |
| with open(path, "w", encoding="utf-8") as f: |
| json.dump(rows, f, ensure_ascii=False, indent=2) |
| f.write("\n") |
| return qid_to_shard |
|
|
|
|
| def split_jsonl_cache(cache_path: Path, out_dir: Path, qid_to_shard: Dict[str, int], num_shards: int) -> None: |
| out_dir.mkdir(parents=True, exist_ok=True) |
| handles = [ |
| open(out_dir / f"shard_{shard:02d}.jsonl", "w", encoding="utf-8") |
| for shard in range(num_shards) |
| ] |
| try: |
| with open(cache_path, "r", encoding="utf-8") as f: |
| for line in f: |
| if not line.strip(): |
| continue |
| row = json.loads(line) |
| qid = row["question_id"] |
| shard = qid_to_shard.get(qid) |
| if shard is not None: |
| handles[shard].write(json.dumps(row, ensure_ascii=False) + "\n") |
| finally: |
| for handle in handles: |
| handle.close() |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--dataset", required=True) |
| parser.add_argument("--ret_cache", required=True) |
| parser.add_argument("--semantic_cache", required=True) |
| parser.add_argument("--out_dir", required=True) |
| parser.add_argument("--num_shards", type=int, default=8) |
| args = parser.parse_args() |
|
|
| out_dir = Path(args.out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| qid_to_shard = split_json_dataset(Path(args.dataset), out_dir, args.num_shards) |
| split_jsonl_cache(Path(args.ret_cache), out_dir / "ret_cache", qid_to_shard, args.num_shards) |
| split_jsonl_cache(Path(args.semantic_cache), out_dir / "semantic_cache", qid_to_shard, args.num_shards) |
|
|
| with open(out_dir / "qid_to_shard.json", "w", encoding="utf-8") as f: |
| json.dump(qid_to_shard, f, ensure_ascii=False, indent=2) |
| f.write("\n") |
|
|
| print(f"Wrote {args.num_shards} shards to {out_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|