| """ |
| Reasoning Training Module for MangoMAS Local |
| |
| This module implements specialized training for reasoning capabilities, |
| adapted from the AWS backup system for local training. |
| """ |
|
|
| import json |
| import logging |
| import os |
| import random |
| import re |
| from typing import Any, Dict, List |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset |
|
|
| from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ReasoningDataset(Dataset): |
| """Dataset for training reasoning capabilities.""" |
|
|
| def __init__(self, data_path: str, tokenizer, max_length: int = 512): |
| """ |
| Initialize the reasoning dataset. |
| |
| Args: |
| data_path: Path to the reasoning data file |
| tokenizer: Tokenizer for text processing |
| max_length: Maximum sequence length |
| """ |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.data = self._load_data(data_path) |
|
|
| logger.info(f"Loaded reasoning dataset with {len(self.data)} examples") |
|
|
| def _load_data(self, data_path: str) -> List[Dict]: |
| """Load reasoning training data.""" |
| data = [] |
| with open(data_path, "r", encoding="utf-8") as f: |
| for line in f: |
| try: |
| item = json.loads(line.strip()) |
| |
| if "question" in item and "reasoning" in item and "answer" in item: |
| data.append(item) |
| except json.JSONDecodeError: |
| continue |
| return data |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| item = self.data[idx] |
|
|
| |
| prompt = f"Question: {item['question']}\nReasoning: {item['reasoning']}\nAnswer: {item['answer']}" |
|
|
| |
| encoding = self.tokenizer( |
| prompt, |
| max_length=self.max_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt", |
| ) |
|
|
| return { |
| "input_ids": encoding["input_ids"].squeeze(), |
| "attention_mask": encoding["attention_mask"].squeeze(), |
| "labels": encoding["input_ids"].squeeze(), |
| } |
|
|
|
|
| class ReasoningEvaluator: |
| """Evaluator for reasoning capabilities.""" |
|
|
| def __init__(self, tokenizer): |
| """ |
| Initialize the reasoning evaluator. |
| |
| Args: |
| tokenizer: Tokenizer for text processing |
| """ |
| self.tokenizer = tokenizer |
| self.metrics = { |
| "logical_consistency": 0.0, |
| "premise_relevance": 0.0, |
| "conclusion_validity": 0.0, |
| "steps_coherence": 0.0, |
| } |
|
|
| def evaluate(self, model, eval_dataset: ReasoningDataset) -> Dict[str, float]: |
| """ |
| Evaluate reasoning capabilities on the provided dataset. |
| |
| Args: |
| model: The model to evaluate |
| eval_dataset: Dataset of reasoning examples |
| |
| Returns: |
| Dictionary of evaluation metrics |
| """ |
| model.eval() |
| device = next(model.parameters()).device |
|
|
| |
| for key in self.metrics: |
| self.metrics[key] = 0.0 |
|
|
| total_examples = min( |
| len(eval_dataset), 100 |
| ) |
|
|
| with torch.no_grad(): |
| for idx in range(total_examples): |
| example = eval_dataset[idx] |
| premise = example["premise"] |
|
|
| |
| prompt = f"Premise: {premise}\nReasoning:" |
| input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( |
| device |
| ) |
|
|
| generated_ids = model.generate( |
| input_ids, max_length=512, temperature=0.7, num_return_sequences=1 |
| ) |
|
|
| generated_text = self.tokenizer.decode( |
| generated_ids[0], skip_special_tokens=True |
| ) |
|
|
| |
| try: |
| generated_reasoning = re.search( |
| r"Reasoning:(.*?)(?:Conclusion:|$)", generated_text, re.DOTALL |
| ) |
| generated_conclusion = re.search( |
| r"Conclusion:(.*?)$", generated_text, re.DOTALL |
| ) |
|
|
| if generated_reasoning: |
| gen_reasoning = generated_reasoning.group(1).strip() |
| else: |
| gen_reasoning = "" |
|
|
| if generated_conclusion: |
| gen_conclusion = generated_conclusion.group(1).strip() |
| else: |
| gen_conclusion = "" |
|
|
| |
| self._update_metrics( |
| premise=premise, |
| expected_reasoning=example["reasoning"], |
| expected_conclusion=example["conclusion"], |
| generated_reasoning=gen_reasoning, |
| generated_conclusion=gen_conclusion, |
| ) |
| except Exception as e: |
| logger.error(f"Error evaluating reasoning: {e}") |
|
|
| |
| for key in self.metrics: |
| self.metrics[key] /= total_examples |
|
|
| return self.metrics |
|
|
| def _update_metrics( |
| self, |
| premise: str, |
| expected_reasoning: str, |
| expected_conclusion: str, |
| generated_reasoning: str, |
| generated_conclusion: str, |
| ) -> None: |
| """ |
| Update reasoning metrics based on a single example. |
| |
| Args: |
| premise: Input premise |
| expected_reasoning: Expected reasoning steps |
| expected_conclusion: Expected conclusion |
| generated_reasoning: Generated reasoning steps |
| generated_conclusion: Generated conclusion |
| """ |
| |
| |
|
|
| |
| self.metrics["logical_consistency"] += 0.5 |
|
|
| |
| premise_terms = set(premise.lower().split()) |
| reasoning_terms = set(generated_reasoning.lower().split()) |
| term_overlap = len(premise_terms.intersection(reasoning_terms)) / max( |
| len(premise_terms), 1 |
| ) |
| self.metrics["premise_relevance"] += term_overlap |
|
|
| |
| if generated_conclusion and "therefore" in generated_conclusion.lower(): |
| self.metrics["conclusion_validity"] += 0.7 |
| else: |
| self.metrics["conclusion_validity"] += 0.3 |
|
|
| |
| flow_markers = [ |
| "first", |
| "second", |
| "third", |
| "then", |
| "next", |
| "finally", |
| "because", |
| "thus", |
| "hence", |
| ] |
| marker_count = sum( |
| 1 for marker in flow_markers if marker in generated_reasoning.lower() |
| ) |
| self.metrics["steps_coherence"] += min(1.0, marker_count / 3) |
|
|
|
|
| class ReasoningTrainingModule(SpecializedTrainingModule): |
| """Specialized training module for reasoning capabilities.""" |
|
|
| def __init__(self, config: TrainingModuleConfig, tokenizer): |
| """ |
| Initialize the reasoning training module. |
| |
| Args: |
| config: Module configuration |
| tokenizer: Tokenizer for text processing |
| """ |
| super().__init__(config, tokenizer) |
|
|
| |
| self.reasoning_loss = nn.CrossEntropyLoss(ignore_index=-100) |
| self.metrics = { |
| "reasoning_loss": 0.0, |
| "reasoning_accuracy": 0.0, |
| "reasoning_perplexity": 0.0, |
| } |
|
|
| logger.info("Initialized ReasoningTrainingModule") |
|
|
| def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """ |
| Prepare a batch of data for reasoning training. |
| |
| Args: |
| batch: The input batch from the dataloader |
| |
| Returns: |
| Processed batch ready for reasoning training |
| """ |
| |
| prepared_batch = {} |
| for key, value in batch.items(): |
| if isinstance(value, torch.Tensor): |
| prepared_batch[key] = value.to(self.device) |
| else: |
| prepared_batch[key] = value |
|
|
| return prepared_batch |
|
|
| def compute_loss( |
| self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor] |
| ) -> torch.Tensor: |
| """ |
| Compute the reasoning-specific loss. |
| |
| Args: |
| student_outputs: Outputs from the student model |
| teacher_outputs: Outputs from the teacher model |
| batch: The processed input batch |
| |
| Returns: |
| Loss tensor for reasoning training |
| """ |
| try: |
| |
| if hasattr(student_outputs, "logits"): |
| student_logits = student_outputs.logits |
| else: |
| student_logits = student_outputs |
|
|
| if hasattr(teacher_outputs, "logits"): |
| teacher_logits = teacher_outputs.logits |
| else: |
| teacher_logits = teacher_outputs |
|
|
| |
| labels = batch.get("labels", batch.get("input_ids")) |
|
|
| |
| shift_logits = student_logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
|
|
| reasoning_loss = self.reasoning_loss( |
| shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
| ) |
|
|
| |
| if teacher_logits is not None: |
| kl_loss = F.kl_div( |
| F.log_softmax(student_logits, dim=-1), |
| F.softmax(teacher_logits, dim=-1), |
| reduction="batchmean", |
| ) |
| total_loss = reasoning_loss + 0.1 * kl_loss |
| else: |
| total_loss = reasoning_loss |
|
|
| |
| self.metrics["reasoning_loss"] = reasoning_loss.item() |
|
|
| return total_loss * self.loss_weight |
|
|
| except Exception as e: |
| logger.error(f"Error computing reasoning loss: {e}") |
| |
| return torch.tensor(0.01, requires_grad=True) |
|
|
| def get_metrics(self) -> Dict[str, float]: |
| """ |
| Get metrics specific to reasoning training. |
| |
| Returns: |
| Dictionary of reasoning metrics |
| """ |
| return self.metrics.copy() |
|
|
| def generate_synthetic_reasoning_data( |
| self, output_path: str, num_samples: int = 1000 |
| ) -> None: |
| """ |
| Generate synthetic reasoning data for training. |
| |
| Args: |
| output_path: Path to save the generated data |
| num_samples: Number of samples to generate |
| """ |
| |
| |
|
|
| templates = [ |
| { |
| "premise": "If it rains, the ground gets wet. It is raining now.", |
| "reasoning": "Since it is raining, and rain makes the ground wet, we can conclude that the ground is getting wet.", |
| "conclusion": "Therefore, the ground is wet.", |
| }, |
| { |
| "premise": "All mammals are warm-blooded. Whales are mammals.", |
| "reasoning": "Whales are classified as mammals. All mammals are warm-blooded animals. Therefore, as a mammal, a whale must be warm-blooded.", |
| "conclusion": "Therefore, whales are warm-blooded.", |
| }, |
| { |
| "premise": "If you study hard, you will pass the exam. You studied hard.", |
| "reasoning": "The premise states a conditional relationship between studying hard and passing the exam. Since you studied hard, the condition is met.", |
| "conclusion": "Therefore, you will pass the exam.", |
| }, |
| ] |
|
|
| |
| output_data = [] |
| for _ in range(num_samples): |
| template = random.choice(templates) |
|
|
| |
| variation = { |
| "premise": template["premise"], |
| "reasoning": template["reasoning"], |
| "conclusion": template["conclusion"], |
| "metadata": { |
| "generated": True, |
| "timestamp": str( |
| torch.cuda.get_device_name(0) |
| if torch.cuda.is_available() |
| else "CPU" |
| ), |
| }, |
| } |
|
|
| output_data.append(variation) |
|
|
| |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| with open(output_path, "w", encoding="utf-8") as f: |
| for item in output_data: |
| f.write(json.dumps(item) + "\n") |
|
|
| logger.info( |
| f"Generated {len(output_data)} synthetic reasoning examples at {output_path}" |
| ) |
|
|