train-mbed / train_datasets_creation /merge_all_splits.py
amos1088's picture
no
cae25d0
"""
Merge all qrels splits (train/dev/test) into one merged.tsv file for each dataset.
We don't care about splits - we want ALL the data!
"""
import os
from pathlib import Path
def merge_qrels_for_dataset(dataset_path):
"""Merge all qrels files in a dataset into merged.tsv."""
qrels_dir = dataset_path / 'qrels'
if not qrels_dir.exists():
print(f" ⚠️ No qrels directory found")
return
# Find all .tsv files
tsv_files = list(qrels_dir.glob('*.tsv'))
if not tsv_files:
print(f" ⚠️ No TSV files found")
return
# Collect all unique entries (query_id, doc_id, score)
all_entries = {} # (qid, doc_id) -> score
header = None
for tsv_file in tsv_files:
if tsv_file.name == 'merged.tsv':
continue # Skip if already merged
print(f" Reading {tsv_file.name}...")
with open(tsv_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
# Get header from first file
if header is None and lines:
header = lines[0].strip()
# Process data lines
for line in lines[1:]: # Skip header
if not line.strip():
continue
parts = line.strip().split('\t')
if len(parts) >= 3:
qid = parts[0]
doc_id = parts[1]
score = int(parts[2])
# Store or update if higher score
key = (qid, doc_id)
if key not in all_entries or score > all_entries[key]:
all_entries[key] = score
# Write merged file
merged_file = qrels_dir / 'merged.tsv'
with open(merged_file, 'w', encoding='utf-8') as f:
# Write header
f.write(header + '\n')
# Write all entries sorted by query_id then doc_id
for (qid, doc_id), score in sorted(all_entries.items()):
f.write(f"{qid}\t{doc_id}\t{score}\n")
print(f" ✓ Merged {len(all_entries)} unique entries into merged.tsv")
print(f" From splits: {', '.join(f.stem for f in tsv_files if f.name != 'merged.tsv')}")
return len(all_entries)
def main():
"""Merge all splits for all datasets in beir_data."""
beir_data_dir = Path('../beir_data')
if not beir_data_dir.exists():
print(f"Error: {beir_data_dir} not found!")
return
# Get all dataset directories
dataset_dirs = [d for d in beir_data_dir.iterdir() if d.is_dir()]
dataset_dirs.sort()
print(f"Found {len(dataset_dirs)} datasets in beir_data")
print("="*60)
total_entries = 0
for dataset_dir in dataset_dirs:
print(f"\nProcessing {dataset_dir.name}...")
entries = merge_qrels_for_dataset(dataset_dir)
if entries:
total_entries += entries
print("\n" + "="*60)
print(f"DONE! Merged {total_entries} total qrel entries across all datasets")
print("All datasets now have a 'merged.tsv' file combining all splits")
if __name__ == "__main__":
main()