| """
|
| Chat Formatter for TouchGrass.
|
| Formats data into chat format compatible with Qwen3.5 fine-tuning.
|
| """
|
|
|
| from typing import List, Dict, Any, Optional
|
| import json
|
| from pathlib import Path
|
|
|
|
|
| class ChatFormatter:
|
| """
|
| Formats music QA data into chat format for instruction tuning.
|
|
|
| Handles:
|
| - System prompt injection
|
| - Context tags (instrument, skill level, emotion)
|
| - Tokenization-ready format
|
| - Multi-turn conversations
|
| """
|
|
|
| def __init__(
|
| self,
|
| tokenizer=None,
|
| max_seq_length: int = 4096,
|
| system_prompt: Optional[str] = None,
|
| ):
|
| """
|
| Initialize chat formatter.
|
|
|
| Args:
|
| tokenizer: Optional tokenizer for length validation
|
| max_seq_length: Maximum sequence length
|
| system_prompt: Optional custom system prompt
|
| """
|
| self.tokenizer = tokenizer
|
| self.max_seq_length = max_seq_length
|
|
|
| self.default_system_prompt = system_prompt or self._get_default_system_prompt()
|
|
|
| def _get_default_system_prompt(self) -> str:
|
| """Get default system prompt."""
|
| return """You are Touch Grass 🌿, a warm, encouraging, and knowledgeable music assistant.
|
|
|
| You help people with:
|
| - Learning instruments (guitar, bass, piano, keys, drums, vocals)
|
| - Understanding music theory at any level
|
| - Writing songs (lyrics, chord progressions, structure)
|
| - Ear training and developing musicality
|
| - DJ skills and music production
|
| - Genre knowledge and music history
|
|
|
| Your personality:
|
| - Patient and encouraging — learning music is hard and takes time
|
| - Adapt to the learner's level automatically — simpler for beginners, deeper for advanced
|
| - When someone is frustrated, acknowledge it warmly before helping
|
| - Use tabs, chord diagrams, and notation when helpful
|
| - Make learning fun, not intimidating
|
| - Celebrate small wins
|
|
|
| When generating tabs use this format:
|
| [TAB]
|
| e|---------|
|
| B|---------|
|
| G|---------|
|
| D|---------|
|
| A|---------|
|
| E|---------|
|
| [/TAB]
|
|
|
| When showing chord progressions use: [PROGRESSION]I - IV - V - I[/PROGRESSION]"""
|
|
|
| def format_qa_pair(
|
| self,
|
| question: str,
|
| answer: str,
|
| context: Optional[str] = None,
|
| system_prompt: Optional[str] = None,
|
| ) -> Dict[str, Any]:
|
| """
|
| Format a single QA pair into chat format.
|
|
|
| Args:
|
| question: User question
|
| answer: Assistant answer
|
| context: Optional context tags (e.g., "[GUITAR][BEGINNER]")
|
| system_prompt: Optional system prompt override
|
|
|
| Returns:
|
| Formatted chat dictionary
|
| """
|
| system = system_prompt or self.default_system_prompt
|
|
|
|
|
| user_message = question
|
| if context:
|
| user_message = f"{context} {question}".strip()
|
|
|
| messages = [
|
| {"role": "system", "content": system},
|
| {"role": "user", "content": user_message},
|
| {"role": "assistant", "content": answer},
|
| ]
|
|
|
|
|
| if self.tokenizer:
|
| total_length = self._estimate_length(messages)
|
| if total_length > self.max_seq_length:
|
| print(f"Warning: Sample exceeds max length ({total_length} > {self.max_seq_length})")
|
|
|
| messages = self._truncate_answers(messages)
|
|
|
| return {"messages": messages}
|
|
|
| def format_multi_turn(
|
| self,
|
| conversations: List[Dict[str, str]],
|
| system_prompt: Optional[str] = None,
|
| ) -> Dict[str, Any]:
|
| """
|
| Format multi-turn conversation.
|
|
|
| Args:
|
| conversations: List of {"role": "...", "content": "..."} dicts
|
| system_prompt: Optional system prompt
|
|
|
| Returns:
|
| Formatted chat dictionary
|
| """
|
| system = system_prompt or self.default_system_prompt
|
|
|
|
|
| if conversations[0]["role"] != "system":
|
| messages = [{"role": "system", "content": system}] + conversations
|
| else:
|
| messages = conversations
|
|
|
|
|
| if self.tokenizer:
|
| total_length = self._estimate_length(messages)
|
| if total_length > self.max_seq_length:
|
| print(f"Warning: Multi-turn sample exceeds max length ({total_length} > {self.max_seq_length})")
|
| messages = self._truncate_multi_turn(messages)
|
|
|
| return {"messages": messages}
|
|
|
| def _estimate_length(self, messages: List[Dict[str, str]]) -> int:
|
| """Estimate token length of messages."""
|
| if not self.tokenizer:
|
| return 0
|
|
|
| total = 0
|
| for msg in messages:
|
| tokens = self.tokenizer.encode(msg["content"])
|
| total += len(tokens["input_ids"])
|
| return total
|
|
|
| def _truncate_answers(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
| """Truncate answer to fit max length."""
|
| if not self.tokenizer:
|
| return messages
|
|
|
| system_len = self._estimate_length([messages[0]])
|
| user_len = self._estimate_length([messages[1]])
|
| available = self.max_seq_length - system_len - user_len - 10
|
|
|
|
|
| answer_msg = messages[2].copy()
|
| answer_tokens = self.tokenizer.encode(answer_msg["content"])
|
| if len(answer_tokens["input_ids"]) > available:
|
|
|
| truncated = self.tokenizer.decode(answer_tokens["input_ids"][:available-3])
|
| answer_msg["content"] = truncated + "..."
|
| messages[2] = answer_msg
|
|
|
| return messages
|
|
|
| def _truncate_multi_turn(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
| """Truncate multi-turn conversation from the end."""
|
| if not self.tokenizer:
|
| return messages
|
|
|
|
|
| system_msg = messages[0]
|
| other_msgs = messages[1:]
|
|
|
| current_length = self._estimate_length([system_msg])
|
| kept_msgs = []
|
|
|
| for msg in other_msgs:
|
| msg_len = self._estimate_length([msg])
|
| if current_length + msg_len <= self.max_seq_length - 10:
|
| kept_msgs.append(msg)
|
| current_length += msg_len
|
| else:
|
| break
|
|
|
| return [system_msg] + kept_msgs
|
|
|
| def save_as_jsonl(
|
| self,
|
| samples: List[Dict[str, Any]],
|
| output_path: str,
|
| ):
|
| """
|
| Save formatted samples as JSONL.
|
|
|
| Args:
|
| samples: List of formatted samples
|
| output_path: Output file path
|
| """
|
| output_path = Path(output_path)
|
| output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
| with open(output_path, "w", encoding="utf-8") as f:
|
| for sample in samples:
|
| f.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
|
|
| print(f"Saved {len(samples)} samples to {output_path}")
|
|
|
| def load_from_jsonl(
|
| self,
|
| input_path: str,
|
| ) -> List[Dict[str, Any]]:
|
| """
|
| Load formatted samples from JSONL.
|
|
|
| Args:
|
| input_path: Input file path
|
|
|
| Returns:
|
| List of samples
|
| """
|
| samples = []
|
| with open(input_path, "r", encoding="utf-8") as f:
|
| for line in f:
|
| samples.append(json.loads(line))
|
|
|
| print(f"Loaded {len(samples)} samples from {input_path}")
|
| return samples
|
|
|
| def validate_sample(
|
| self,
|
| sample: Dict[str, Any],
|
| ) -> bool:
|
| """
|
| Validate a formatted sample.
|
|
|
| Args:
|
| sample: Sample to validate
|
|
|
| Returns:
|
| True if valid
|
| """
|
| if "messages" not in sample:
|
| print("Error: Missing 'messages' field")
|
| return False
|
|
|
| messages = sample["messages"]
|
| if len(messages) < 2:
|
| print("Error: At least 2 messages required (system + user)")
|
| return False
|
|
|
| if messages[0]["role"] != "system":
|
| print("Error: First message must be system")
|
| return False
|
|
|
|
|
| for i in range(1, len(messages), 2):
|
| if messages[i]["role"] != "user":
|
| print(f"Error: Expected user at position {i}, got {messages[i]['role']}")
|
| return False
|
| if i + 1 < len(messages) and messages[i + 1]["role"] != "assistant":
|
| print(f"Error: Expected assistant at position {i+1}, got {messages[i+1]['role']}")
|
| return False
|
|
|
| return True
|
|
|
| def create_pretraining_dataset(
|
| self,
|
| qa_samples: List[Dict[str, Any]],
|
| output_dir: str,
|
| train_split: float = 0.9,
|
| ) -> Dict[str, str]:
|
| """
|
| Create train/val splits for fine-tuning.
|
|
|
| Args:
|
| qa_samples: List of QA samples
|
| output_dir: Output directory
|
| train_split: Train split ratio (0-1)
|
|
|
| Returns:
|
| Dictionary with train/val file paths
|
| """
|
| import random
|
| random.shuffle(qa_samples)
|
|
|
| split_idx = int(len(qa_samples) * train_split)
|
| train_samples = qa_samples[:split_idx]
|
| val_samples = qa_samples[split_idx:]
|
|
|
| output_dir = Path(output_dir)
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| train_path = output_dir / "train.jsonl"
|
| val_path = output_dir / "val.jsonl"
|
|
|
| self.save_as_jsonl(train_samples, str(train_path))
|
| self.save_as_jsonl(val_samples, str(val_path))
|
|
|
| print(f"Created splits: train={len(train_samples)}, val={len(val_samples)}")
|
|
|
| return {
|
| "train": str(train_path),
|
| "val": str(val_path),
|
| }
|
|
|
|
|
| def test_chat_formatter():
|
| """Test the ChatFormatter."""
|
|
|
| formatter = ChatFormatter()
|
|
|
| print("Testing ChatFormatter...\n")
|
|
|
|
|
| qa = formatter.format_qa_pair(
|
| question="How do I play a G chord?",
|
| answer="[TAB]...[/TAB] Here's how...",
|
| context="[GUITAR][BEGINNER]",
|
| )
|
|
|
| print("Formatted QA pair:")
|
| for msg in qa["messages"]:
|
| print(f" {msg['role']}: {msg['content'][:80]}...")
|
|
|
|
|
| is_valid = formatter.validate_sample(qa)
|
| print(f"\nSample valid: {is_valid}")
|
|
|
|
|
| multi_turn = formatter.format_multi_turn([
|
| {"role": "user", "content": "What is a chord?"},
|
| {"role": "assistant", "content": "A chord is..."},
|
| {"role": "user", "content": "Can you give an example?"},
|
| {"role": "assistant", "content": "C major is C-E-G"},
|
| ])
|
|
|
| print("\nMulti-turn format:")
|
| for msg in multi_turn["messages"]:
|
| print(f" {msg['role']}: {msg['content'][:60]}...")
|
|
|
| print("\nChatFormatter test complete!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_chat_formatter() |