| """Utility functions for data processing."""
|
|
|
| import json
|
| import logging
|
| from pathlib import Path
|
| from typing import Any, Dict, List, Optional
|
|
|
| import numpy as np
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
|
| """Load JSONL file."""
|
| data = []
|
| with open(file_path, 'r', encoding='utf-8') as f:
|
| for line in f:
|
| line = line.strip()
|
| if line:
|
| try:
|
| data.append(json.loads(line))
|
| except json.JSONDecodeError as e:
|
| logger.warning(f"Failed to parse line: {e}")
|
| return data
|
|
|
|
|
| def save_jsonl(data: List[Dict[str, Any]], file_path: str):
|
| """Save to JSONL file."""
|
| Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
| with open(file_path, 'w', encoding='utf-8') as f:
|
| for item in data:
|
| f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
|
|
|
|
| def load_json(file_path: str) -> Any:
|
| """Load JSON file."""
|
| with open(file_path, 'r', encoding='utf-8') as f:
|
| return json.load(f)
|
|
|
|
|
| def save_json(data: Any, file_path: str):
|
| """Save to JSON file."""
|
| Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
| with open(file_path, 'w', encoding='utf-8') as f:
|
| json.dump(data, f, ensure_ascii=False, indent=2)
|
|
|
|
|
| def merge_datasets(
|
| datasets: List[Any],
|
| weights: Optional[List[float]] = None,
|
| ) -> List[Dict[str, Any]]:
|
| """Merge multiple datasets with optional weighting."""
|
| if weights is None:
|
| weights = [1.0] * len(datasets)
|
|
|
|
|
| total_weight = sum(weights)
|
| weights = [w / total_weight for w in weights]
|
|
|
| merged = []
|
| for ds, weight in zip(datasets, weights):
|
|
|
| n_samples = int(len(ds) * weight)
|
| if n_samples > len(ds):
|
| n_samples = len(ds)
|
|
|
|
|
| indices = np.random.choice(len(ds), size=n_samples, replace=False)
|
| for idx in indices:
|
| merged.append(ds[idx])
|
|
|
| logger.info(f"Merged dataset size: {len(merged)}")
|
| return merged
|
|
|
|
|
| def compute_dataset_statistics(
|
| dataset: Any,
|
| text_key: str = "text",
|
| ) -> Dict[str, Any]:
|
| """Compute comprehensive statistics for dataset."""
|
| lengths = []
|
| domains = []
|
| has_thoughts = []
|
|
|
| for sample in dataset:
|
| text = sample.get(text_key, "")
|
| lengths.append(len(text.split()))
|
|
|
| domain = sample.get("domain", "unknown")
|
| domains.append(domain)
|
|
|
| has_thoughts.append(1 if sample.get("thoughts") else 0)
|
|
|
| return {
|
| "num_samples": len(dataset),
|
| "length_stats": compute_length_statistics(lengths),
|
| "domain_distribution": {d: domains.count(d) / len(domains) for d in set(domains)},
|
| "thoughts_coverage": sum(has_thoughts) / len(has_thoughts),
|
| }
|
|
|
|
|
| def validate_dataset(dataset: Any, required_keys: List[str]) -> List[str]:
|
| """Validate dataset structure."""
|
| errors = []
|
|
|
| for i, sample in enumerate(dataset):
|
| for key in required_keys:
|
| if key not in sample:
|
| errors.append(f"Sample {i} missing required key: {key}")
|
|
|
| return errors
|
|
|
|
|
| def deduplicate_dataset(dataset: List[Dict[str, Any]], key: str = "text") -> List[Dict[str, Any]]:
|
| """Remove duplicate samples based on key."""
|
| seen = set()
|
| deduplicated = []
|
|
|
| for sample in dataset:
|
| value = sample.get(key, "")
|
| if value not in seen:
|
| seen.add(value)
|
| deduplicated.append(sample)
|
|
|
| logger.info(f"Deduplicated: {len(dataset)} -> {len(deduplicated)}")
|
| return deduplicated
|
|
|
|
|
| def balance_dataset(
|
| dataset: List[Dict[str, Any]],
|
| by_key: str = "domain",
|
| max_per_category: Optional[int] = None,
|
| ) -> List[Dict[str, Any]]:
|
| """Balance dataset across categories."""
|
| categories = {}
|
| for sample in dataset:
|
| category = sample.get(by_key, "unknown")
|
| if category not in categories:
|
| categories[category] = []
|
| categories[category].append(sample)
|
|
|
|
|
| if max_per_category is None:
|
| max_per_category = min(len(cat) for cat in categories.values())
|
|
|
|
|
| balanced = []
|
| for category, samples in categories.items():
|
| if len(samples) > max_per_category:
|
| samples = np.random.choice(samples, size=max_per_category, replace=False).tolist()
|
| balanced.extend(samples)
|
|
|
| logger.info(f"Balanced dataset: {len(dataset)} -> {len(balanced)}")
|
| return balanced
|
|
|