| """
|
| DatasetLoader: Loads and processes open scientific datasets.
|
| Supports streaming from HuggingFace datasets with sharding.
|
| """
|
|
|
| import os
|
| import json
|
| from typing import List, Dict, Optional, Iterator
|
| from pathlib import Path
|
|
|
| try:
|
| from datasets import load_dataset, Dataset, IterableDataset
|
| import pyarrow.parquet as pq
|
| except ImportError:
|
| print("Please install datasets and pyarrow: pip install datasets pyarrow")
|
| raise
|
|
|
|
|
| class VortexDatasetLoader:
|
| """
|
| Loads and processes open scientific datasets.
|
| Supports streaming with sharding to Parquet files.
|
| """
|
|
|
|
|
| DATASETS = {
|
| "pile_scientific": {
|
| "path": "EleutherAI/pile",
|
| "subset": "pubmed_central",
|
| "split": "train",
|
| "text_field": "text",
|
| "domain": "biology",
|
| },
|
| "s2orc": {
|
| "path": "allenai/s2orc",
|
| "subset": None,
|
| "split": "train",
|
| "text_field": "text",
|
| "domain": "multidisciplinary",
|
| },
|
| "pes2o": {
|
| "path": "allenai/peS2o",
|
| "subset": None,
|
| "split": "train",
|
| "text_field": "text",
|
| "domain": "multidisciplinary",
|
| },
|
| "automath": {
|
| "path": "math-ai/AutoMathText",
|
| "subset": None,
|
| "split": "train",
|
| "text_field": "text",
|
| "domain": "math",
|
| },
|
| "deepmind_math": {
|
| "path": "deepmind/math_dataset",
|
| "subset": "algebra__linear_1d",
|
| "split": "train",
|
| "text_field": "text",
|
| "domain": "math",
|
| },
|
| "pubmed_qa": {
|
| "path": "bigbio/pubmed_qa",
|
| "subset": "pubmed_qa_labeled_fold0_source",
|
| "split": "train",
|
| "text_field": "question",
|
| "domain": "biology",
|
| },
|
| }
|
|
|
| def __init__(
|
| self,
|
| cache_dir: str = "./data/cache",
|
| output_dir: str = "./data/processed",
|
| num_proc: int = 4,
|
| ):
|
| """
|
| Initialize dataset loader.
|
|
|
| Args:
|
| cache_dir: Directory for caching downloaded datasets
|
| output_dir: Directory for processed shards
|
| num_proc: Number of processes for data processing
|
| """
|
| self.cache_dir = Path(cache_dir)
|
| self.output_dir = Path(output_dir)
|
| self.num_proc = num_proc
|
|
|
| self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| def load_dataset(
|
| self,
|
| dataset_name: str,
|
| streaming: bool = True,
|
| max_samples: Optional[int] = None,
|
| ) -> Iterator[Dict]:
|
| """
|
| Load a dataset as an iterator.
|
|
|
| Args:
|
| dataset_name: Name from DATASETS config
|
| streaming: Use streaming mode for large datasets
|
| max_samples: Maximum number of samples to yield
|
|
|
| Yields:
|
| Dictionary with text and metadata
|
| """
|
| if dataset_name not in self.DATASETS:
|
| raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(self.DATASETS.keys())}")
|
|
|
| config = self.DATASETS[dataset_name]
|
|
|
| print(f"Loading dataset: {dataset_name}")
|
| print(f" Path: {config['path']}")
|
| print(f" Streaming: {streaming}")
|
|
|
| try:
|
| dataset = load_dataset(
|
| config["path"],
|
| name=config["subset"],
|
| split=config["split"],
|
| streaming=streaming,
|
| cache_dir=str(self.cache_dir),
|
| )
|
|
|
| count = 0
|
| for sample in dataset:
|
| text = sample.get(config["text_field"], "")
|
| if not text or not isinstance(text, str):
|
| continue
|
|
|
| yield {
|
| "text": text,
|
| "dataset": dataset_name,
|
| "domain": config["domain"],
|
| "source": config["path"],
|
| }
|
|
|
| count += 1
|
| if max_samples and count >= max_samples:
|
| break
|
|
|
| print(f"Loaded {count} samples from {dataset_name}")
|
|
|
| except Exception as e:
|
| print(f"Error loading dataset {dataset_name}: {e}")
|
|
|
| return
|
|
|
| def load_multiple_datasets(
|
| self,
|
| dataset_names: List[str],
|
| streaming: bool = True,
|
| max_per_dataset: Optional[int] = None,
|
| ) -> Iterator[Dict]:
|
| """
|
| Load multiple datasets and yield samples interleaved.
|
|
|
| Args:
|
| dataset_names: List of dataset names
|
| streaming: Use streaming mode
|
| max_per_dataset: Max samples per dataset
|
|
|
| Yields:
|
| Dictionary with text and metadata
|
| """
|
| iterators = []
|
| for name in dataset_names:
|
| it = self.load_dataset(name, streaming=streaming, max_samples=max_per_dataset)
|
| iterators.append(it)
|
|
|
|
|
| active = len(iterators)
|
| while active > 0:
|
| for i, it in enumerate(iterators):
|
| if it is None:
|
| continue
|
| try:
|
| yield next(it)
|
| except StopIteration:
|
| iterators[i] = None
|
| active -= 1
|
| break
|
|
|
| def shard_to_parquet(
|
| self,
|
| samples: Iterator[Dict],
|
| output_prefix: str,
|
| samples_per_shard: int = 10000,
|
| ):
|
| """
|
| Write samples to sharded Parquet files.
|
|
|
| Args:
|
| samples: Iterator of sample dictionaries
|
| output_prefix: Prefix for output files (e.g., "train")
|
| samples_per_shard: Number of samples per shard
|
| """
|
| shard_index = 0
|
| buffer = []
|
|
|
| for sample in samples:
|
| buffer.append(sample)
|
|
|
| if len(buffer) >= samples_per_shard:
|
| self._write_shard(buffer, output_prefix, shard_index)
|
| shard_index += 1
|
| buffer = []
|
|
|
|
|
| if buffer:
|
| self._write_shard(buffer, output_prefix, shard_index)
|
|
|
| print(f"Wrote {shard_index + 1} shards to {self.output_dir}")
|
|
|
| def _write_shard(
|
| self,
|
| buffer: List[Dict],
|
| output_prefix: str,
|
| shard_index: int,
|
| ):
|
| """Write a single shard to Parquet."""
|
| import pandas as pd
|
|
|
| df = pd.DataFrame(buffer)
|
| output_path = self.output_dir / f"{output_prefix}_{shard_index:05d}.parquet"
|
| df.to_parquet(output_path, index=False)
|
|
|
| def get_shard_list(
|
| self,
|
| prefix: str,
|
| ) -> List[Path]:
|
| """Get list of shard files matching prefix."""
|
| return sorted(self.output_dir.glob(f"{prefix}_*.parquet"))
|
|
|
| def read_shard(
|
| self,
|
| shard_path: Path,
|
| ) -> List[Dict]:
|
| """Read a single shard."""
|
| import pandas as pd
|
| df = pd.read_parquet(shard_path)
|
| return df.to_dict('records')
|
|
|
|
|
| def test_dataset_loader():
|
| """Test the dataset loader."""
|
| loader = VortexDatasetLoader()
|
|
|
|
|
| print("Testing dataset loader...")
|
| count = 0
|
| for sample in loader.load_dataset("pubmed_qa", streaming=True, max_samples=10):
|
| print(f"Sample {count}: {sample['text'][:100]}...")
|
| count += 1
|
|
|
| print(f"Loaded {count} samples")
|
| print("DatasetLoader test passed!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_dataset_loader()
|
|
|