File size: 8,950 Bytes
cae25d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""Create training data - IDS ONLY with BATCHING (best of both worlds)."""

import json
import sys
import random
from pathlib import Path
from collections import defaultdict
import time

def process_dataset(dataset_name, config, negatives_ratio=10, batch_size=5000):
    """Process dataset in batches, storing only IDs."""
    print(f"\n{'='*50}\nProcessing: {dataset_name}")
    
    # Load qrels
    qrels_path = Path(f'../beir_data/{dataset_name}/qrels/merged.tsv')
    if not qrels_path.exists():
        print(f"  ⚠️ No merged.tsv found")
        return
    
    print(f"  Loading qrels...")
    qrels = defaultdict(dict)
    with open(qrels_path, 'r', encoding='utf-8') as f:
        next(f)  # Skip header
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) == 3:
                qrels[parts[0]][parts[1]] = int(parts[2])
    
    # Count scores for info
    score_counts = defaultdict(int)
    for docs in qrels.values():
        for score in docs.values():
            score_counts[score] += 1
    print(f"  Loaded {len(qrels):,} queries, scores: {dict(score_counts)}")
    
    # Load queries
    print(f"  Loading queries...")
    queries = {}
    with open(f'../beir_data/{dataset_name}/queries.jsonl', 'r', encoding='utf-8') as f:
        for line in f:
            q = json.loads(line)
            queries[q['_id']] = q['text']
    
    # Get score mapping from config
    score_map = config['datasets'][dataset_name]['score_to_category']
    
    # Check if this dataset has both Score 1 and 2 as positive
    has_score_2_positive = score_map.get('2') == 'positive'
    has_score_1_positive = score_map.get('1') == 'positive'
    both_scores_positive = has_score_2_positive and has_score_1_positive
    
    # Process in batches
    all_qids = list(qrels.keys())
    random.seed(42)
    
    # Create output file
    output_dir = Path(f'datasets/{dataset_name}')
    output_dir.mkdir(parents=True, exist_ok=True)
    output_file = output_dir / 'training_ids.jsonl'
    
    # Adjust batch size for dataset size
    if len(all_qids) > 50000:
        batch_size = 5000
    elif len(all_qids) > 10000:
        batch_size = 2000
    else:
        batch_size = len(all_qids)  # Process small datasets in one batch
    
    print(f"  Processing {len(all_qids):,} queries in batches of {batch_size:,}...")
    
    # Stats
    total_examples = 0
    total_easy_pos = 0
    total_hard_pos = 0
    total_hard_neg = 0
    total_easy_neg = 0
    
    with open(output_file, 'w', encoding='utf-8') as out_f:
        for batch_start in range(0, len(all_qids), batch_size):
            batch_end = min(batch_start + batch_size, len(all_qids))
            batch_qids = all_qids[batch_start:batch_end]
            
            if len(all_qids) > 10000:  # Only show progress for large datasets
                print(f"    Processing batch: queries {batch_start:,}-{batch_end:,}")
            
            # Process queries in this batch
            for qid in batch_qids:
                if qid not in queries:
                    continue
                
                docs = qrels[qid]
                
                # Categorize documents by score (IDs only!)
                hard_positive_ids = []
                easy_positive_ids = []
                hard_negative_ids = []
                easy_negative_ids = []
                
                for doc_id, score in docs.items():
                    category = score_map.get(str(score), 'easy_negative')
                    
                    if category == 'positive':
                        # If both 1 and 2 are positive, differentiate them
                        if both_scores_positive:
                            if score == 2:
                                easy_positive_ids.append(doc_id)
                            elif score == 1:
                                hard_positive_ids.append(doc_id)
                        else:
                            # Only one score is positive, treat all as easy positives
                            easy_positive_ids.append(doc_id)
                    elif category == 'hard_negative':
                        hard_negative_ids.append(doc_id)
                    elif category == 'easy_negative':
                        easy_negative_ids.append(doc_id)
                
                # Combine positives
                all_positive_ids = easy_positive_ids + hard_positive_ids
                if not all_positive_ids:
                    continue
                
                # Calculate 1:10 ratio based on total positives
                total_positives = len(all_positive_ids)
                total_negatives_have = len(hard_negative_ids) + len(easy_negative_ids)
                total_negatives_needed = total_positives * negatives_ratio
                
                # Only sample from other queries IN THIS BATCH if we need more
                if total_negatives_have < total_negatives_needed:
                    need_more = total_negatives_needed - total_negatives_have
                    
                    # Sample from batch queries only (like efficient version!)
                    other_batch_qids = [q for q in batch_qids if q != qid]
                    random.shuffle(other_batch_qids)
                    
                    # Collect ALL doc IDs already judged for this query
                    current_query_docs = set(docs.keys())
                    
                    added = 0
                    for other_qid in other_batch_qids:
                        if added >= need_more:
                            break
                        for doc_id in qrels[other_qid]:
                            # Check: not already judged for current query
                            if doc_id not in current_query_docs:
                                easy_negative_ids.append(doc_id)
                                added += 1
                                if added >= need_more:
                                    break
                
                # Write example directly to file (streaming)
                example = {
                    'query_id': qid,
                    'query_text': queries[qid],
                    'source_dataset': dataset_name,
                    'easy_positive_ids': easy_positive_ids,
                    'hard_positive_ids': hard_positive_ids,
                    'hard_negative_ids': hard_negative_ids,
                    'easy_negative_ids': easy_negative_ids[:total_negatives_needed - len(hard_negative_ids)] if easy_negative_ids else []
                }
                
                out_f.write(json.dumps(example) + '\n')
                
                # Update stats
                total_examples += 1
                total_easy_pos += len(example['easy_positive_ids'])
                total_hard_pos += len(example['hard_positive_ids'])
                total_hard_neg += len(example['hard_negative_ids'])
                total_easy_neg += len(example['easy_negative_ids'])
    
    # Print stats
    print(f"  βœ“ Created {total_examples:,} examples")
    print(f"    Easy positives: {total_easy_pos:,}")
    print(f"    Hard positives: {total_hard_pos:,}")
    print(f"    Hard negatives: {total_hard_neg:,}")
    print(f"    Easy negatives: {total_easy_neg:,}")

def main():
    """Main function."""
    print("="*50)
    print("TRAINING DATA CREATION - IDS + BATCHING")
    print("="*50)
    print("Best of both worlds: IDs only (small files) + Batching (fast)")
    
    # Load config
    with open('../test_scores/dataset_reports/training_config.json', 'r', encoding='utf-8') as f:
        config = json.load(f)['beir_training_config']
    
    # Get datasets to process
    datasets = [name for name, cfg in config['datasets'].items() if cfg['use']]
    
    # Check if specific dataset requested
    if len(sys.argv) > 1:
        if sys.argv[1] in datasets:
            datasets = [sys.argv[1]]
            print(f"Processing only: {sys.argv[1]}")
        else:
            print(f"❌ Dataset '{sys.argv[1]}' not found or disabled")
            print(f"Available: {', '.join(datasets)}")
            return
    
    print(f"Will process {len(datasets)} datasets")
    
    total_start = time.time()
    
    # Process each dataset
    for idx, dataset_name in enumerate(datasets, 1):
        print(f"\n[{idx}/{len(datasets)}] {dataset_name}")
        dataset_start = time.time()
        process_dataset(dataset_name, config)
        dataset_time = time.time() - dataset_start
        print(f"  Dataset completed in {dataset_time:.2f} seconds")
    
    total_time = time.time() - total_start
    print(f"\nβœ… Complete! Total time: {total_time:.2f} seconds")
    print("\nπŸ“ Output: datasets/<dataset>/training_ids.jsonl (IDs only)")
    print("πŸ’Ύ File sizes: ~100x smaller than full text")
    print("⚑ Speed: As fast as efficient version")

if __name__ == "__main__":
    main()