|
""" |
|
Domain Dataset Module for Cross-Domain Uncertainty Quantification |
|
|
|
This module provides functionality for loading and managing datasets from different domains |
|
for evaluating uncertainty quantification methods across domains. |
|
""" |
|
|
|
import os |
|
import json |
|
import pandas as pd |
|
import numpy as np |
|
from typing import List, Dict, Any, Union, Optional, Tuple |
|
from datasets import load_dataset |
|
|
|
class DomainDataset: |
|
"""Base class for domain-specific datasets.""" |
|
|
|
def __init__(self, name: str, domain: str): |
|
""" |
|
Initialize the domain dataset. |
|
|
|
Args: |
|
name: Name of the dataset |
|
domain: Domain category (e.g., 'medical', 'legal', 'general') |
|
""" |
|
self.name = name |
|
self.domain = domain |
|
self.data = None |
|
|
|
def load(self) -> None: |
|
"""Load the dataset.""" |
|
raise NotImplementedError("Subclasses must implement this method") |
|
|
|
def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]: |
|
""" |
|
Get samples from the dataset. |
|
|
|
Args: |
|
n: Number of samples to return (None for all) |
|
|
|
Returns: |
|
List of samples with prompts and expected outputs |
|
""" |
|
raise NotImplementedError("Subclasses must implement this method") |
|
|
|
def get_prompt_template(self) -> str: |
|
""" |
|
Get the prompt template for this domain. |
|
|
|
Returns: |
|
Prompt template string |
|
""" |
|
raise NotImplementedError("Subclasses must implement this method") |
|
|
|
|
|
class MedicalQADataset(DomainDataset): |
|
"""Dataset for medical question answering.""" |
|
|
|
def __init__(self, data_path: Optional[str] = None): |
|
""" |
|
Initialize the medical QA dataset. |
|
|
|
Args: |
|
data_path: Path to the dataset file (None to use default) |
|
""" |
|
super().__init__("medical_qa", "medical") |
|
self.data_path = data_path |
|
|
|
def load(self) -> None: |
|
"""Load the medical QA dataset.""" |
|
if self.data_path and os.path.exists(self.data_path): |
|
|
|
if self.data_path.endswith('.csv'): |
|
self.data = pd.read_csv(self.data_path) |
|
elif self.data_path.endswith('.json'): |
|
with open(self.data_path, 'r') as f: |
|
self.data = json.load(f) |
|
else: |
|
raise ValueError(f"Unsupported file format: {self.data_path}") |
|
else: |
|
|
|
try: |
|
dataset = load_dataset("medmcqa", split="train[:100]") |
|
self.data = dataset.to_pandas() |
|
except Exception as e: |
|
|
|
print(f"Failed to load MedMCQA dataset: {e}") |
|
self.data = self._create_synthetic_data() |
|
|
|
def _create_synthetic_data(self) -> pd.DataFrame: |
|
"""Create synthetic medical QA data for testing.""" |
|
questions = [ |
|
"What are the common symptoms of myocardial infarction?", |
|
"How does insulin regulate blood glucose levels?", |
|
"What is the mechanism of action for ACE inhibitors?", |
|
"What are the diagnostic criteria for rheumatoid arthritis?", |
|
"How does the SARS-CoV-2 virus enter human cells?", |
|
"What are the main side effects of chemotherapy?", |
|
"How does the blood-brain barrier function?", |
|
"What is the pathophysiology of type 2 diabetes?", |
|
"How do vaccines create immunity?", |
|
"What are the stages of chronic kidney disease?" |
|
] |
|
|
|
|
|
return pd.DataFrame({ |
|
'question': questions, |
|
'domain': ['medical'] * len(questions) |
|
}) |
|
|
|
def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]: |
|
""" |
|
Get samples from the medical QA dataset. |
|
|
|
Args: |
|
n: Number of samples to return (None for all) |
|
|
|
Returns: |
|
List of samples with prompts |
|
""" |
|
if self.data is None: |
|
self.load() |
|
|
|
if 'question' in self.data.columns: |
|
questions = self.data['question'].tolist() |
|
elif 'question_text' in self.data.columns: |
|
questions = self.data['question_text'].tolist() |
|
else: |
|
raise ValueError("Dataset does not contain question column") |
|
|
|
if n is not None: |
|
questions = questions[:n] |
|
|
|
|
|
samples = [] |
|
for question in questions: |
|
prompt = self.get_prompt_template().format(question=question) |
|
samples.append({ |
|
'domain': 'medical', |
|
'question': question, |
|
'prompt': prompt |
|
}) |
|
|
|
return samples |
|
|
|
def get_prompt_template(self) -> str: |
|
""" |
|
Get the prompt template for medical domain. |
|
|
|
Returns: |
|
Prompt template string |
|
""" |
|
return "You are a medical expert. Please answer the following medical question accurately and concisely:\n\n{question}" |
|
|
|
|
|
class LegalQADataset(DomainDataset): |
|
"""Dataset for legal question answering.""" |
|
|
|
def __init__(self, data_path: Optional[str] = None): |
|
""" |
|
Initialize the legal QA dataset. |
|
|
|
Args: |
|
data_path: Path to the dataset file (None to use default) |
|
""" |
|
super().__init__("legal_qa", "legal") |
|
self.data_path = data_path |
|
|
|
def load(self) -> None: |
|
"""Load the legal QA dataset.""" |
|
if self.data_path and os.path.exists(self.data_path): |
|
|
|
if self.data_path.endswith('.csv'): |
|
self.data = pd.read_csv(self.data_path) |
|
elif self.data_path.endswith('.json'): |
|
with open(self.data_path, 'r') as f: |
|
self.data = json.load(f) |
|
else: |
|
raise ValueError(f"Unsupported file format: {self.data_path}") |
|
else: |
|
|
|
self.data = self._create_synthetic_data() |
|
|
|
def _create_synthetic_data(self) -> pd.DataFrame: |
|
"""Create synthetic legal QA data for testing.""" |
|
questions = [ |
|
"What constitutes a breach of contract?", |
|
"How is intellectual property protected under international law?", |
|
"What are the elements of negligence in tort law?", |
|
"How does the doctrine of stare decisis function in common law systems?", |
|
"What rights are protected under the Fourth Amendment?", |
|
"What is the difference between a patent and a copyright?", |
|
"How does arbitration differ from litigation?", |
|
"What constitutes insider trading under securities law?", |
|
"What are the legal requirements for a valid will?", |
|
"How does diplomatic immunity work under international law?" |
|
] |
|
|
|
|
|
return pd.DataFrame({ |
|
'question': questions, |
|
'domain': ['legal'] * len(questions) |
|
}) |
|
|
|
def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]: |
|
""" |
|
Get samples from the legal QA dataset. |
|
|
|
Args: |
|
n: Number of samples to return (None for all) |
|
|
|
Returns: |
|
List of samples with prompts |
|
""" |
|
if self.data is None: |
|
self.load() |
|
|
|
questions = self.data['question'].tolist() |
|
|
|
if n is not None: |
|
questions = questions[:n] |
|
|
|
|
|
samples = [] |
|
for question in questions: |
|
prompt = self.get_prompt_template().format(question=question) |
|
samples.append({ |
|
'domain': 'legal', |
|
'question': question, |
|
'prompt': prompt |
|
}) |
|
|
|
return samples |
|
|
|
def get_prompt_template(self) -> str: |
|
""" |
|
Get the prompt template for legal domain. |
|
|
|
Returns: |
|
Prompt template string |
|
""" |
|
return "You are a legal expert. Please answer the following legal question accurately and concisely:\n\n{question}" |
|
|
|
|
|
class GeneralKnowledgeDataset(DomainDataset): |
|
"""Dataset for general knowledge question answering.""" |
|
|
|
def __init__(self, data_path: Optional[str] = None): |
|
""" |
|
Initialize the general knowledge dataset. |
|
|
|
Args: |
|
data_path: Path to the dataset file (None to use default) |
|
""" |
|
super().__init__("general_knowledge", "general") |
|
self.data_path = data_path |
|
|
|
def load(self) -> None: |
|
"""Load the general knowledge dataset.""" |
|
if self.data_path and os.path.exists(self.data_path): |
|
|
|
if self.data_path.endswith('.csv'): |
|
self.data = pd.read_csv(self.data_path) |
|
elif self.data_path.endswith('.json'): |
|
with open(self.data_path, 'r') as f: |
|
self.data = json.load(f) |
|
else: |
|
raise ValueError(f"Unsupported file format: {self.data_path}") |
|
else: |
|
|
|
try: |
|
dataset = load_dataset("trivia_qa", "unfiltered", split="train[:100]") |
|
self.data = dataset.to_pandas() |
|
except Exception as e: |
|
|
|
print(f"Failed to load TriviaQA dataset: {e}") |
|
self.data = self._create_synthetic_data() |
|
|
|
def _create_synthetic_data(self) -> pd.DataFrame: |
|
"""Create synthetic general knowledge data for testing.""" |
|
questions = [ |
|
"What is the capital of France?", |
|
"Who wrote the novel '1984'?", |
|
"What is the chemical symbol for gold?", |
|
"Which planet is known as the Red Planet?", |
|
"Who painted the Mona Lisa?", |
|
"What is the largest ocean on Earth?", |
|
"What year did World War II end?", |
|
"What is the tallest mountain in the world?", |
|
"Who was the first person to step on the moon?", |
|
"What is the speed of light in a vacuum?" |
|
] |
|
|
|
|
|
return pd.DataFrame({ |
|
'question': questions, |
|
'domain': ['general'] * len(questions) |
|
}) |
|
|
|
def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]: |
|
""" |
|
Get samples from the general knowledge dataset. |
|
|
|
Args: |
|
n: Number of samples to return (None for all) |
|
|
|
Returns: |
|
List of samples with prompts |
|
""" |
|
if self.data is None: |
|
self.load() |
|
|
|
if 'question' in self.data.columns: |
|
questions = self.data['question'].tolist() |
|
elif 'question_text' in self.data.columns: |
|
questions = self.data['question_text'].tolist() |
|
else: |
|
raise ValueError("Dataset does not contain question column") |
|
|
|
if n is not None: |
|
questions = questions[:n] |
|
|
|
|
|
samples = [] |
|
for question in questions: |
|
prompt = self.get_prompt_template().format(question=question) |
|
samples.append({ |
|
'domain': 'general', |
|
'question': question, |
|
'prompt': prompt |
|
}) |
|
|
|
return samples |
|
|
|
def get_prompt_template(self) -> str: |
|
""" |
|
Get the prompt template for general knowledge domain. |
|
|
|
Returns: |
|
Prompt template string |
|
""" |
|
return "Please answer the following general knowledge question accurately and concisely:\n\n{question}" |
|
|
|
|
|
|
|
def create_domain_dataset(domain: str, data_path: Optional[str] = None) -> DomainDataset: |
|
""" |
|
Create a domain dataset based on the specified domain. |
|
|
|
Args: |
|
domain: Domain category ('medical', 'legal', 'general') |
|
data_path: Path to the dataset file (None to use default) |
|
|
|
Returns: |
|
Domain dataset instance |
|
""" |
|
if domain == "medical": |
|
return MedicalQADataset(data_path) |
|
elif domain == "legal": |
|
return LegalQADataset(data_path) |
|
elif domain == "general": |
|
return GeneralKnowledgeDataset(data_path) |
|
else: |
|
raise ValueError(f"Unsupported domain: {domain}") |
|
|