Spaces:
Paused
Paused
""" | |
Merge all qrels splits (train/dev/test) into one merged.tsv file for each dataset. | |
We don't care about splits - we want ALL the data! | |
""" | |
import os | |
from pathlib import Path | |
def merge_qrels_for_dataset(dataset_path): | |
"""Merge all qrels files in a dataset into merged.tsv.""" | |
qrels_dir = dataset_path / 'qrels' | |
if not qrels_dir.exists(): | |
print(f" ⚠️ No qrels directory found") | |
return | |
# Find all .tsv files | |
tsv_files = list(qrels_dir.glob('*.tsv')) | |
if not tsv_files: | |
print(f" ⚠️ No TSV files found") | |
return | |
# Collect all unique entries (query_id, doc_id, score) | |
all_entries = {} # (qid, doc_id) -> score | |
header = None | |
for tsv_file in tsv_files: | |
if tsv_file.name == 'merged.tsv': | |
continue # Skip if already merged | |
print(f" Reading {tsv_file.name}...") | |
with open(tsv_file, 'r', encoding='utf-8') as f: | |
lines = f.readlines() | |
# Get header from first file | |
if header is None and lines: | |
header = lines[0].strip() | |
# Process data lines | |
for line in lines[1:]: # Skip header | |
if not line.strip(): | |
continue | |
parts = line.strip().split('\t') | |
if len(parts) >= 3: | |
qid = parts[0] | |
doc_id = parts[1] | |
score = int(parts[2]) | |
# Store or update if higher score | |
key = (qid, doc_id) | |
if key not in all_entries or score > all_entries[key]: | |
all_entries[key] = score | |
# Write merged file | |
merged_file = qrels_dir / 'merged.tsv' | |
with open(merged_file, 'w', encoding='utf-8') as f: | |
# Write header | |
f.write(header + '\n') | |
# Write all entries sorted by query_id then doc_id | |
for (qid, doc_id), score in sorted(all_entries.items()): | |
f.write(f"{qid}\t{doc_id}\t{score}\n") | |
print(f" ✓ Merged {len(all_entries)} unique entries into merged.tsv") | |
print(f" From splits: {', '.join(f.stem for f in tsv_files if f.name != 'merged.tsv')}") | |
return len(all_entries) | |
def main(): | |
"""Merge all splits for all datasets in beir_data.""" | |
beir_data_dir = Path('../beir_data') | |
if not beir_data_dir.exists(): | |
print(f"Error: {beir_data_dir} not found!") | |
return | |
# Get all dataset directories | |
dataset_dirs = [d for d in beir_data_dir.iterdir() if d.is_dir()] | |
dataset_dirs.sort() | |
print(f"Found {len(dataset_dirs)} datasets in beir_data") | |
print("="*60) | |
total_entries = 0 | |
for dataset_dir in dataset_dirs: | |
print(f"\nProcessing {dataset_dir.name}...") | |
entries = merge_qrels_for_dataset(dataset_dir) | |
if entries: | |
total_entries += entries | |
print("\n" + "="*60) | |
print(f"DONE! Merged {total_entries} total qrel entries across all datasets") | |
print("All datasets now have a 'merged.tsv' file combining all splits") | |
if __name__ == "__main__": | |
main() |