import os import re import json from pathlib import Path from typing import List, Dict, Optional, Any from dataclasses import dataclass, field from logger_config import config_logger logger = config_logger(__name__) @dataclass class TaskmasterDialogue: conversation_id: str instruction_id: Optional[str] scenario: Optional[str] domain: Optional[str] turns: List[Dict[str, Any]] original_metadata: Dict[str, Any] = field(default_factory=dict) def __str__(self): return f"TaskmasterDialogue(conversation_id={self.conversation_id}, turns={len(self.turns)} turns)" def validate(self) -> bool: return bool(self.conversation_id and isinstance(self.turns, list)) class RawDataProcessingConfig: """ Simple config for raw dataset processing """ def __init__( self, debug: bool = True, max_length: int = 512, min_turns: int = 4, min_user_words: int = 3 ): self.debug = debug self.max_length = max_length self.min_turns = min_turns self.min_user_words = min_user_words class TaskmasterProcessor: """ Load Taskmaster-1 dialogues, extracts domain. Clean, filter, save to pipeline format. """ def __init__(self, config: RawDataProcessingConfig): self.config = config def load_taskmaster_dataset( self, base_dir: str, max_examples: Optional[int] = None ) -> List[TaskmasterDialogue]: """ Load & parse Taskmaster-1 JSON for self-dialogs & woz-dialogs. """ required_files = { "self-dialogs": "self-dialogs.json", "woz-dialogs": "woz-dialogs.json", "ontology": "ontology.json", } # Check for missing files missing = [k for k, v in required_files.items() if not Path(base_dir, v).exists()] if missing: raise FileNotFoundError(f"Missing Taskmaster files: {missing}") # Load ontology ontology_path = Path(base_dir, required_files["ontology"]) with open(ontology_path, 'r', encoding='utf-8') as f: ontology = json.load(f) if self.config.debug: logger.info(f"[TaskmasterProcessor] Loaded ontology with {len(ontology.keys())} top-level keys (unused).") dialogues: List[TaskmasterDialogue] = [] # Process each file file_keys = ["self-dialogs", "woz-dialogs"] for file_key in file_keys: file_path = Path(base_dir, required_files[file_key]) with open(file_path, 'r', encoding='utf-8') as f: raw_data = json.load(f) for d in raw_data: conversation_id = d.get("conversation_id", "") instruction_id = d.get("instruction_id", None) scenario_text = d.get("scenario", "") # Handle utterances utterances = d.get("utterances", []) turns = self._process_utterances(utterances) # Detect Domain domain = self._extract_domain(scenario_text, turns) # Build the object new_dlg = TaskmasterDialogue( conversation_id=conversation_id, instruction_id=instruction_id, scenario=scenario_text, domain=domain, turns=turns, original_metadata={} ) dialogues.append(new_dlg) if max_examples and len(dialogues) >= max_examples: break if self.config.debug: logger.info(f"[TaskmasterProcessor] Loaded {len(dialogues)} total dialogues from Taskmaster-1.") return dialogues def _extract_domain(self, scenario: str, turns: List[Dict[str, str]]) -> str: """ Combine scenario text + all turn texts to detect domain more robustly. """ combined_text = scenario.lower() for turn in turns: txt = turn.get('text', '').lower() combined_text += " " + txt # Domain patterns domain_patterns = { 'restaurant': r'\b(restaurant|dining|food|reservation|table|menu|cuisine|eat|hungry)\b', 'movie': r'\b(movie|cinema|film|ticket|showtime|theater|flick|screening)\b', 'ride_share': r'\b(ride|taxi|uber|lyft|car\s?service|pickup|dropoff|driver)\b', 'coffee': r'\b(coffee|café|cafe|starbucks|espresso|latte|mocha|americano)\b', 'pizza': r'\b(pizza|delivery|order\s?food|pepperoni|topping|pizzeria|slice)\b', 'auto': r'\b(car|vehicle|repair|maintenance|mechanic|oil\s?change)\b' } for domain, pattern in domain_patterns.items(): if re.search(pattern, combined_text): # Optional: logger.info if debug if self.config.debug: logger.info(f"Matched domain: {domain} in scenario/turns") return domain if self.config.debug: logger.info("No domain match, returning 'other'") return 'other' def _clean_text(self, text: str) -> str: """ Simple text normalization """ # Strip multiple spaces, remove unnecessary punctuation text = re.sub(r'\s+', ' ', text) text = re.sub(r'([!?.,])\1+', r'\1', text) return text.strip() def _is_numeric_line(self, text: str) -> bool: """ Return True if line is purely digits/punctuation/spaces, e.g. "4 3 13" and similar found in Taskmaster-1 dataset. """ pattern = r'^[\s]*[\d]+([\s\d.,]+)*[\s]*$' return bool(re.match(pattern, text)) def filter_and_convert(self, dialogues: List[TaskmasterDialogue]) -> List[Dict]: """ Filter out dialogues that don't meet min length requirements. Convert to pipeline format. { "dialogue_id": "...", "domain": "...", "turns": [ {"speaker": "user", "text": "..."}, ... ] } """ total = len(dialogues) invalid = 0 too_few_turns = 0 short_user_turns = 0 results = [] for dlg in dialogues: if not dlg.validate(): invalid += 1 continue # Skip if too few turns if len(dlg.turns) < self.config.min_turns: too_few_turns += 1 continue # Skip if any user turn is too short keep = True for turn in dlg.turns: if turn['speaker'] == 'user': words_count = len(turn['text'].split()) if words_count < self.config.min_user_words: short_user_turns += 1 keep = False break if not keep: continue pipeline_dlg = { 'dialogue_id': dlg.conversation_id, 'domain': dlg.domain, 'turns': dlg.turns } results.append(pipeline_dlg) if self.config.debug: logger.info(f"\nFiltering Statistics:") logger.info(f"Total dialogues: {total}") logger.info(f"Invalid dialogues: {invalid}") logger.info(f"Too few turns: {too_few_turns}") logger.info(f"Short user turns: {short_user_turns}") logger.info(f"Remaining dialogues: {len(results)}") logger.info(f"Filtering rate: {((total - len(results)) / total) * 100:.1f}%\n") return results def _process_utterances(self, utterances: List[Dict[str, Any]]) -> List[Dict[str, str]]: """Added logging to track utterance filtering""" total = len(utterances) empty = 0 numeric = 0 too_short = 0 cleaned_turns = [] for utt in utterances: speaker = 'assistant' if utt.get('speaker') == 'ASSISTANT' else 'user' raw_text = utt.get('text', '').strip() text = self._clean_text(raw_text) if not text: empty += 1 continue if self._is_numeric_line(text): numeric += 1 continue if len(text.split()) < 3: too_short += 1 continue cleaned_turns.append({ 'speaker': speaker, 'text': text }) if self.config.debug and total > 0: logger.info(f"\nUtterance Cleaning Statistics (Dialogue {utterances[0].get('conversation_id', 'unknown')}):") logger.info(f"Total utterances: {total}") logger.info(f"Empty/blank: {empty}") logger.info(f"Numeric only: {numeric}") logger.info(f"Too short (<3 words): {too_short}") logger.info(f"Remaining turns: {len(cleaned_turns)}") logger.info(f"Filtering rate: {((total - len(cleaned_turns)) / total) * 100:.1f}%\n") return cleaned_turns