dei-model / utils /prepare_reasoning_data.py
renpas22
Add utils directory
da76488
"""
Data preparation utilities for step-level reasoning chains.
Provides tools to:
1. Create reasoning chain annotations from existing VQA/reasoning datasets
2. Generate synthetic step-level annotations
3. Validate and convert data formats
"""
import json
import argparse
from pathlib import Path
from typing import List, Dict, Any, Optional
import logging
from PIL import Image
from tqdm import tqdm
from src.reasoning.step_data import ReasoningChain, ReasoningStep, StepType
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ReasoningDataPreparator:
"""Prepares step-level reasoning data from various sources."""
def __init__(self, output_dir: str):
"""
Initialize data preparator.
Args:
output_dir: Directory to save prepared data
"""
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# Create split directories
for split in ['train', 'val', 'test']:
(self.output_dir / split).mkdir(exist_ok=True)
def from_vqa_dataset(
self,
vqa_data: List[Dict[str, Any]],
split: str = 'train',
generate_steps: bool = True,
) -> List[ReasoningChain]:
"""
Convert VQA dataset to reasoning chains.
Args:
vqa_data: List of VQA samples with keys: image_path, question, answer
split: Dataset split
generate_steps: If True, generate intermediate reasoning steps
Returns:
List of reasoning chains
"""
chains = []
logger.info(f"Converting {len(vqa_data)} VQA samples to reasoning chains")
for idx, sample in enumerate(tqdm(vqa_data)):
chain_id = f"{split}_vqa_{idx:06d}"
# Generate reasoning steps
if generate_steps:
steps = self._generate_steps_for_vqa(
question=sample['question'],
answer=sample['answer'],
image_path=sample.get('image_path', ''),
)
else:
# Single step (direct answer)
steps = [
ReasoningStep(
step_id=0,
step_type=StepType.INFERENCE,
description=f"Answer: {sample['answer']}",
confidence=1.0,
reward=1.0,
)
]
# Calculate total reward
total_reward = sum(step.reward for step in steps)
# Create chain
chain = ReasoningChain(
chain_id=chain_id,
image_path=sample.get('image_path', ''),
prompt=sample['question'],
steps=steps,
final_answer=sample['answer'],
total_reward=total_reward,
is_correct=True, # Assume correct if from ground truth
metadata={
'source': 'vqa',
'original_sample_id': sample.get('id', idx),
}
)
chains.append(chain)
# Save chain
save_path = self.output_dir / split / f"{chain_id}.json"
chain.save_json(str(save_path))
logger.info(f"Saved {len(chains)} reasoning chains to {self.output_dir / split}")
return chains
def _generate_steps_for_vqa(
self,
question: str,
answer: str,
image_path: str = '',
) -> List[ReasoningStep]:
"""
Generate plausible reasoning steps for VQA question.
This is a heuristic approach. For production, use a language model
or manual annotation.
Args:
question: VQA question
answer: Ground truth answer
image_path: Path to image
Returns:
List of reasoning steps
"""
steps = []
# Step 1: Perception (understand the question)
steps.append(ReasoningStep(
step_id=0,
step_type=StepType.PERCEPTION,
description=f"I need to answer: {question}",
confidence=0.9,
reward=0.7,
dependencies=[],
))
# Step 2: Type-specific reasoning
question_lower = question.lower()
if any(word in question_lower for word in ['how many', 'count']):
# Counting question
steps.append(ReasoningStep(
step_id=1,
step_type=StepType.LOCALIZATION,
description="I identify and locate the relevant objects",
confidence=0.85,
reward=0.75,
dependencies=[0],
))
steps.append(ReasoningStep(
step_id=2,
step_type=StepType.COUNTING,
description=f"I count the objects and determine the answer is {answer}",
confidence=0.9,
reward=0.85,
dependencies=[1],
))
elif any(word in question_lower for word in ['where', 'location']):
# Localization question
steps.append(ReasoningStep(
step_id=1,
step_type=StepType.LOCALIZATION,
description=f"I determine the location: {answer}",
confidence=0.88,
reward=0.8,
dependencies=[0],
))
elif any(word in question_lower for word in ['what color', 'which color']):
# Color perception
steps.append(ReasoningStep(
step_id=1,
step_type=StepType.PERCEPTION,
description=f"I identify the color as {answer}",
confidence=0.92,
reward=0.85,
dependencies=[0],
))
elif any(word in question_lower for word in ['compare', 'difference', 'similar']):
# Comparison
steps.append(ReasoningStep(
step_id=1,
step_type=StepType.COMPARISON,
description="I compare the relevant elements",
confidence=0.8,
reward=0.75,
dependencies=[0],
))
steps.append(ReasoningStep(
step_id=2,
step_type=StepType.INFERENCE,
description=f"Based on comparison, the answer is {answer}",
confidence=0.85,
reward=0.8,
dependencies=[1],
))
else:
# General inference
steps.append(ReasoningStep(
step_id=1,
step_type=StepType.INFERENCE,
description=f"Based on the image, I conclude: {answer}",
confidence=0.85,
reward=0.8,
dependencies=[0],
))
# Final verification step
steps.append(ReasoningStep(
step_id=len(steps),
step_type=StepType.VERIFICATION,
description=f"I verify my answer: {answer}",
confidence=0.9,
reward=0.85,
dependencies=[len(steps) - 1],
))
return steps
def from_gqa_dataset(
self,
gqa_data: List[Dict[str, Any]],
split: str = 'train',
) -> List[ReasoningChain]:
"""
Convert GQA dataset (which has semantic parse) to reasoning chains.
GQA provides structured programs that can be converted to steps.
Args:
gqa_data: List of GQA samples
split: Dataset split
Returns:
List of reasoning chains
"""
chains = []
logger.info(f"Converting {len(gqa_data)} GQA samples to reasoning chains")
for idx, sample in enumerate(tqdm(gqa_data)):
# GQA has semantic_parse which can map to reasoning steps
steps = []
if 'semantic' in sample:
# Parse semantic structure into steps
for step_idx, operation in enumerate(sample['semantic']):
step_type = self._map_gqa_operation_to_step_type(operation)
steps.append(ReasoningStep(
step_id=step_idx,
step_type=step_type,
description=operation.get('argument', ''),
confidence=0.85,
reward=0.8,
dependencies=[step_idx - 1] if step_idx > 0 else [],
))
else:
# Fallback to generated steps
steps = self._generate_steps_for_vqa(
sample['question'],
sample['answer'],
sample.get('imageId', ''),
)
chain = ReasoningChain(
chain_id=f"{split}_gqa_{idx:06d}",
image_path=sample.get('imageId', ''),
prompt=sample['question'],
steps=steps,
final_answer=sample['answer'],
total_reward=sum(s.reward for s in steps),
is_correct=True,
metadata={'source': 'gqa', 'fullAnswer': sample.get('fullAnswer', '')},
)
chains.append(chain)
save_path = self.output_dir / split / f"{chain.chain_id}.json"
chain.save_json(str(save_path))
return chains
def _map_gqa_operation_to_step_type(self, operation: Dict[str, Any]) -> StepType:
"""Map GQA semantic operation to step type."""
op = operation.get('operation', '').lower()
if 'select' in op or 'relate' in op:
return StepType.LOCALIZATION
elif 'query' in op:
return StepType.PERCEPTION
elif 'verify' in op or 'choose' in op:
return StepType.VERIFICATION
elif 'and' in op or 'or' in op:
return StepType.COMPOSITION
else:
return StepType.INFERENCE
def validate_dataset(self, split: str = 'train') -> Dict[str, Any]:
"""
Validate reasoning chain dataset.
Args:
split: Dataset split to validate
Returns:
Validation statistics
"""
split_dir = self.output_dir / split
chain_files = list(split_dir.glob('*.json'))
logger.info(f"Validating {len(chain_files)} chains in {split} split")
stats = {
'num_chains': len(chain_files),
'num_steps_total': 0,
'avg_steps_per_chain': 0,
'step_types': {},
'errors': [],
}
for chain_file in tqdm(chain_files):
try:
chain = ReasoningChain.load_json(str(chain_file))
# Validate structure
if not chain.steps:
stats['errors'].append(f"{chain_file.name}: No steps")
if not chain.final_answer:
stats['errors'].append(f"{chain_file.name}: No final answer")
# Count steps
stats['num_steps_total'] += len(chain.steps)
# Count step types
for step in chain.steps:
step_type = step.step_type.value
stats['step_types'][step_type] = stats['step_types'].get(step_type, 0) + 1
except Exception as e:
stats['errors'].append(f"{chain_file.name}: {str(e)}")
stats['avg_steps_per_chain'] = stats['num_steps_total'] / stats['num_chains'] if stats['num_chains'] > 0 else 0
logger.info(f"Validation complete: {stats}")
return stats
def main():
parser = argparse.ArgumentParser(description="Prepare step-level reasoning data")
parser.add_argument("--source", type=str, required=True, choices=['vqa', 'gqa', 'custom'])
parser.add_argument("--input", type=str, required=True, help="Input data file (JSON)")
parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
parser.add_argument("--split", type=str, default='train', choices=['train', 'val', 'test'])
parser.add_argument("--validate", action='store_true', help="Validate after conversion")
args = parser.parse_args()
# Initialize preparator
preparator = ReasoningDataPreparator(args.output_dir)
# Load input data
logger.info(f"Loading data from {args.input}")
with open(args.input, 'r') as f:
input_data = json.load(f)
# Convert based on source
if args.source == 'vqa':
chains = preparator.from_vqa_dataset(input_data, split=args.split)
elif args.source == 'gqa':
chains = preparator.from_gqa_dataset(input_data, split=args.split)
else:
logger.error(f"Unsupported source: {args.source}")
return
logger.info(f"Converted {len(chains)} samples to reasoning chains")
# Validate if requested
if args.validate:
stats = preparator.validate_dataset(args.split)
logger.info(f"Validation results: {stats}")
if __name__ == "__main__":
main()