RGBMetrics / src /data_loader.py
RGB Evaluation
fix: Information Integration evaluation - handle multiple answer variants with pipe-separated format
5253a83
"""
Data Loader for RGB Dataset
Handles loading and preprocessing of RGB benchmark datasets:
- en_refine.json: For noise robustness and negative rejection
- en_int.json: For information integration
- en_fact.json: For counterfactual robustness
Dataset structure (from https://github.com/chen700564/RGB):
- en_refine.json: {id, query, answer, positive, negative}
- en_int.json: {id, query, answer, answer1, answer2, positive, negative}
- en_fact.json: {id, query, answer, fakeanswer, positive_wrong, positive, negative}
"""
import json
import os
import random
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
class TaskType(Enum):
"""Types of RAG evaluation tasks."""
NOISE_ROBUSTNESS = "noise_robustness"
NEGATIVE_REJECTION = "negative_rejection"
INFORMATION_INTEGRATION = "information_integration"
COUNTERFACTUAL_ROBUSTNESS = "counterfactual_robustness"
@dataclass
class RGBSample:
"""A single sample from the RGB dataset."""
id: int
question: str
answer: str # Ground truth answer (can be string or list)
documents: List[str] # Retrieved documents/passages
task_type: TaskType
noise_level: Optional[int] = None # Number of noise documents
has_answer: Optional[bool] = None # Whether docs contain the answer
num_docs_needed: Optional[int] = None # Docs needed for answer
has_counterfactual: Optional[bool] = None # Whether docs contain counterfactual
counterfactual_answer: Optional[str] = None # The counterfactual (wrong) answer
raw_data: Optional[Dict] = None # Original raw data
class RGBDataLoader:
"""
Loader for RGB benchmark datasets.
Implements data loading as per the RGB paper and repository.
"""
def __init__(self, data_dir: str = "data", passage_num: int = 5):
"""
Initialize the data loader.
Args:
data_dir: Directory containing the RGB dataset files.
passage_num: Number of passages to include per sample (default 5).
"""
self.data_dir = data_dir
self.passage_num = passage_num
self._validate_data_dir()
def _validate_data_dir(self) -> None:
"""Check if data directory exists."""
if not os.path.exists(self.data_dir):
os.makedirs(self.data_dir)
print(f"Created data directory: {self.data_dir}")
print("Please run: python download_datasets.py")
def _get_file_path(self, filename: str) -> str:
"""Get full path to a data file."""
return os.path.join(self.data_dir, filename)
def _load_jsonl(self, filepath: str) -> List[Dict]:
"""Load a JSONL file (one JSON object per line)."""
data = []
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
data.append(json.loads(line))
return data
def _format_answer(self, answer: Any) -> str:
"""
Format answer to string for comparison.
For nested lists (information integration), flatten to list of alternatives.
For simple lists (noise robustness), take first or join.
"""
if isinstance(answer, list):
# Check if it's a nested list (from en_int.json with answer variants)
if answer and isinstance(answer[0], list):
# Flatten nested list: [['variant1', 'variant2'], 'other_answer'] → all variants
variants = []
for item in answer:
if isinstance(item, list):
variants.extend(item)
else:
variants.append(str(item))
# Return as pipe-separated alternatives for matching
return "|".join(variants)
else:
# Simple list: join with pipe as alternatives
return "|".join(str(a) for a in answer)
return str(answer)
def load_noise_robustness(
self,
max_samples: Optional[int] = None,
noise_rate: float = 0.4
) -> List[RGBSample]:
"""
Load data for Noise Robustness evaluation.
Uses en_refine.json - tests LLM's ability to handle noisy documents.
Args:
max_samples: Maximum number of samples to load (None for all).
noise_rate: Rate of noise documents (0.0 to 0.8).
Returns:
List of RGBSample objects for noise robustness evaluation.
"""
filepath = self._get_file_path("en_refine.json")
if not os.path.exists(filepath):
raise FileNotFoundError(
f"Dataset file not found: {filepath}\n"
"Please run: python download_datasets.py"
)
data = self._load_jsonl(filepath)
samples = []
for idx, item in enumerate(data):
if max_samples and idx >= max_samples:
break
# Calculate number of positive and negative documents
neg_num = int(self.passage_num * noise_rate)
pos_num = self.passage_num - neg_num
# Get positive and negative documents
positive_docs = item.get('positive', [])[:pos_num]
negative_docs = item.get('negative', [])[:neg_num]
# Combine and shuffle documents
documents = positive_docs + negative_docs
random.shuffle(documents)
if not documents:
continue
sample = RGBSample(
id=item.get('id', idx),
question=item.get('query', ''),
answer=self._format_answer(item.get('answer', '')),
documents=documents,
task_type=TaskType.NOISE_ROBUSTNESS,
noise_level=neg_num,
has_answer=True,
raw_data=item
)
samples.append(sample)
print(f"Loaded {len(samples)} samples for Noise Robustness (noise_rate={noise_rate})")
return samples
def load_negative_rejection(
self,
max_samples: Optional[int] = None
) -> List[RGBSample]:
"""
Load data for Negative Rejection evaluation.
Uses en_refine.json with noise_rate=1.0 (all negative documents).
Tests LLM's ability to reject when documents don't contain the answer.
Args:
max_samples: Maximum number of samples to load (None for all).
Returns:
List of RGBSample objects for negative rejection evaluation.
"""
filepath = self._get_file_path("en_refine.json")
if not os.path.exists(filepath):
raise FileNotFoundError(
f"Dataset file not found: {filepath}\n"
"Please run: python download_datasets.py"
)
data = self._load_jsonl(filepath)
samples = []
for idx, item in enumerate(data):
if max_samples and idx >= max_samples:
break
# For negative rejection, use only negative documents
negative_docs = item.get('negative', [])[:self.passage_num]
if not negative_docs:
continue
sample = RGBSample(
id=item.get('id', idx),
question=item.get('query', ''),
answer=self._format_answer(item.get('answer', '')),
documents=negative_docs,
task_type=TaskType.NEGATIVE_REJECTION,
has_answer=False, # Documents don't contain the answer
raw_data=item
)
samples.append(sample)
print(f"Loaded {len(samples)} samples for Negative Rejection")
return samples
def load_information_integration(
self,
max_samples: Optional[int] = None
) -> List[RGBSample]:
"""
Load data for Information Integration evaluation.
Uses en_int.json - tests LLM's ability to integrate info from multiple docs.
Args:
max_samples: Maximum number of samples to load (None for all).
Returns:
List of RGBSample objects for information integration evaluation.
"""
filepath = self._get_file_path("en_int.json")
if not os.path.exists(filepath):
raise FileNotFoundError(
f"Dataset file not found: {filepath}\n"
"Please run: python download_datasets.py"
)
data = self._load_jsonl(filepath)
samples = []
for idx, item in enumerate(data):
if max_samples and idx >= max_samples:
break
# For information integration, we need documents from different sources
# The 'positive' field contains lists of documents for each answer component
positive_docs = item.get('positive', [])
# Flatten and get one document from each source
documents = []
if isinstance(positive_docs, list):
for doc_group in positive_docs:
if isinstance(doc_group, list) and doc_group:
documents.append(doc_group[0]) # Take first from each group
elif isinstance(doc_group, str):
documents.append(doc_group)
# Add some negative docs if needed
neg_num = max(0, self.passage_num - len(documents))
negative_docs = item.get('negative', [])[:neg_num]
documents.extend(negative_docs)
if not documents:
continue
random.shuffle(documents)
sample = RGBSample(
id=item.get('id', idx),
question=item.get('query', ''),
answer=self._format_answer(item.get('answer', '')),
documents=documents[:self.passage_num],
task_type=TaskType.INFORMATION_INTEGRATION,
num_docs_needed=len(positive_docs) if isinstance(positive_docs, list) else 1,
raw_data=item
)
samples.append(sample)
print(f"Loaded {len(samples)} samples for Information Integration")
return samples
def load_counterfactual_robustness(
self,
max_samples: Optional[int] = None
) -> List[RGBSample]:
"""
Load data for Counterfactual Robustness evaluation.
Uses en_fact.json - tests LLM's ability to detect/correct factual errors.
Args:
max_samples: Maximum number of samples to load (None for all).
Returns:
List of RGBSample objects for counterfactual robustness evaluation.
"""
filepath = self._get_file_path("en_fact.json")
if not os.path.exists(filepath):
raise FileNotFoundError(
f"Dataset file not found: {filepath}\n"
"Please run: python download_datasets.py"
)
data = self._load_jsonl(filepath)
samples = []
for idx, item in enumerate(data):
if max_samples and idx >= max_samples:
break
# For counterfactual, we use positive_wrong documents (contain fake answer)
# and can mix with some correct documents
wrong_docs = item.get('positive_wrong', [])
correct_docs = item.get('positive', [])
negative_docs = item.get('negative', [])
# Use mainly wrong docs with some negative
documents = wrong_docs[:3] + negative_docs[:2]
if not documents:
# Fallback to any available docs
documents = wrong_docs or correct_docs or negative_docs
if not documents:
continue
random.shuffle(documents)
sample = RGBSample(
id=item.get('id', idx),
question=item.get('query', ''),
answer=self._format_answer(item.get('answer', '')),
documents=documents[:self.passage_num],
task_type=TaskType.COUNTERFACTUAL_ROBUSTNESS,
has_counterfactual=True,
counterfactual_answer=self._format_answer(item.get('fakeanswer', '')),
raw_data=item
)
samples.append(sample)
print(f"Loaded {len(samples)} samples for Counterfactual Robustness")
return samples
def load_all_for_task(
self,
task_type: TaskType,
max_samples: Optional[int] = None,
**kwargs
) -> List[RGBSample]:
"""
Load data for a specific task type.
Args:
task_type: The type of evaluation task.
max_samples: Maximum samples to load.
**kwargs: Additional arguments for specific loaders.
Returns:
List of RGBSample objects.
"""
loaders = {
TaskType.NOISE_ROBUSTNESS: self.load_noise_robustness,
TaskType.NEGATIVE_REJECTION: self.load_negative_rejection,
TaskType.INFORMATION_INTEGRATION: self.load_information_integration,
TaskType.COUNTERFACTUAL_ROBUSTNESS: self.load_counterfactual_robustness,
}
return loaders[task_type](max_samples, **kwargs)
def get_dataset_stats(self) -> Dict[str, Any]:
"""Get statistics about the loaded datasets."""
stats = {}
files = {
"en_refine.json": "Noise Robustness & Negative Rejection",
"en_int.json": "Information Integration",
"en_fact.json": "Counterfactual Robustness"
}
for filename, description in files.items():
filepath = self._get_file_path(filename)
if os.path.exists(filepath):
data = self._load_jsonl(filepath)
stats[filename] = {
"description": description,
"num_samples": len(data),
"file_size_bytes": os.path.getsize(filepath)
}
else:
stats[filename] = {"error": "File not found"}
return stats
def test_loader():
"""Test the data loader with actual data."""
loader = RGBDataLoader()
print("="*60)
print("RGB Dataset Loader Test")
print("="*60)
# Get stats
stats = loader.get_dataset_stats()
print("\nDataset Statistics:")
for filename, info in stats.items():
print(f" {filename}: {info}")
# Test loading a few samples from each task
print("\n" + "-"*60)
try:
samples = loader.load_noise_robustness(max_samples=2)
if samples:
print(f"\nNoise Robustness Sample:")
print(f" Question: {samples[0].question[:80]}...")
print(f" Answer: {samples[0].answer}")
print(f" Num Docs: {len(samples[0].documents)}")
except FileNotFoundError as e:
print(f" Skipping: {e}")
try:
samples = loader.load_negative_rejection(max_samples=2)
if samples:
print(f"\nNegative Rejection Sample:")
print(f" Question: {samples[0].question[:80]}...")
print(f" Num Docs: {len(samples[0].documents)}")
except FileNotFoundError as e:
print(f" Skipping: {e}")
try:
samples = loader.load_information_integration(max_samples=2)
if samples:
print(f"\nInformation Integration Sample:")
print(f" Question: {samples[0].question[:80]}...")
print(f" Answer: {samples[0].answer}")
except FileNotFoundError as e:
print(f" Skipping: {e}")
try:
samples = loader.load_counterfactual_robustness(max_samples=2)
if samples:
print(f"\nCounterfactual Robustness Sample:")
print(f" Question: {samples[0].question[:80]}...")
print(f" Correct Answer: {samples[0].answer}")
print(f" Fake Answer: {samples[0].counterfactual_answer}")
except FileNotFoundError as e:
print(f" Skipping: {e}")
print("\n" + "="*60)
if __name__ == "__main__":
test_loader()