train-mbed / train_datasets_creation /create_training_data_ids_batched.py
amos1088's picture
no
cae25d0
"""Create training data - IDS ONLY with BATCHING (best of both worlds)."""
import json
import sys
import random
from pathlib import Path
from collections import defaultdict
import time
def process_dataset(dataset_name, config, negatives_ratio=10, batch_size=5000):
"""Process dataset in batches, storing only IDs."""
print(f"\n{'='*50}\nProcessing: {dataset_name}")
# Load qrels
qrels_path = Path(f'../beir_data/{dataset_name}/qrels/merged.tsv')
if not qrels_path.exists():
print(f" ⚠️ No merged.tsv found")
return
print(f" Loading qrels...")
qrels = defaultdict(dict)
with open(qrels_path, 'r', encoding='utf-8') as f:
next(f) # Skip header
for line in f:
parts = line.strip().split('\t')
if len(parts) == 3:
qrels[parts[0]][parts[1]] = int(parts[2])
# Count scores for info
score_counts = defaultdict(int)
for docs in qrels.values():
for score in docs.values():
score_counts[score] += 1
print(f" Loaded {len(qrels):,} queries, scores: {dict(score_counts)}")
# Load queries
print(f" Loading queries...")
queries = {}
with open(f'../beir_data/{dataset_name}/queries.jsonl', 'r', encoding='utf-8') as f:
for line in f:
q = json.loads(line)
queries[q['_id']] = q['text']
# Get score mapping from config
score_map = config['datasets'][dataset_name]['score_to_category']
# Check if this dataset has both Score 1 and 2 as positive
has_score_2_positive = score_map.get('2') == 'positive'
has_score_1_positive = score_map.get('1') == 'positive'
both_scores_positive = has_score_2_positive and has_score_1_positive
# Process in batches
all_qids = list(qrels.keys())
random.seed(42)
# Create output file
output_dir = Path(f'datasets/{dataset_name}')
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / 'training_ids.jsonl'
# Adjust batch size for dataset size
if len(all_qids) > 50000:
batch_size = 5000
elif len(all_qids) > 10000:
batch_size = 2000
else:
batch_size = len(all_qids) # Process small datasets in one batch
print(f" Processing {len(all_qids):,} queries in batches of {batch_size:,}...")
# Stats
total_examples = 0
total_easy_pos = 0
total_hard_pos = 0
total_hard_neg = 0
total_easy_neg = 0
with open(output_file, 'w', encoding='utf-8') as out_f:
for batch_start in range(0, len(all_qids), batch_size):
batch_end = min(batch_start + batch_size, len(all_qids))
batch_qids = all_qids[batch_start:batch_end]
if len(all_qids) > 10000: # Only show progress for large datasets
print(f" Processing batch: queries {batch_start:,}-{batch_end:,}")
# Process queries in this batch
for qid in batch_qids:
if qid not in queries:
continue
docs = qrels[qid]
# Categorize documents by score (IDs only!)
hard_positive_ids = []
easy_positive_ids = []
hard_negative_ids = []
easy_negative_ids = []
for doc_id, score in docs.items():
category = score_map.get(str(score), 'easy_negative')
if category == 'positive':
# If both 1 and 2 are positive, differentiate them
if both_scores_positive:
if score == 2:
easy_positive_ids.append(doc_id)
elif score == 1:
hard_positive_ids.append(doc_id)
else:
# Only one score is positive, treat all as easy positives
easy_positive_ids.append(doc_id)
elif category == 'hard_negative':
hard_negative_ids.append(doc_id)
elif category == 'easy_negative':
easy_negative_ids.append(doc_id)
# Combine positives
all_positive_ids = easy_positive_ids + hard_positive_ids
if not all_positive_ids:
continue
# Calculate 1:10 ratio based on total positives
total_positives = len(all_positive_ids)
total_negatives_have = len(hard_negative_ids) + len(easy_negative_ids)
total_negatives_needed = total_positives * negatives_ratio
# Only sample from other queries IN THIS BATCH if we need more
if total_negatives_have < total_negatives_needed:
need_more = total_negatives_needed - total_negatives_have
# Sample from batch queries only (like efficient version!)
other_batch_qids = [q for q in batch_qids if q != qid]
random.shuffle(other_batch_qids)
# Collect ALL doc IDs already judged for this query
current_query_docs = set(docs.keys())
added = 0
for other_qid in other_batch_qids:
if added >= need_more:
break
for doc_id in qrels[other_qid]:
# Check: not already judged for current query
if doc_id not in current_query_docs:
easy_negative_ids.append(doc_id)
added += 1
if added >= need_more:
break
# Write example directly to file (streaming)
example = {
'query_id': qid,
'query_text': queries[qid],
'source_dataset': dataset_name,
'easy_positive_ids': easy_positive_ids,
'hard_positive_ids': hard_positive_ids,
'hard_negative_ids': hard_negative_ids,
'easy_negative_ids': easy_negative_ids[:total_negatives_needed - len(hard_negative_ids)] if easy_negative_ids else []
}
out_f.write(json.dumps(example) + '\n')
# Update stats
total_examples += 1
total_easy_pos += len(example['easy_positive_ids'])
total_hard_pos += len(example['hard_positive_ids'])
total_hard_neg += len(example['hard_negative_ids'])
total_easy_neg += len(example['easy_negative_ids'])
# Print stats
print(f" βœ“ Created {total_examples:,} examples")
print(f" Easy positives: {total_easy_pos:,}")
print(f" Hard positives: {total_hard_pos:,}")
print(f" Hard negatives: {total_hard_neg:,}")
print(f" Easy negatives: {total_easy_neg:,}")
def main():
"""Main function."""
print("="*50)
print("TRAINING DATA CREATION - IDS + BATCHING")
print("="*50)
print("Best of both worlds: IDs only (small files) + Batching (fast)")
# Load config
with open('../test_scores/dataset_reports/training_config.json', 'r', encoding='utf-8') as f:
config = json.load(f)['beir_training_config']
# Get datasets to process
datasets = [name for name, cfg in config['datasets'].items() if cfg['use']]
# Check if specific dataset requested
if len(sys.argv) > 1:
if sys.argv[1] in datasets:
datasets = [sys.argv[1]]
print(f"Processing only: {sys.argv[1]}")
else:
print(f"❌ Dataset '{sys.argv[1]}' not found or disabled")
print(f"Available: {', '.join(datasets)}")
return
print(f"Will process {len(datasets)} datasets")
total_start = time.time()
# Process each dataset
for idx, dataset_name in enumerate(datasets, 1):
print(f"\n[{idx}/{len(datasets)}] {dataset_name}")
dataset_start = time.time()
process_dataset(dataset_name, config)
dataset_time = time.time() - dataset_start
print(f" Dataset completed in {dataset_time:.2f} seconds")
total_time = time.time() - total_start
print(f"\nβœ… Complete! Total time: {total_time:.2f} seconds")
print("\nπŸ“ Output: datasets/<dataset>/training_ids.jsonl (IDs only)")
print("πŸ’Ύ File sizes: ~100x smaller than full text")
print("⚑ Speed: As fast as efficient version")
if __name__ == "__main__":
main()