| | """ |
| | Memory Training Module for MangoMAS Local |
| | |
| | This module implements specialized training for memory and context retention capabilities, |
| | adapted from the AWS backup system for local training. |
| | """ |
| |
|
| | import json |
| | import logging |
| | import os |
| | import random |
| | from typing import Any, Dict, List |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset |
| |
|
| | from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class MemoryDataset(Dataset): |
| | """Dataset for training memory and context retention capabilities.""" |
| |
|
| | def __init__(self, data_path: str, tokenizer, max_length: int = 1024): |
| | """ |
| | Initialize the memory dataset. |
| | |
| | Args: |
| | data_path: Path to the memory training data file |
| | tokenizer: Tokenizer for text processing |
| | max_length: Maximum sequence length |
| | """ |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| | self.data = self._load_data(data_path) |
| |
|
| | logger.info(f"Loaded memory dataset with {len(self.data)} examples") |
| |
|
| | def _load_data(self, data_path: str) -> List[Dict]: |
| | """Load memory training data.""" |
| | data = [] |
| | with open(data_path, "r", encoding="utf-8") as f: |
| | for line in f: |
| | try: |
| | item = json.loads(line.strip()) |
| | |
| | if "conversation" in item and isinstance( |
| | item["conversation"], list |
| | ): |
| | data.append(item) |
| | except json.JSONDecodeError: |
| | continue |
| | return data |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | item = self.data[idx] |
| |
|
| | |
| | conversation = item["conversation"] |
| | context = "\n".join( |
| | [f"{turn['role']}: {turn['content']}" for turn in conversation[:-1]] |
| | ) |
| | target = conversation[-1]["content"] |
| |
|
| | prompt = f"Context:\n{context}\nResponse: {target}" |
| |
|
| | |
| | encoding = self.tokenizer( |
| | prompt, |
| | max_length=self.max_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| |
|
| | return { |
| | "input_ids": encoding["input_ids"].squeeze(), |
| | "attention_mask": encoding["attention_mask"].squeeze(), |
| | "labels": encoding["input_ids"].squeeze(), |
| | } |
| |
|
| |
|
| | class MemoryTrainingModule(SpecializedTrainingModule): |
| | """Specialized training module for memory and context retention capabilities.""" |
| |
|
| | def __init__(self, config: TrainingModuleConfig, tokenizer): |
| | """ |
| | Initialize the memory training module. |
| | |
| | Args: |
| | config: Module configuration |
| | tokenizer: Tokenizer for text processing |
| | """ |
| | super().__init__(config, tokenizer) |
| |
|
| | |
| | self.memory_loss = nn.CrossEntropyLoss(ignore_index=-100) |
| | self.metrics = { |
| | "memory_loss": 0.0, |
| | "context_retention": 0.0, |
| | "coherence_score": 0.0, |
| | } |
| |
|
| | logger.info("Initialized MemoryTrainingModule") |
| |
|
| | def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| | """ |
| | Prepare a batch of data for memory training. |
| | |
| | Args: |
| | batch: The input batch from the dataloader |
| | |
| | Returns: |
| | Processed batch ready for memory training |
| | """ |
| | |
| | prepared_batch = {} |
| | for key, value in batch.items(): |
| | if isinstance(value, torch.Tensor): |
| | prepared_batch[key] = value.to(self.device) |
| | else: |
| | prepared_batch[key] = value |
| |
|
| | return prepared_batch |
| |
|
| | def compute_loss( |
| | self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor] |
| | ) -> torch.Tensor: |
| | """ |
| | Compute the memory-specific loss. |
| | |
| | Args: |
| | student_outputs: Outputs from the student model |
| | teacher_outputs: Outputs from the teacher model |
| | batch: The processed input batch |
| | |
| | Returns: |
| | Loss tensor for memory training |
| | """ |
| | try: |
| | |
| | if hasattr(student_outputs, "logits"): |
| | student_logits = student_outputs.logits |
| | else: |
| | student_logits = student_outputs |
| |
|
| | if hasattr(teacher_outputs, "logits"): |
| | teacher_logits = teacher_outputs.logits |
| | else: |
| | teacher_logits = teacher_outputs |
| |
|
| | |
| | labels = batch.get("labels", batch.get("input_ids")) |
| |
|
| | |
| | shift_logits = student_logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| |
|
| | memory_loss = self.memory_loss( |
| | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
| | ) |
| |
|
| | |
| | if teacher_logits is not None: |
| | kl_loss = F.kl_div( |
| | F.log_softmax(student_logits, dim=-1), |
| | F.softmax(teacher_logits, dim=-1), |
| | reduction="batchmean", |
| | ) |
| | total_loss = memory_loss + 0.1 * kl_loss |
| | else: |
| | total_loss = memory_loss |
| |
|
| | |
| | self.metrics["memory_loss"] = memory_loss.item() |
| |
|
| | return total_loss * self.loss_weight |
| |
|
| | except Exception as e: |
| | logger.error(f"Error computing memory loss: {e}") |
| | |
| | return torch.tensor(0.01, requires_grad=True) |
| |
|
| | def get_metrics(self) -> Dict[str, float]: |
| | """ |
| | Get metrics specific to memory training. |
| | |
| | Returns: |
| | Dictionary of memory metrics |
| | """ |
| | return self.metrics.copy() |
| |
|
| | def generate_synthetic_memory_data( |
| | self, output_path: str, num_samples: int = 1000 |
| | ) -> None: |
| | """ |
| | Generate synthetic memory training data. |
| | |
| | Args: |
| | output_path: Path to save the generated data |
| | num_samples: Number of samples to generate |
| | """ |
| | |
| | |
| |
|
| | conversation_templates = [ |
| | [ |
| | { |
| | "role": "user", |
| | "content": "Hi, my name is Alex and I'm interested in machine learning.", |
| | }, |
| | { |
| | "role": "assistant", |
| | "content": "Hello Alex! I'd be happy to discuss machine learning with you. What aspects are you most interested in?", |
| | }, |
| | { |
| | "role": "user", |
| | "content": "I'm particularly interested in natural language processing.", |
| | }, |
| | { |
| | "role": "assistant", |
| | "content": "NLP is a fascinating field! It's used for tasks like translation, summarization, and question answering.", |
| | }, |
| | { |
| | "role": "user", |
| | "content": "What do you think would be a good first project?", |
| | }, |
| | { |
| | "role": "assistant", |
| | "content": "For a beginner in NLP, I'd recommend starting with a text classification project, like sentiment analysis.", |
| | }, |
| | ], |
| | [ |
| | { |
| | "role": "user", |
| | "content": "I'm planning a trip to Japan next spring.", |
| | }, |
| | { |
| | "role": "assistant", |
| | "content": "That sounds exciting! Japan is beautiful in spring with cherry blossoms. What cities are you planning to visit?", |
| | }, |
| | { |
| | "role": "user", |
| | "content": "I'm thinking Tokyo, Kyoto, and maybe Osaka.", |
| | }, |
| | { |
| | "role": "assistant", |
| | "content": "Great choices! Tokyo has modern attractions, Kyoto has historical temples, and Osaka is known for amazing food.", |
| | }, |
| | { |
| | "role": "user", |
| | "content": "What's the best way to travel between these cities?", |
| | }, |
| | { |
| | "role": "assistant", |
| | "content": "The Shinkansen (bullet train) is the most efficient way to travel between these cities. It's fast, comfortable, and reliable.", |
| | }, |
| | ], |
| | ] |
| |
|
| | recall_templates = [ |
| | { |
| | "recall_context": "what was my name again?", |
| | "recall_target": "Your name is Alex, as you mentioned at the beginning of our conversation.", |
| | }, |
| | { |
| | "recall_context": "which cities did I say I wanted to visit?", |
| | "recall_target": "You mentioned you're planning to visit Tokyo, Kyoto, and possibly Osaka during your trip to Japan.", |
| | }, |
| | ] |
| |
|
| | |
| | output_data = [] |
| | for _ in range(num_samples): |
| | template_idx = random.randint(0, len(conversation_templates) - 1) |
| | conversation = conversation_templates[template_idx].copy() |
| |
|
| | |
| | if template_idx < len(recall_templates): |
| | recall_template = recall_templates[template_idx] |
| |
|
| | |
| | conversation.append( |
| | {"role": "user", "content": recall_template["recall_context"]} |
| | ) |
| |
|
| | |
| | example = { |
| | "conversation": conversation, |
| | "recall_context": recall_template["recall_context"], |
| | "recall_target": recall_template["recall_target"], |
| | "metadata": {"generated": True, "requires_memory": True}, |
| | } |
| | else: |
| | |
| | example = { |
| | "conversation": conversation, |
| | "metadata": {"generated": True, "requires_memory": False}, |
| | } |
| |
|
| | output_data.append(example) |
| |
|
| | |
| | os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| | with open(output_path, "w", encoding="utf-8") as f: |
| | for item in output_data: |
| | f.write(json.dumps(item) + "\n") |
| |
|
| | logger.info( |
| | f"Generated {len(output_data)} synthetic memory examples at {output_path}" |
| | ) |
| |
|