"""Create training data - IDS ONLY with BATCHING (best of both worlds).""" import json import sys import random from pathlib import Path from collections import defaultdict import time def process_dataset(dataset_name, config, negatives_ratio=10, batch_size=5000): """Process dataset in batches, storing only IDs.""" print(f"\n{'='*50}\nProcessing: {dataset_name}") # Load qrels qrels_path = Path(f'../beir_data/{dataset_name}/qrels/merged.tsv') if not qrels_path.exists(): print(f" ⚠️ No merged.tsv found") return print(f" Loading qrels...") qrels = defaultdict(dict) with open(qrels_path, 'r', encoding='utf-8') as f: next(f) # Skip header for line in f: parts = line.strip().split('\t') if len(parts) == 3: qrels[parts[0]][parts[1]] = int(parts[2]) # Count scores for info score_counts = defaultdict(int) for docs in qrels.values(): for score in docs.values(): score_counts[score] += 1 print(f" Loaded {len(qrels):,} queries, scores: {dict(score_counts)}") # Load queries print(f" Loading queries...") queries = {} with open(f'../beir_data/{dataset_name}/queries.jsonl', 'r', encoding='utf-8') as f: for line in f: q = json.loads(line) queries[q['_id']] = q['text'] # Get score mapping from config score_map = config['datasets'][dataset_name]['score_to_category'] # Check if this dataset has both Score 1 and 2 as positive has_score_2_positive = score_map.get('2') == 'positive' has_score_1_positive = score_map.get('1') == 'positive' both_scores_positive = has_score_2_positive and has_score_1_positive # Process in batches all_qids = list(qrels.keys()) random.seed(42) # Create output file output_dir = Path(f'datasets/{dataset_name}') output_dir.mkdir(parents=True, exist_ok=True) output_file = output_dir / 'training_ids.jsonl' # Adjust batch size for dataset size if len(all_qids) > 50000: batch_size = 5000 elif len(all_qids) > 10000: batch_size = 2000 else: batch_size = len(all_qids) # Process small datasets in one batch print(f" Processing {len(all_qids):,} queries in batches of {batch_size:,}...") # Stats total_examples = 0 total_easy_pos = 0 total_hard_pos = 0 total_hard_neg = 0 total_easy_neg = 0 with open(output_file, 'w', encoding='utf-8') as out_f: for batch_start in range(0, len(all_qids), batch_size): batch_end = min(batch_start + batch_size, len(all_qids)) batch_qids = all_qids[batch_start:batch_end] if len(all_qids) > 10000: # Only show progress for large datasets print(f" Processing batch: queries {batch_start:,}-{batch_end:,}") # Process queries in this batch for qid in batch_qids: if qid not in queries: continue docs = qrels[qid] # Categorize documents by score (IDs only!) hard_positive_ids = [] easy_positive_ids = [] hard_negative_ids = [] easy_negative_ids = [] for doc_id, score in docs.items(): category = score_map.get(str(score), 'easy_negative') if category == 'positive': # If both 1 and 2 are positive, differentiate them if both_scores_positive: if score == 2: easy_positive_ids.append(doc_id) elif score == 1: hard_positive_ids.append(doc_id) else: # Only one score is positive, treat all as easy positives easy_positive_ids.append(doc_id) elif category == 'hard_negative': hard_negative_ids.append(doc_id) elif category == 'easy_negative': easy_negative_ids.append(doc_id) # Combine positives all_positive_ids = easy_positive_ids + hard_positive_ids if not all_positive_ids: continue # Calculate 1:10 ratio based on total positives total_positives = len(all_positive_ids) total_negatives_have = len(hard_negative_ids) + len(easy_negative_ids) total_negatives_needed = total_positives * negatives_ratio # Only sample from other queries IN THIS BATCH if we need more if total_negatives_have < total_negatives_needed: need_more = total_negatives_needed - total_negatives_have # Sample from batch queries only (like efficient version!) other_batch_qids = [q for q in batch_qids if q != qid] random.shuffle(other_batch_qids) # Collect ALL doc IDs already judged for this query current_query_docs = set(docs.keys()) added = 0 for other_qid in other_batch_qids: if added >= need_more: break for doc_id in qrels[other_qid]: # Check: not already judged for current query if doc_id not in current_query_docs: easy_negative_ids.append(doc_id) added += 1 if added >= need_more: break # Write example directly to file (streaming) example = { 'query_id': qid, 'query_text': queries[qid], 'source_dataset': dataset_name, 'easy_positive_ids': easy_positive_ids, 'hard_positive_ids': hard_positive_ids, 'hard_negative_ids': hard_negative_ids, 'easy_negative_ids': easy_negative_ids[:total_negatives_needed - len(hard_negative_ids)] if easy_negative_ids else [] } out_f.write(json.dumps(example) + '\n') # Update stats total_examples += 1 total_easy_pos += len(example['easy_positive_ids']) total_hard_pos += len(example['hard_positive_ids']) total_hard_neg += len(example['hard_negative_ids']) total_easy_neg += len(example['easy_negative_ids']) # Print stats print(f" ✓ Created {total_examples:,} examples") print(f" Easy positives: {total_easy_pos:,}") print(f" Hard positives: {total_hard_pos:,}") print(f" Hard negatives: {total_hard_neg:,}") print(f" Easy negatives: {total_easy_neg:,}") def main(): """Main function.""" print("="*50) print("TRAINING DATA CREATION - IDS + BATCHING") print("="*50) print("Best of both worlds: IDs only (small files) + Batching (fast)") # Load config with open('../test_scores/dataset_reports/training_config.json', 'r', encoding='utf-8') as f: config = json.load(f)['beir_training_config'] # Get datasets to process datasets = [name for name, cfg in config['datasets'].items() if cfg['use']] # Check if specific dataset requested if len(sys.argv) > 1: if sys.argv[1] in datasets: datasets = [sys.argv[1]] print(f"Processing only: {sys.argv[1]}") else: print(f"❌ Dataset '{sys.argv[1]}' not found or disabled") print(f"Available: {', '.join(datasets)}") return print(f"Will process {len(datasets)} datasets") total_start = time.time() # Process each dataset for idx, dataset_name in enumerate(datasets, 1): print(f"\n[{idx}/{len(datasets)}] {dataset_name}") dataset_start = time.time() process_dataset(dataset_name, config) dataset_time = time.time() - dataset_start print(f" Dataset completed in {dataset_time:.2f} seconds") total_time = time.time() - total_start print(f"\n✅ Complete! Total time: {total_time:.2f} seconds") print("\n📝 Output: datasets//training_ids.jsonl (IDs only)") print("💾 File sizes: ~100x smaller than full text") print("⚡ Speed: As fast as efficient version") if __name__ == "__main__": main()