graphrag-inference-hackathon / graphrag /prepare_dataset.py
muthuk1's picture
Add .gitignore, dataset metadata, retrieval layer, and latest web/graphrag updates
577adc4
"""
Dataset Preparation — 2M+ Token Corpus for GraphRAG Hackathon
==============================================================
Downloads, tokenizes, and ingests a 2M+ token corpus into TigerGraph.
Supported sources (pick one or combine):
1. Wikipedia (English) — best entity density, CC-BY-SA
2. arXiv papers (neuralwork/arxiver) — full text, CC-BY-NC-SA
3. BBC News (RealTimeData/bbc_news_alltime) — events, CC-BY
Usage:
python graphrag/prepare_dataset.py --source wikipedia --target-tokens 2500000
python graphrag/prepare_dataset.py --source arxiv --target-tokens 2500000
python graphrag/prepare_dataset.py --source bbc --target-tokens 2500000
python graphrag/prepare_dataset.py --source wikipedia --target-tokens 2500000 --ingest
"""
import argparse
import hashlib
import json
import logging
import os
import sys
import time
from typing import Dict, List, Tuple
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def count_tokens(text: str) -> int:
"""Estimate token count. Uses tiktoken if available, otherwise word-based estimate."""
try:
import tiktoken
enc = tiktoken.encoding_for_model("gpt-4o-mini")
return len(enc.encode(text))
except ImportError:
# Rough estimate: 1 token ≈ 0.75 words (English)
return int(len(text.split()) * 1.33)
def load_wikipedia(target_tokens: int, domain: str = "science") -> List[Dict]:
"""
Load Wikipedia articles until target token count is reached.
Domain filters available:
- "science": physics, chemistry, biology, mathematics, astronomy
- "history": wars, empires, historical figures, events
- "politics": countries, politicians, governments, elections
- "technology": computing, AI, internet, software, engineering
- "all": no filter (fastest, most diverse)
"""
from datasets import load_dataset
logger.info(f"Loading Wikipedia (domain={domain}, target={target_tokens:,} tokens)...")
domain_keywords = {
"science": ["physicist", "scientist", "chemist", "biologist", "mathematician",
"theory", "equation", "discovery", "experiment", "research",
"university", "professor", "nobel", "quantum", "evolution"],
"history": ["war", "battle", "empire", "dynasty", "revolution", "treaty",
"king", "queen", "president", "ancient", "medieval", "colonial"],
"politics": ["election", "government", "parliament", "president", "minister",
"democrat", "republic", "constitution", "legislation", "political"],
"technology": ["computer", "software", "algorithm", "internet", "artificial",
"programming", "engineer", "processor", "database", "network"],
"all": [],
}
keywords = domain_keywords.get(domain, [])
ds = load_dataset("wikimedia/wikipedia", "20231101.en", split="train", streaming=True)
documents = []
total_tokens = 0
scanned = 0
for article in ds:
scanned += 1
title = article.get("title", "")
text = article.get("text", "")
if not text or len(text) < 200:
continue
# Domain filter
if keywords:
title_lower = title.lower()
text_lower = text[:2000].lower() # check first 2000 chars only for speed
if not any(kw in title_lower or kw in text_lower for kw in keywords):
if scanned % 1000 == 0:
logger.info(f" Scanned {scanned:,} articles, collected {len(documents)}, "
f"tokens: {total_tokens:,}/{target_tokens:,}")
continue
tokens = count_tokens(text)
documents.append({
"id": hashlib.md5(title.encode()).hexdigest()[:12],
"title": title,
"content": text,
"source": "wikipedia",
"tokens": tokens,
"url": article.get("url", ""),
})
total_tokens += tokens
if len(documents) % 100 == 0:
logger.info(f" Collected {len(documents)} articles, "
f"tokens: {total_tokens:,}/{target_tokens:,} "
f"({total_tokens/target_tokens*100:.1f}%)")
if total_tokens >= target_tokens:
break
logger.info(f"✅ Wikipedia: {len(documents)} articles, {total_tokens:,} tokens")
return documents
def load_arxiv(target_tokens: int) -> List[Dict]:
"""Load arXiv papers with full markdown text from neuralwork/arxiver."""
from datasets import load_dataset
logger.info(f"Loading arXiv papers (target={target_tokens:,} tokens)...")
ds = load_dataset("neuralwork/arxiver", split="train")
documents = []
total_tokens = 0
for i, paper in enumerate(ds):
text = paper.get("markdown", "")
if not text or len(text) < 500:
continue
title = paper.get("title", f"Paper {i}")
tokens = count_tokens(text)
documents.append({
"id": paper.get("id", hashlib.md5(title.encode()).hexdigest()[:12]),
"title": title,
"content": text,
"source": "arxiv",
"tokens": tokens,
"authors": paper.get("authors", ""),
"published_date": paper.get("published_date", ""),
})
total_tokens += tokens
if len(documents) % 50 == 0:
logger.info(f" Collected {len(documents)} papers, "
f"tokens: {total_tokens:,}/{target_tokens:,}")
if total_tokens >= target_tokens:
break
logger.info(f"✅ arXiv: {len(documents)} papers, {total_tokens:,} tokens")
return documents
def load_bbc_news(target_tokens: int, year: str = "2022") -> List[Dict]:
"""Load BBC News articles from RealTimeData/bbc_news_alltime."""
from datasets import load_dataset, concatenate_datasets
logger.info(f"Loading BBC News (year={year}, target={target_tokens:,} tokens)...")
months = [f"{year}-{m:02d}" for m in range(1, 13)]
all_articles = []
for month in months:
try:
ds = load_dataset("RealTimeData/bbc_news_alltime", month, split="train")
all_articles.extend([dict(row) for row in ds])
logger.info(f" Loaded {month}: {len(ds)} articles (total: {len(all_articles)})")
except Exception as e:
logger.warning(f" {month} not available: {e}")
continue
documents = []
total_tokens = 0
for article in all_articles:
text = article.get("content", "")
if not text or len(text) < 200:
continue
title = article.get("title", "Untitled")
tokens = count_tokens(text)
documents.append({
"id": hashlib.md5(f"{title}:{article.get('published_date','')}".encode()).hexdigest()[:12],
"title": title,
"content": text,
"source": "bbc_news",
"tokens": tokens,
"section": article.get("section", ""),
"published_date": article.get("published_date", ""),
})
total_tokens += tokens
if total_tokens >= target_tokens:
break
logger.info(f"✅ BBC News: {len(documents)} articles, {total_tokens:,} tokens")
return documents
def save_dataset(documents: List[Dict], output_dir: str = "dataset"):
"""Save prepared dataset to disk."""
os.makedirs(output_dir, exist_ok=True)
# Save as JSONL
output_path = os.path.join(output_dir, "corpus.jsonl")
with open(output_path, "w", encoding="utf-8") as f:
for doc in documents:
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
# Save metadata
total_tokens = sum(d["tokens"] for d in documents)
meta = {
"num_documents": len(documents),
"total_tokens": total_tokens,
"sources": list(set(d["source"] for d in documents)),
"avg_tokens_per_doc": total_tokens // max(len(documents), 1),
"meets_2m_minimum": total_tokens >= 2_000_000,
"created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
}
meta_path = os.path.join(output_dir, "metadata.json")
with open(meta_path, "w") as f:
json.dump(meta, f, indent=2)
logger.info(f"\n{'='*60}")
logger.info(f"DATASET SAVED: {output_path}")
logger.info(f" Documents: {len(documents):,}")
logger.info(f" Total tokens: {total_tokens:,}")
logger.info(f" Meets 2M minimum: {'✅ YES' if total_tokens >= 2_000_000 else '❌ NO'}")
logger.info(f" Metadata: {meta_path}")
logger.info(f"{'='*60}\n")
return meta
def ingest_to_tigergraph(documents: List[Dict], max_docs: int = None, extract_entities: bool = True):
"""Ingest prepared documents into TigerGraph via the ingestion pipeline."""
from graphrag.layers.graph_layer import GraphLayer
from graphrag.layers.llm_layer import LLMLayer
from graphrag.layers.orchestration_layer import EmbeddingManager
from graphrag.ingestion import IngestionPipeline
logger.info("Connecting to TigerGraph...")
graph = GraphLayer(config={
"host": os.getenv("TG_HOST", ""),
"graphname": os.getenv("TG_GRAPH", "GraphRAG"),
"username": os.getenv("TG_USERNAME", "tigergraph"),
"password": os.getenv("TG_PASSWORD", ""),
"token": os.getenv("TG_TOKEN", ""),
})
if not graph.connect():
logger.error("TigerGraph connection failed. Set TG_HOST and TG_PASSWORD.")
return
graph.create_schema()
graph.install_queries()
llm = LLMLayer(api_key=os.getenv("OPENAI_API_KEY", ""),
model=os.getenv("LLM_MODEL", "gpt-4o-mini"))
llm.initialize()
embedder = EmbeddingManager()
embedder.initialize()
pipeline = IngestionPipeline(graph, llm, embedder)
docs_to_ingest = documents[:max_docs] if max_docs else documents
logger.info(f"Ingesting {len(docs_to_ingest)} documents into TigerGraph...")
custom_docs = [{"id": d["id"], "title": d["title"], "content": d["content"],
"source": d["source"]} for d in docs_to_ingest]
try:
stats = pipeline.ingest_custom_documents(custom_docs, extract_entities=extract_entities)
except Exception as e:
import traceback
logger.error(f"Ingestion crashed: {e}")
logger.error(traceback.format_exc())
return
logger.info(f"✅ Ingestion complete: {stats}")
return stats
def main():
parser = argparse.ArgumentParser(
description="Prepare 2M+ token dataset for GraphRAG Hackathon")
parser.add_argument("--source", choices=["wikipedia", "arxiv", "bbc", "combined"],
default="wikipedia", help="Dataset source")
parser.add_argument("--target-tokens", type=int, default=2_500_000,
help="Target token count (default: 2.5M for safety margin)")
parser.add_argument("--domain", default="science",
help="Domain filter for Wikipedia (science/history/politics/technology/all)")
parser.add_argument("--year", default="2022",
help="Year for BBC News")
parser.add_argument("--output-dir", default="dataset",
help="Output directory")
parser.add_argument("--ingest", action="store_true",
help="Also ingest into TigerGraph (requires TG_HOST, TG_PASSWORD)")
parser.add_argument("--max-ingest", type=int, default=None,
help="Max docs to ingest (default: all)")
parser.add_argument("--no-entities", action="store_true",
help="Skip LLM entity extraction (faster, free)")
args = parser.parse_args()
# Load dataset
if args.source == "wikipedia":
documents = load_wikipedia(args.target_tokens, domain=args.domain)
elif args.source == "arxiv":
documents = load_arxiv(args.target_tokens)
elif args.source == "bbc":
documents = load_bbc_news(args.target_tokens, year=args.year)
elif args.source == "combined":
# Mix: 60% Wikipedia + 25% arXiv + 15% BBC
wiki_target = int(args.target_tokens * 0.6)
arxiv_target = int(args.target_tokens * 0.25)
bbc_target = int(args.target_tokens * 0.15)
documents = (load_wikipedia(wiki_target, domain=args.domain) +
load_arxiv(arxiv_target) +
load_bbc_news(bbc_target, year=args.year))
if not documents:
logger.error("No documents loaded. Check your internet connection.")
sys.exit(1)
# Save to disk
meta = save_dataset(documents, args.output_dir)
if not meta["meets_2m_minimum"]:
logger.warning(f"⚠️ Only {meta['total_tokens']:,} tokens. "
f"Need {2_000_000 - meta['total_tokens']:,} more. "
f"Try --target-tokens {args.target_tokens + 1_000_000}")
# Ingest into TigerGraph
if args.ingest:
ingest_to_tigergraph(documents, max_docs=args.max_ingest,
extract_entities=not args.no_entities)
if __name__ == "__main__":
main()