| """ |
| Data preparation utilities for step-level reasoning chains. |
| |
| Provides tools to: |
| 1. Create reasoning chain annotations from existing VQA/reasoning datasets |
| 2. Generate synthetic step-level annotations |
| 3. Validate and convert data formats |
| """ |
|
|
| import json |
| import argparse |
| from pathlib import Path |
| from typing import List, Dict, Any, Optional |
| import logging |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| from src.reasoning.step_data import ReasoningChain, ReasoningStep, StepType |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ReasoningDataPreparator: |
| """Prepares step-level reasoning data from various sources.""" |
| |
| def __init__(self, output_dir: str): |
| """ |
| Initialize data preparator. |
| |
| Args: |
| output_dir: Directory to save prepared data |
| """ |
| self.output_dir = Path(output_dir) |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| for split in ['train', 'val', 'test']: |
| (self.output_dir / split).mkdir(exist_ok=True) |
| |
| def from_vqa_dataset( |
| self, |
| vqa_data: List[Dict[str, Any]], |
| split: str = 'train', |
| generate_steps: bool = True, |
| ) -> List[ReasoningChain]: |
| """ |
| Convert VQA dataset to reasoning chains. |
| |
| Args: |
| vqa_data: List of VQA samples with keys: image_path, question, answer |
| split: Dataset split |
| generate_steps: If True, generate intermediate reasoning steps |
| |
| Returns: |
| List of reasoning chains |
| """ |
| chains = [] |
| |
| logger.info(f"Converting {len(vqa_data)} VQA samples to reasoning chains") |
| |
| for idx, sample in enumerate(tqdm(vqa_data)): |
| chain_id = f"{split}_vqa_{idx:06d}" |
| |
| |
| if generate_steps: |
| steps = self._generate_steps_for_vqa( |
| question=sample['question'], |
| answer=sample['answer'], |
| image_path=sample.get('image_path', ''), |
| ) |
| else: |
| |
| steps = [ |
| ReasoningStep( |
| step_id=0, |
| step_type=StepType.INFERENCE, |
| description=f"Answer: {sample['answer']}", |
| confidence=1.0, |
| reward=1.0, |
| ) |
| ] |
| |
| |
| total_reward = sum(step.reward for step in steps) |
| |
| |
| chain = ReasoningChain( |
| chain_id=chain_id, |
| image_path=sample.get('image_path', ''), |
| prompt=sample['question'], |
| steps=steps, |
| final_answer=sample['answer'], |
| total_reward=total_reward, |
| is_correct=True, |
| metadata={ |
| 'source': 'vqa', |
| 'original_sample_id': sample.get('id', idx), |
| } |
| ) |
| |
| chains.append(chain) |
| |
| |
| save_path = self.output_dir / split / f"{chain_id}.json" |
| chain.save_json(str(save_path)) |
| |
| logger.info(f"Saved {len(chains)} reasoning chains to {self.output_dir / split}") |
| |
| return chains |
| |
| def _generate_steps_for_vqa( |
| self, |
| question: str, |
| answer: str, |
| image_path: str = '', |
| ) -> List[ReasoningStep]: |
| """ |
| Generate plausible reasoning steps for VQA question. |
| |
| This is a heuristic approach. For production, use a language model |
| or manual annotation. |
| |
| Args: |
| question: VQA question |
| answer: Ground truth answer |
| image_path: Path to image |
| |
| Returns: |
| List of reasoning steps |
| """ |
| steps = [] |
| |
| |
| steps.append(ReasoningStep( |
| step_id=0, |
| step_type=StepType.PERCEPTION, |
| description=f"I need to answer: {question}", |
| confidence=0.9, |
| reward=0.7, |
| dependencies=[], |
| )) |
| |
| |
| question_lower = question.lower() |
| |
| if any(word in question_lower for word in ['how many', 'count']): |
| |
| steps.append(ReasoningStep( |
| step_id=1, |
| step_type=StepType.LOCALIZATION, |
| description="I identify and locate the relevant objects", |
| confidence=0.85, |
| reward=0.75, |
| dependencies=[0], |
| )) |
| steps.append(ReasoningStep( |
| step_id=2, |
| step_type=StepType.COUNTING, |
| description=f"I count the objects and determine the answer is {answer}", |
| confidence=0.9, |
| reward=0.85, |
| dependencies=[1], |
| )) |
| |
| elif any(word in question_lower for word in ['where', 'location']): |
| |
| steps.append(ReasoningStep( |
| step_id=1, |
| step_type=StepType.LOCALIZATION, |
| description=f"I determine the location: {answer}", |
| confidence=0.88, |
| reward=0.8, |
| dependencies=[0], |
| )) |
| |
| elif any(word in question_lower for word in ['what color', 'which color']): |
| |
| steps.append(ReasoningStep( |
| step_id=1, |
| step_type=StepType.PERCEPTION, |
| description=f"I identify the color as {answer}", |
| confidence=0.92, |
| reward=0.85, |
| dependencies=[0], |
| )) |
| |
| elif any(word in question_lower for word in ['compare', 'difference', 'similar']): |
| |
| steps.append(ReasoningStep( |
| step_id=1, |
| step_type=StepType.COMPARISON, |
| description="I compare the relevant elements", |
| confidence=0.8, |
| reward=0.75, |
| dependencies=[0], |
| )) |
| steps.append(ReasoningStep( |
| step_id=2, |
| step_type=StepType.INFERENCE, |
| description=f"Based on comparison, the answer is {answer}", |
| confidence=0.85, |
| reward=0.8, |
| dependencies=[1], |
| )) |
| |
| else: |
| |
| steps.append(ReasoningStep( |
| step_id=1, |
| step_type=StepType.INFERENCE, |
| description=f"Based on the image, I conclude: {answer}", |
| confidence=0.85, |
| reward=0.8, |
| dependencies=[0], |
| )) |
| |
| |
| steps.append(ReasoningStep( |
| step_id=len(steps), |
| step_type=StepType.VERIFICATION, |
| description=f"I verify my answer: {answer}", |
| confidence=0.9, |
| reward=0.85, |
| dependencies=[len(steps) - 1], |
| )) |
| |
| return steps |
| |
| def from_gqa_dataset( |
| self, |
| gqa_data: List[Dict[str, Any]], |
| split: str = 'train', |
| ) -> List[ReasoningChain]: |
| """ |
| Convert GQA dataset (which has semantic parse) to reasoning chains. |
| |
| GQA provides structured programs that can be converted to steps. |
| |
| Args: |
| gqa_data: List of GQA samples |
| split: Dataset split |
| |
| Returns: |
| List of reasoning chains |
| """ |
| chains = [] |
| |
| logger.info(f"Converting {len(gqa_data)} GQA samples to reasoning chains") |
| |
| for idx, sample in enumerate(tqdm(gqa_data)): |
| |
| steps = [] |
| |
| if 'semantic' in sample: |
| |
| for step_idx, operation in enumerate(sample['semantic']): |
| step_type = self._map_gqa_operation_to_step_type(operation) |
| |
| steps.append(ReasoningStep( |
| step_id=step_idx, |
| step_type=step_type, |
| description=operation.get('argument', ''), |
| confidence=0.85, |
| reward=0.8, |
| dependencies=[step_idx - 1] if step_idx > 0 else [], |
| )) |
| else: |
| |
| steps = self._generate_steps_for_vqa( |
| sample['question'], |
| sample['answer'], |
| sample.get('imageId', ''), |
| ) |
| |
| chain = ReasoningChain( |
| chain_id=f"{split}_gqa_{idx:06d}", |
| image_path=sample.get('imageId', ''), |
| prompt=sample['question'], |
| steps=steps, |
| final_answer=sample['answer'], |
| total_reward=sum(s.reward for s in steps), |
| is_correct=True, |
| metadata={'source': 'gqa', 'fullAnswer': sample.get('fullAnswer', '')}, |
| ) |
| |
| chains.append(chain) |
| |
| save_path = self.output_dir / split / f"{chain.chain_id}.json" |
| chain.save_json(str(save_path)) |
| |
| return chains |
| |
| def _map_gqa_operation_to_step_type(self, operation: Dict[str, Any]) -> StepType: |
| """Map GQA semantic operation to step type.""" |
| op = operation.get('operation', '').lower() |
| |
| if 'select' in op or 'relate' in op: |
| return StepType.LOCALIZATION |
| elif 'query' in op: |
| return StepType.PERCEPTION |
| elif 'verify' in op or 'choose' in op: |
| return StepType.VERIFICATION |
| elif 'and' in op or 'or' in op: |
| return StepType.COMPOSITION |
| else: |
| return StepType.INFERENCE |
| |
| def validate_dataset(self, split: str = 'train') -> Dict[str, Any]: |
| """ |
| Validate reasoning chain dataset. |
| |
| Args: |
| split: Dataset split to validate |
| |
| Returns: |
| Validation statistics |
| """ |
| split_dir = self.output_dir / split |
| chain_files = list(split_dir.glob('*.json')) |
| |
| logger.info(f"Validating {len(chain_files)} chains in {split} split") |
| |
| stats = { |
| 'num_chains': len(chain_files), |
| 'num_steps_total': 0, |
| 'avg_steps_per_chain': 0, |
| 'step_types': {}, |
| 'errors': [], |
| } |
| |
| for chain_file in tqdm(chain_files): |
| try: |
| chain = ReasoningChain.load_json(str(chain_file)) |
| |
| |
| if not chain.steps: |
| stats['errors'].append(f"{chain_file.name}: No steps") |
| |
| if not chain.final_answer: |
| stats['errors'].append(f"{chain_file.name}: No final answer") |
| |
| |
| stats['num_steps_total'] += len(chain.steps) |
| |
| |
| for step in chain.steps: |
| step_type = step.step_type.value |
| stats['step_types'][step_type] = stats['step_types'].get(step_type, 0) + 1 |
| |
| except Exception as e: |
| stats['errors'].append(f"{chain_file.name}: {str(e)}") |
| |
| stats['avg_steps_per_chain'] = stats['num_steps_total'] / stats['num_chains'] if stats['num_chains'] > 0 else 0 |
| |
| logger.info(f"Validation complete: {stats}") |
| |
| return stats |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Prepare step-level reasoning data") |
| parser.add_argument("--source", type=str, required=True, choices=['vqa', 'gqa', 'custom']) |
| parser.add_argument("--input", type=str, required=True, help="Input data file (JSON)") |
| parser.add_argument("--output_dir", type=str, required=True, help="Output directory") |
| parser.add_argument("--split", type=str, default='train', choices=['train', 'val', 'test']) |
| parser.add_argument("--validate", action='store_true', help="Validate after conversion") |
| args = parser.parse_args() |
| |
| |
| preparator = ReasoningDataPreparator(args.output_dir) |
| |
| |
| logger.info(f"Loading data from {args.input}") |
| with open(args.input, 'r') as f: |
| input_data = json.load(f) |
| |
| |
| if args.source == 'vqa': |
| chains = preparator.from_vqa_dataset(input_data, split=args.split) |
| elif args.source == 'gqa': |
| chains = preparator.from_gqa_dataset(input_data, split=args.split) |
| else: |
| logger.error(f"Unsupported source: {args.source}") |
| return |
| |
| logger.info(f"Converted {len(chains)} samples to reasoning chains") |
| |
| |
| if args.validate: |
| stats = preparator.validate_dataset(args.split) |
| logger.info(f"Validation results: {stats}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|