Spaces:
Paused
Paused
File size: 8,950 Bytes
cae25d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
"""Create training data - IDS ONLY with BATCHING (best of both worlds)."""
import json
import sys
import random
from pathlib import Path
from collections import defaultdict
import time
def process_dataset(dataset_name, config, negatives_ratio=10, batch_size=5000):
"""Process dataset in batches, storing only IDs."""
print(f"\n{'='*50}\nProcessing: {dataset_name}")
# Load qrels
qrels_path = Path(f'../beir_data/{dataset_name}/qrels/merged.tsv')
if not qrels_path.exists():
print(f" β οΈ No merged.tsv found")
return
print(f" Loading qrels...")
qrels = defaultdict(dict)
with open(qrels_path, 'r', encoding='utf-8') as f:
next(f) # Skip header
for line in f:
parts = line.strip().split('\t')
if len(parts) == 3:
qrels[parts[0]][parts[1]] = int(parts[2])
# Count scores for info
score_counts = defaultdict(int)
for docs in qrels.values():
for score in docs.values():
score_counts[score] += 1
print(f" Loaded {len(qrels):,} queries, scores: {dict(score_counts)}")
# Load queries
print(f" Loading queries...")
queries = {}
with open(f'../beir_data/{dataset_name}/queries.jsonl', 'r', encoding='utf-8') as f:
for line in f:
q = json.loads(line)
queries[q['_id']] = q['text']
# Get score mapping from config
score_map = config['datasets'][dataset_name]['score_to_category']
# Check if this dataset has both Score 1 and 2 as positive
has_score_2_positive = score_map.get('2') == 'positive'
has_score_1_positive = score_map.get('1') == 'positive'
both_scores_positive = has_score_2_positive and has_score_1_positive
# Process in batches
all_qids = list(qrels.keys())
random.seed(42)
# Create output file
output_dir = Path(f'datasets/{dataset_name}')
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / 'training_ids.jsonl'
# Adjust batch size for dataset size
if len(all_qids) > 50000:
batch_size = 5000
elif len(all_qids) > 10000:
batch_size = 2000
else:
batch_size = len(all_qids) # Process small datasets in one batch
print(f" Processing {len(all_qids):,} queries in batches of {batch_size:,}...")
# Stats
total_examples = 0
total_easy_pos = 0
total_hard_pos = 0
total_hard_neg = 0
total_easy_neg = 0
with open(output_file, 'w', encoding='utf-8') as out_f:
for batch_start in range(0, len(all_qids), batch_size):
batch_end = min(batch_start + batch_size, len(all_qids))
batch_qids = all_qids[batch_start:batch_end]
if len(all_qids) > 10000: # Only show progress for large datasets
print(f" Processing batch: queries {batch_start:,}-{batch_end:,}")
# Process queries in this batch
for qid in batch_qids:
if qid not in queries:
continue
docs = qrels[qid]
# Categorize documents by score (IDs only!)
hard_positive_ids = []
easy_positive_ids = []
hard_negative_ids = []
easy_negative_ids = []
for doc_id, score in docs.items():
category = score_map.get(str(score), 'easy_negative')
if category == 'positive':
# If both 1 and 2 are positive, differentiate them
if both_scores_positive:
if score == 2:
easy_positive_ids.append(doc_id)
elif score == 1:
hard_positive_ids.append(doc_id)
else:
# Only one score is positive, treat all as easy positives
easy_positive_ids.append(doc_id)
elif category == 'hard_negative':
hard_negative_ids.append(doc_id)
elif category == 'easy_negative':
easy_negative_ids.append(doc_id)
# Combine positives
all_positive_ids = easy_positive_ids + hard_positive_ids
if not all_positive_ids:
continue
# Calculate 1:10 ratio based on total positives
total_positives = len(all_positive_ids)
total_negatives_have = len(hard_negative_ids) + len(easy_negative_ids)
total_negatives_needed = total_positives * negatives_ratio
# Only sample from other queries IN THIS BATCH if we need more
if total_negatives_have < total_negatives_needed:
need_more = total_negatives_needed - total_negatives_have
# Sample from batch queries only (like efficient version!)
other_batch_qids = [q for q in batch_qids if q != qid]
random.shuffle(other_batch_qids)
# Collect ALL doc IDs already judged for this query
current_query_docs = set(docs.keys())
added = 0
for other_qid in other_batch_qids:
if added >= need_more:
break
for doc_id in qrels[other_qid]:
# Check: not already judged for current query
if doc_id not in current_query_docs:
easy_negative_ids.append(doc_id)
added += 1
if added >= need_more:
break
# Write example directly to file (streaming)
example = {
'query_id': qid,
'query_text': queries[qid],
'source_dataset': dataset_name,
'easy_positive_ids': easy_positive_ids,
'hard_positive_ids': hard_positive_ids,
'hard_negative_ids': hard_negative_ids,
'easy_negative_ids': easy_negative_ids[:total_negatives_needed - len(hard_negative_ids)] if easy_negative_ids else []
}
out_f.write(json.dumps(example) + '\n')
# Update stats
total_examples += 1
total_easy_pos += len(example['easy_positive_ids'])
total_hard_pos += len(example['hard_positive_ids'])
total_hard_neg += len(example['hard_negative_ids'])
total_easy_neg += len(example['easy_negative_ids'])
# Print stats
print(f" β Created {total_examples:,} examples")
print(f" Easy positives: {total_easy_pos:,}")
print(f" Hard positives: {total_hard_pos:,}")
print(f" Hard negatives: {total_hard_neg:,}")
print(f" Easy negatives: {total_easy_neg:,}")
def main():
"""Main function."""
print("="*50)
print("TRAINING DATA CREATION - IDS + BATCHING")
print("="*50)
print("Best of both worlds: IDs only (small files) + Batching (fast)")
# Load config
with open('../test_scores/dataset_reports/training_config.json', 'r', encoding='utf-8') as f:
config = json.load(f)['beir_training_config']
# Get datasets to process
datasets = [name for name, cfg in config['datasets'].items() if cfg['use']]
# Check if specific dataset requested
if len(sys.argv) > 1:
if sys.argv[1] in datasets:
datasets = [sys.argv[1]]
print(f"Processing only: {sys.argv[1]}")
else:
print(f"β Dataset '{sys.argv[1]}' not found or disabled")
print(f"Available: {', '.join(datasets)}")
return
print(f"Will process {len(datasets)} datasets")
total_start = time.time()
# Process each dataset
for idx, dataset_name in enumerate(datasets, 1):
print(f"\n[{idx}/{len(datasets)}] {dataset_name}")
dataset_start = time.time()
process_dataset(dataset_name, config)
dataset_time = time.time() - dataset_start
print(f" Dataset completed in {dataset_time:.2f} seconds")
total_time = time.time() - total_start
print(f"\nβ
Complete! Total time: {total_time:.2f} seconds")
print("\nπ Output: datasets/<dataset>/training_ids.jsonl (IDs only)")
print("πΎ File sizes: ~100x smaller than full text")
print("β‘ Speed: As fast as efficient version")
if __name__ == "__main__":
main() |