Spaces:
Paused
Paused
"""Create training data - IDS ONLY with BATCHING (best of both worlds).""" | |
import json | |
import sys | |
import random | |
from pathlib import Path | |
from collections import defaultdict | |
import time | |
def process_dataset(dataset_name, config, negatives_ratio=10, batch_size=5000): | |
"""Process dataset in batches, storing only IDs.""" | |
print(f"\n{'='*50}\nProcessing: {dataset_name}") | |
# Load qrels | |
qrels_path = Path(f'../beir_data/{dataset_name}/qrels/merged.tsv') | |
if not qrels_path.exists(): | |
print(f" β οΈ No merged.tsv found") | |
return | |
print(f" Loading qrels...") | |
qrels = defaultdict(dict) | |
with open(qrels_path, 'r', encoding='utf-8') as f: | |
next(f) # Skip header | |
for line in f: | |
parts = line.strip().split('\t') | |
if len(parts) == 3: | |
qrels[parts[0]][parts[1]] = int(parts[2]) | |
# Count scores for info | |
score_counts = defaultdict(int) | |
for docs in qrels.values(): | |
for score in docs.values(): | |
score_counts[score] += 1 | |
print(f" Loaded {len(qrels):,} queries, scores: {dict(score_counts)}") | |
# Load queries | |
print(f" Loading queries...") | |
queries = {} | |
with open(f'../beir_data/{dataset_name}/queries.jsonl', 'r', encoding='utf-8') as f: | |
for line in f: | |
q = json.loads(line) | |
queries[q['_id']] = q['text'] | |
# Get score mapping from config | |
score_map = config['datasets'][dataset_name]['score_to_category'] | |
# Check if this dataset has both Score 1 and 2 as positive | |
has_score_2_positive = score_map.get('2') == 'positive' | |
has_score_1_positive = score_map.get('1') == 'positive' | |
both_scores_positive = has_score_2_positive and has_score_1_positive | |
# Process in batches | |
all_qids = list(qrels.keys()) | |
random.seed(42) | |
# Create output file | |
output_dir = Path(f'datasets/{dataset_name}') | |
output_dir.mkdir(parents=True, exist_ok=True) | |
output_file = output_dir / 'training_ids.jsonl' | |
# Adjust batch size for dataset size | |
if len(all_qids) > 50000: | |
batch_size = 5000 | |
elif len(all_qids) > 10000: | |
batch_size = 2000 | |
else: | |
batch_size = len(all_qids) # Process small datasets in one batch | |
print(f" Processing {len(all_qids):,} queries in batches of {batch_size:,}...") | |
# Stats | |
total_examples = 0 | |
total_easy_pos = 0 | |
total_hard_pos = 0 | |
total_hard_neg = 0 | |
total_easy_neg = 0 | |
with open(output_file, 'w', encoding='utf-8') as out_f: | |
for batch_start in range(0, len(all_qids), batch_size): | |
batch_end = min(batch_start + batch_size, len(all_qids)) | |
batch_qids = all_qids[batch_start:batch_end] | |
if len(all_qids) > 10000: # Only show progress for large datasets | |
print(f" Processing batch: queries {batch_start:,}-{batch_end:,}") | |
# Process queries in this batch | |
for qid in batch_qids: | |
if qid not in queries: | |
continue | |
docs = qrels[qid] | |
# Categorize documents by score (IDs only!) | |
hard_positive_ids = [] | |
easy_positive_ids = [] | |
hard_negative_ids = [] | |
easy_negative_ids = [] | |
for doc_id, score in docs.items(): | |
category = score_map.get(str(score), 'easy_negative') | |
if category == 'positive': | |
# If both 1 and 2 are positive, differentiate them | |
if both_scores_positive: | |
if score == 2: | |
easy_positive_ids.append(doc_id) | |
elif score == 1: | |
hard_positive_ids.append(doc_id) | |
else: | |
# Only one score is positive, treat all as easy positives | |
easy_positive_ids.append(doc_id) | |
elif category == 'hard_negative': | |
hard_negative_ids.append(doc_id) | |
elif category == 'easy_negative': | |
easy_negative_ids.append(doc_id) | |
# Combine positives | |
all_positive_ids = easy_positive_ids + hard_positive_ids | |
if not all_positive_ids: | |
continue | |
# Calculate 1:10 ratio based on total positives | |
total_positives = len(all_positive_ids) | |
total_negatives_have = len(hard_negative_ids) + len(easy_negative_ids) | |
total_negatives_needed = total_positives * negatives_ratio | |
# Only sample from other queries IN THIS BATCH if we need more | |
if total_negatives_have < total_negatives_needed: | |
need_more = total_negatives_needed - total_negatives_have | |
# Sample from batch queries only (like efficient version!) | |
other_batch_qids = [q for q in batch_qids if q != qid] | |
random.shuffle(other_batch_qids) | |
# Collect ALL doc IDs already judged for this query | |
current_query_docs = set(docs.keys()) | |
added = 0 | |
for other_qid in other_batch_qids: | |
if added >= need_more: | |
break | |
for doc_id in qrels[other_qid]: | |
# Check: not already judged for current query | |
if doc_id not in current_query_docs: | |
easy_negative_ids.append(doc_id) | |
added += 1 | |
if added >= need_more: | |
break | |
# Write example directly to file (streaming) | |
example = { | |
'query_id': qid, | |
'query_text': queries[qid], | |
'source_dataset': dataset_name, | |
'easy_positive_ids': easy_positive_ids, | |
'hard_positive_ids': hard_positive_ids, | |
'hard_negative_ids': hard_negative_ids, | |
'easy_negative_ids': easy_negative_ids[:total_negatives_needed - len(hard_negative_ids)] if easy_negative_ids else [] | |
} | |
out_f.write(json.dumps(example) + '\n') | |
# Update stats | |
total_examples += 1 | |
total_easy_pos += len(example['easy_positive_ids']) | |
total_hard_pos += len(example['hard_positive_ids']) | |
total_hard_neg += len(example['hard_negative_ids']) | |
total_easy_neg += len(example['easy_negative_ids']) | |
# Print stats | |
print(f" β Created {total_examples:,} examples") | |
print(f" Easy positives: {total_easy_pos:,}") | |
print(f" Hard positives: {total_hard_pos:,}") | |
print(f" Hard negatives: {total_hard_neg:,}") | |
print(f" Easy negatives: {total_easy_neg:,}") | |
def main(): | |
"""Main function.""" | |
print("="*50) | |
print("TRAINING DATA CREATION - IDS + BATCHING") | |
print("="*50) | |
print("Best of both worlds: IDs only (small files) + Batching (fast)") | |
# Load config | |
with open('../test_scores/dataset_reports/training_config.json', 'r', encoding='utf-8') as f: | |
config = json.load(f)['beir_training_config'] | |
# Get datasets to process | |
datasets = [name for name, cfg in config['datasets'].items() if cfg['use']] | |
# Check if specific dataset requested | |
if len(sys.argv) > 1: | |
if sys.argv[1] in datasets: | |
datasets = [sys.argv[1]] | |
print(f"Processing only: {sys.argv[1]}") | |
else: | |
print(f"β Dataset '{sys.argv[1]}' not found or disabled") | |
print(f"Available: {', '.join(datasets)}") | |
return | |
print(f"Will process {len(datasets)} datasets") | |
total_start = time.time() | |
# Process each dataset | |
for idx, dataset_name in enumerate(datasets, 1): | |
print(f"\n[{idx}/{len(datasets)}] {dataset_name}") | |
dataset_start = time.time() | |
process_dataset(dataset_name, config) | |
dataset_time = time.time() - dataset_start | |
print(f" Dataset completed in {dataset_time:.2f} seconds") | |
total_time = time.time() - total_start | |
print(f"\nβ Complete! Total time: {total_time:.2f} seconds") | |
print("\nπ Output: datasets/<dataset>/training_ids.jsonl (IDs only)") | |
print("πΎ File sizes: ~100x smaller than full text") | |
print("β‘ Speed: As fast as efficient version") | |
if __name__ == "__main__": | |
main() |