| import argparse |
| import tqdm |
| import pandas as pd |
| import gc |
| from datasets import load_dataset |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--lang", type=str, default="en") |
| parser.add_argument("--shard_id", type=int, required=True) |
| parser.add_argument("--num_shards", type=int, default=20) |
| parser.add_argument("--max_chunks", type=int, default=15) |
| args = parser.parse_args() |
|
|
| |
| print(f"Loading {args.lang} Wikipedia shard {args.shard_id}...") |
| ds = load_dataset("wikimedia/wikipedia", f"20231101.{args.lang}", split='train') |
| ds_shard = ds.shard(num_shards=args.num_shards, index=args.shard_id) |
|
|
| |
| STOP_HEADERS = ["\nReferences", "\nSee also", "\nExternal links", "\nNotes", "\nFurther reading", "\nBibliography"] |
| wiki_chunks = [] |
| |
| |
| for article in tqdm.tqdm(ds_shard): |
| text = article['text'] |
| |
| |
| clean_text = text |
| for header in STOP_HEADERS: |
| if header in clean_text: |
| clean_text = clean_text.split(header)[0] |
| |
| |
| paragraphs = [p.strip() for p in clean_text.split('\n\n') if len(p.split()) > 20] |
| |
| |
| if len(paragraphs) > args.max_chunks: |
| paragraphs = paragraphs[:args.max_chunks] |
| |
| wiki_chunks.extend(paragraphs) |
|
|
| |
| |
| df = pd.DataFrame({"text": wiki_chunks}) |
| save_path = f"/home/mshahidul/readctrl/data/wiki_chunks/wiki_chunks_{args.lang}_shard_{args.shard_id}.parquet" |
| df.to_parquet(save_path, compression='snappy') |
| |
| print(f"Saved {len(wiki_chunks)} chunks to {save_path}") |
|
|
| if __name__ == "__main__": |
| main() |