Spaces:
Sleeping
Sleeping
RGB Evaluation
fix: Information Integration evaluation - handle multiple answer variants with pipe-separated format
5253a83
| """ | |
| Data Loader for RGB Dataset | |
| Handles loading and preprocessing of RGB benchmark datasets: | |
| - en_refine.json: For noise robustness and negative rejection | |
| - en_int.json: For information integration | |
| - en_fact.json: For counterfactual robustness | |
| Dataset structure (from https://github.com/chen700564/RGB): | |
| - en_refine.json: {id, query, answer, positive, negative} | |
| - en_int.json: {id, query, answer, answer1, answer2, positive, negative} | |
| - en_fact.json: {id, query, answer, fakeanswer, positive_wrong, positive, negative} | |
| """ | |
| import json | |
| import os | |
| import random | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| class TaskType(Enum): | |
| """Types of RAG evaluation tasks.""" | |
| NOISE_ROBUSTNESS = "noise_robustness" | |
| NEGATIVE_REJECTION = "negative_rejection" | |
| INFORMATION_INTEGRATION = "information_integration" | |
| COUNTERFACTUAL_ROBUSTNESS = "counterfactual_robustness" | |
| class RGBSample: | |
| """A single sample from the RGB dataset.""" | |
| id: int | |
| question: str | |
| answer: str # Ground truth answer (can be string or list) | |
| documents: List[str] # Retrieved documents/passages | |
| task_type: TaskType | |
| noise_level: Optional[int] = None # Number of noise documents | |
| has_answer: Optional[bool] = None # Whether docs contain the answer | |
| num_docs_needed: Optional[int] = None # Docs needed for answer | |
| has_counterfactual: Optional[bool] = None # Whether docs contain counterfactual | |
| counterfactual_answer: Optional[str] = None # The counterfactual (wrong) answer | |
| raw_data: Optional[Dict] = None # Original raw data | |
| class RGBDataLoader: | |
| """ | |
| Loader for RGB benchmark datasets. | |
| Implements data loading as per the RGB paper and repository. | |
| """ | |
| def __init__(self, data_dir: str = "data", passage_num: int = 5): | |
| """ | |
| Initialize the data loader. | |
| Args: | |
| data_dir: Directory containing the RGB dataset files. | |
| passage_num: Number of passages to include per sample (default 5). | |
| """ | |
| self.data_dir = data_dir | |
| self.passage_num = passage_num | |
| self._validate_data_dir() | |
| def _validate_data_dir(self) -> None: | |
| """Check if data directory exists.""" | |
| if not os.path.exists(self.data_dir): | |
| os.makedirs(self.data_dir) | |
| print(f"Created data directory: {self.data_dir}") | |
| print("Please run: python download_datasets.py") | |
| def _get_file_path(self, filename: str) -> str: | |
| """Get full path to a data file.""" | |
| return os.path.join(self.data_dir, filename) | |
| def _load_jsonl(self, filepath: str) -> List[Dict]: | |
| """Load a JSONL file (one JSON object per line).""" | |
| data = [] | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| data.append(json.loads(line)) | |
| return data | |
| def _format_answer(self, answer: Any) -> str: | |
| """ | |
| Format answer to string for comparison. | |
| For nested lists (information integration), flatten to list of alternatives. | |
| For simple lists (noise robustness), take first or join. | |
| """ | |
| if isinstance(answer, list): | |
| # Check if it's a nested list (from en_int.json with answer variants) | |
| if answer and isinstance(answer[0], list): | |
| # Flatten nested list: [['variant1', 'variant2'], 'other_answer'] → all variants | |
| variants = [] | |
| for item in answer: | |
| if isinstance(item, list): | |
| variants.extend(item) | |
| else: | |
| variants.append(str(item)) | |
| # Return as pipe-separated alternatives for matching | |
| return "|".join(variants) | |
| else: | |
| # Simple list: join with pipe as alternatives | |
| return "|".join(str(a) for a in answer) | |
| return str(answer) | |
| def load_noise_robustness( | |
| self, | |
| max_samples: Optional[int] = None, | |
| noise_rate: float = 0.4 | |
| ) -> List[RGBSample]: | |
| """ | |
| Load data for Noise Robustness evaluation. | |
| Uses en_refine.json - tests LLM's ability to handle noisy documents. | |
| Args: | |
| max_samples: Maximum number of samples to load (None for all). | |
| noise_rate: Rate of noise documents (0.0 to 0.8). | |
| Returns: | |
| List of RGBSample objects for noise robustness evaluation. | |
| """ | |
| filepath = self._get_file_path("en_refine.json") | |
| if not os.path.exists(filepath): | |
| raise FileNotFoundError( | |
| f"Dataset file not found: {filepath}\n" | |
| "Please run: python download_datasets.py" | |
| ) | |
| data = self._load_jsonl(filepath) | |
| samples = [] | |
| for idx, item in enumerate(data): | |
| if max_samples and idx >= max_samples: | |
| break | |
| # Calculate number of positive and negative documents | |
| neg_num = int(self.passage_num * noise_rate) | |
| pos_num = self.passage_num - neg_num | |
| # Get positive and negative documents | |
| positive_docs = item.get('positive', [])[:pos_num] | |
| negative_docs = item.get('negative', [])[:neg_num] | |
| # Combine and shuffle documents | |
| documents = positive_docs + negative_docs | |
| random.shuffle(documents) | |
| if not documents: | |
| continue | |
| sample = RGBSample( | |
| id=item.get('id', idx), | |
| question=item.get('query', ''), | |
| answer=self._format_answer(item.get('answer', '')), | |
| documents=documents, | |
| task_type=TaskType.NOISE_ROBUSTNESS, | |
| noise_level=neg_num, | |
| has_answer=True, | |
| raw_data=item | |
| ) | |
| samples.append(sample) | |
| print(f"Loaded {len(samples)} samples for Noise Robustness (noise_rate={noise_rate})") | |
| return samples | |
| def load_negative_rejection( | |
| self, | |
| max_samples: Optional[int] = None | |
| ) -> List[RGBSample]: | |
| """ | |
| Load data for Negative Rejection evaluation. | |
| Uses en_refine.json with noise_rate=1.0 (all negative documents). | |
| Tests LLM's ability to reject when documents don't contain the answer. | |
| Args: | |
| max_samples: Maximum number of samples to load (None for all). | |
| Returns: | |
| List of RGBSample objects for negative rejection evaluation. | |
| """ | |
| filepath = self._get_file_path("en_refine.json") | |
| if not os.path.exists(filepath): | |
| raise FileNotFoundError( | |
| f"Dataset file not found: {filepath}\n" | |
| "Please run: python download_datasets.py" | |
| ) | |
| data = self._load_jsonl(filepath) | |
| samples = [] | |
| for idx, item in enumerate(data): | |
| if max_samples and idx >= max_samples: | |
| break | |
| # For negative rejection, use only negative documents | |
| negative_docs = item.get('negative', [])[:self.passage_num] | |
| if not negative_docs: | |
| continue | |
| sample = RGBSample( | |
| id=item.get('id', idx), | |
| question=item.get('query', ''), | |
| answer=self._format_answer(item.get('answer', '')), | |
| documents=negative_docs, | |
| task_type=TaskType.NEGATIVE_REJECTION, | |
| has_answer=False, # Documents don't contain the answer | |
| raw_data=item | |
| ) | |
| samples.append(sample) | |
| print(f"Loaded {len(samples)} samples for Negative Rejection") | |
| return samples | |
| def load_information_integration( | |
| self, | |
| max_samples: Optional[int] = None | |
| ) -> List[RGBSample]: | |
| """ | |
| Load data for Information Integration evaluation. | |
| Uses en_int.json - tests LLM's ability to integrate info from multiple docs. | |
| Args: | |
| max_samples: Maximum number of samples to load (None for all). | |
| Returns: | |
| List of RGBSample objects for information integration evaluation. | |
| """ | |
| filepath = self._get_file_path("en_int.json") | |
| if not os.path.exists(filepath): | |
| raise FileNotFoundError( | |
| f"Dataset file not found: {filepath}\n" | |
| "Please run: python download_datasets.py" | |
| ) | |
| data = self._load_jsonl(filepath) | |
| samples = [] | |
| for idx, item in enumerate(data): | |
| if max_samples and idx >= max_samples: | |
| break | |
| # For information integration, we need documents from different sources | |
| # The 'positive' field contains lists of documents for each answer component | |
| positive_docs = item.get('positive', []) | |
| # Flatten and get one document from each source | |
| documents = [] | |
| if isinstance(positive_docs, list): | |
| for doc_group in positive_docs: | |
| if isinstance(doc_group, list) and doc_group: | |
| documents.append(doc_group[0]) # Take first from each group | |
| elif isinstance(doc_group, str): | |
| documents.append(doc_group) | |
| # Add some negative docs if needed | |
| neg_num = max(0, self.passage_num - len(documents)) | |
| negative_docs = item.get('negative', [])[:neg_num] | |
| documents.extend(negative_docs) | |
| if not documents: | |
| continue | |
| random.shuffle(documents) | |
| sample = RGBSample( | |
| id=item.get('id', idx), | |
| question=item.get('query', ''), | |
| answer=self._format_answer(item.get('answer', '')), | |
| documents=documents[:self.passage_num], | |
| task_type=TaskType.INFORMATION_INTEGRATION, | |
| num_docs_needed=len(positive_docs) if isinstance(positive_docs, list) else 1, | |
| raw_data=item | |
| ) | |
| samples.append(sample) | |
| print(f"Loaded {len(samples)} samples for Information Integration") | |
| return samples | |
| def load_counterfactual_robustness( | |
| self, | |
| max_samples: Optional[int] = None | |
| ) -> List[RGBSample]: | |
| """ | |
| Load data for Counterfactual Robustness evaluation. | |
| Uses en_fact.json - tests LLM's ability to detect/correct factual errors. | |
| Args: | |
| max_samples: Maximum number of samples to load (None for all). | |
| Returns: | |
| List of RGBSample objects for counterfactual robustness evaluation. | |
| """ | |
| filepath = self._get_file_path("en_fact.json") | |
| if not os.path.exists(filepath): | |
| raise FileNotFoundError( | |
| f"Dataset file not found: {filepath}\n" | |
| "Please run: python download_datasets.py" | |
| ) | |
| data = self._load_jsonl(filepath) | |
| samples = [] | |
| for idx, item in enumerate(data): | |
| if max_samples and idx >= max_samples: | |
| break | |
| # For counterfactual, we use positive_wrong documents (contain fake answer) | |
| # and can mix with some correct documents | |
| wrong_docs = item.get('positive_wrong', []) | |
| correct_docs = item.get('positive', []) | |
| negative_docs = item.get('negative', []) | |
| # Use mainly wrong docs with some negative | |
| documents = wrong_docs[:3] + negative_docs[:2] | |
| if not documents: | |
| # Fallback to any available docs | |
| documents = wrong_docs or correct_docs or negative_docs | |
| if not documents: | |
| continue | |
| random.shuffle(documents) | |
| sample = RGBSample( | |
| id=item.get('id', idx), | |
| question=item.get('query', ''), | |
| answer=self._format_answer(item.get('answer', '')), | |
| documents=documents[:self.passage_num], | |
| task_type=TaskType.COUNTERFACTUAL_ROBUSTNESS, | |
| has_counterfactual=True, | |
| counterfactual_answer=self._format_answer(item.get('fakeanswer', '')), | |
| raw_data=item | |
| ) | |
| samples.append(sample) | |
| print(f"Loaded {len(samples)} samples for Counterfactual Robustness") | |
| return samples | |
| def load_all_for_task( | |
| self, | |
| task_type: TaskType, | |
| max_samples: Optional[int] = None, | |
| **kwargs | |
| ) -> List[RGBSample]: | |
| """ | |
| Load data for a specific task type. | |
| Args: | |
| task_type: The type of evaluation task. | |
| max_samples: Maximum samples to load. | |
| **kwargs: Additional arguments for specific loaders. | |
| Returns: | |
| List of RGBSample objects. | |
| """ | |
| loaders = { | |
| TaskType.NOISE_ROBUSTNESS: self.load_noise_robustness, | |
| TaskType.NEGATIVE_REJECTION: self.load_negative_rejection, | |
| TaskType.INFORMATION_INTEGRATION: self.load_information_integration, | |
| TaskType.COUNTERFACTUAL_ROBUSTNESS: self.load_counterfactual_robustness, | |
| } | |
| return loaders[task_type](max_samples, **kwargs) | |
| def get_dataset_stats(self) -> Dict[str, Any]: | |
| """Get statistics about the loaded datasets.""" | |
| stats = {} | |
| files = { | |
| "en_refine.json": "Noise Robustness & Negative Rejection", | |
| "en_int.json": "Information Integration", | |
| "en_fact.json": "Counterfactual Robustness" | |
| } | |
| for filename, description in files.items(): | |
| filepath = self._get_file_path(filename) | |
| if os.path.exists(filepath): | |
| data = self._load_jsonl(filepath) | |
| stats[filename] = { | |
| "description": description, | |
| "num_samples": len(data), | |
| "file_size_bytes": os.path.getsize(filepath) | |
| } | |
| else: | |
| stats[filename] = {"error": "File not found"} | |
| return stats | |
| def test_loader(): | |
| """Test the data loader with actual data.""" | |
| loader = RGBDataLoader() | |
| print("="*60) | |
| print("RGB Dataset Loader Test") | |
| print("="*60) | |
| # Get stats | |
| stats = loader.get_dataset_stats() | |
| print("\nDataset Statistics:") | |
| for filename, info in stats.items(): | |
| print(f" {filename}: {info}") | |
| # Test loading a few samples from each task | |
| print("\n" + "-"*60) | |
| try: | |
| samples = loader.load_noise_robustness(max_samples=2) | |
| if samples: | |
| print(f"\nNoise Robustness Sample:") | |
| print(f" Question: {samples[0].question[:80]}...") | |
| print(f" Answer: {samples[0].answer}") | |
| print(f" Num Docs: {len(samples[0].documents)}") | |
| except FileNotFoundError as e: | |
| print(f" Skipping: {e}") | |
| try: | |
| samples = loader.load_negative_rejection(max_samples=2) | |
| if samples: | |
| print(f"\nNegative Rejection Sample:") | |
| print(f" Question: {samples[0].question[:80]}...") | |
| print(f" Num Docs: {len(samples[0].documents)}") | |
| except FileNotFoundError as e: | |
| print(f" Skipping: {e}") | |
| try: | |
| samples = loader.load_information_integration(max_samples=2) | |
| if samples: | |
| print(f"\nInformation Integration Sample:") | |
| print(f" Question: {samples[0].question[:80]}...") | |
| print(f" Answer: {samples[0].answer}") | |
| except FileNotFoundError as e: | |
| print(f" Skipping: {e}") | |
| try: | |
| samples = loader.load_counterfactual_robustness(max_samples=2) | |
| if samples: | |
| print(f"\nCounterfactual Robustness Sample:") | |
| print(f" Question: {samples[0].question[:80]}...") | |
| print(f" Correct Answer: {samples[0].answer}") | |
| print(f" Fake Answer: {samples[0].counterfactual_answer}") | |
| except FileNotFoundError as e: | |
| print(f" Skipping: {e}") | |
| print("\n" + "="*60) | |
| if __name__ == "__main__": | |
| test_loader() | |