| """
|
| Ear Training Module for TouchGrass.
|
| Guides ear training exercises without audio, using descriptive language.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Optional, List, Dict, Tuple
|
|
|
|
|
| class EarTrainingModule(nn.Module):
|
| """
|
| Guides ear training exercises without audio.
|
|
|
| Can:
|
| - Describe interval sounds in relatable terms
|
| ("a perfect 5th sounds like the Star Wars theme opening")
|
| - Generate solfege exercises (Do Re Mi Fa Sol La Ti Do)
|
| - Create interval identification quizzes in text form
|
| - Explain chord quality by ear ("major chords sound happy/bright,
|
| minor chords sound sad/dark, diminished chords sound tense/unstable")
|
| - Guide relative pitch training
|
| - Suggest listening exercises with specific songs/moments
|
|
|
| Tracks user progress through session context.
|
| """
|
|
|
|
|
| INTERVALS = {
|
| 0: "unison",
|
| 1: "minor 2nd",
|
| 2: "major 2nd",
|
| 3: "minor 3rd",
|
| 4: "major 3rd",
|
| 5: "perfect 4th",
|
| 6: "tritone",
|
| 7: "perfect 5th",
|
| 8: "minor 6th",
|
| 9: "major 6th",
|
| 10: "minor 7th",
|
| 11: "major 7th",
|
| 12: "octave",
|
| }
|
|
|
|
|
| QUALITIES = ["perfect", "major", "minor", "augmented", "diminished"]
|
|
|
|
|
| SOLFEGE = ["Do", "Re", "Mi", "Fa", "Sol", "La", "Ti", "Do"]
|
|
|
|
|
| CHORD_DESCRIPTIONS = {
|
| "major": "bright, happy, stable",
|
| "minor": "sad, dark, melancholic",
|
| "diminished": "tense, unstable, dissonant",
|
| "augmented": "bright, dreamy, suspenseful",
|
| "dominant7": "bluesy, tense, wants to resolve",
|
| "major7": "smooth, jazzy, dreamy",
|
| "minor7": "smooth, soulful, mellow",
|
| }
|
|
|
|
|
| INTERVAL_SONGS = {
|
| 0: "any note played twice",
|
| 1: "Jaws theme (da-dum)",
|
| 2: "Happy Birthday (2nd note)",
|
| 3: "When the Saints Go Marching In (minor 3rd)",
|
| 4: "Oh When the Saints (major 3rd)",
|
| 5: "Here Comes the Bride (perfect 4th)",
|
| 6: "The Simpsons theme (tritone)",
|
| 7: "Star Wars theme (perfect 5th)",
|
| 8: "My Bonnie Lies Over the Ocean (minor 6th)",
|
| 9: "Somewhere Over the Rainbow (major 6th)",
|
| 10: "The Office theme (minor 7th)",
|
| 11: "Take On Me (major 7th)",
|
| 12: "Somewhere Over the Rainbow (octave)",
|
| }
|
|
|
| def __init__(self, d_model: int):
|
| """
|
| Initialize EarTrainingModule.
|
|
|
| Args:
|
| d_model: Hidden dimension from base model
|
| """
|
| super().__init__()
|
| self.d_model = d_model
|
|
|
|
|
| self.interval_embed = nn.Embedding(13, 64)
|
| self.quality_embed = nn.Embedding(5, 64)
|
|
|
|
|
| self.difficulty_tracker = nn.Linear(d_model, 5)
|
|
|
|
|
| self.exercise_type_head = nn.Linear(d_model, 6)
|
|
|
|
|
| self.interval_predictor = nn.Linear(d_model, 13)
|
|
|
|
|
| self.chord_quality_predictor = nn.Linear(d_model, 7)
|
|
|
|
|
| self.solfege_generator = nn.GRU(
|
| input_size=d_model + 64,
|
| hidden_size=d_model,
|
| num_layers=1,
|
| batch_first=True,
|
| )
|
|
|
|
|
| self.progress_tracker = nn.GRU(
|
| input_size=5,
|
| hidden_size=64,
|
| num_layers=1,
|
| batch_first=True,
|
| )
|
|
|
|
|
| self.success_predictor = nn.Linear(64, 1)
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| exercise_type: Optional[int] = None,
|
| user_response: Optional[str] = None,
|
| ) -> Dict[str, torch.Tensor]:
|
| """
|
| Forward pass through EarTrainingModule.
|
|
|
| Args:
|
| hidden_states: Base model hidden states [batch, seq_len, d_model]
|
| exercise_type: Optional exercise type ID (0-5)
|
| user_response: Optional user's answer for progress tracking
|
|
|
| Returns:
|
| Dictionary with ear training predictions
|
| """
|
| batch_size, seq_len, _ = hidden_states.shape
|
|
|
|
|
| pooled = hidden_states.mean(dim=1)
|
|
|
|
|
| difficulty_logits = self.difficulty_tracker(pooled)
|
|
|
|
|
| exercise_logits = self.exercise_type_head(pooled)
|
|
|
|
|
| interval_logits = self.interval_predictor(pooled)
|
|
|
|
|
| chord_quality_logits = self.chord_quality_predictor(pooled)
|
|
|
| outputs = {
|
| "difficulty_logits": difficulty_logits,
|
| "exercise_type_logits": exercise_logits,
|
| "interval_logits": interval_logits,
|
| "chord_quality_logits": chord_quality_logits,
|
| }
|
|
|
| return outputs
|
|
|
| def describe_interval(self, interval_semitones: int, reference: str = "song") -> str:
|
| """
|
| Describe an interval in relatable terms.
|
|
|
| Args:
|
| interval_semitones: Number of semitones (0-12)
|
| reference: Type of reference ("song", "emotion", "technical")
|
|
|
| Returns:
|
| Descriptive string
|
| """
|
| if interval_semitones not in self.INTERVALS:
|
| return f"Unknown interval: {interval_semitones} semitones"
|
|
|
| interval_name = self.INTERVALS[interval_semitones]
|
|
|
| if reference == "song":
|
| song = self.INTERVAL_SONGS.get(interval_semitones, "a generic interval")
|
| return f"A {interval_name} ({interval_semitones} semitones) — like {song}."
|
| elif reference == "emotion":
|
|
|
| emotion_map = {
|
| 0: "familiar, consonant",
|
| 1: "tense, dissonant",
|
| 2: "slightly tense",
|
| 3: "sad, soulful",
|
| 4: "bright, happy",
|
| 5: "stable, resolved",
|
| 6: "very tense, mysterious",
|
| 7: "strong, stable",
|
| 8: "sweet, melancholic",
|
| 9: "bright, hopeful",
|
| 10: "bluesy, tense",
|
| 11: "smooth, jazzy",
|
| 12: "complete, resolved",
|
| }
|
| emotion = emotion_map.get(interval_semitones, "neutral")
|
| return f"A {interval_name} feels {emotion}."
|
| else:
|
| return f"A {interval_name} spans {interval_semitones} semitones."
|
|
|
| def generate_solfege_exercise(
|
| self,
|
| key: str = "C",
|
| difficulty: int = 1,
|
| num_notes: int = 5,
|
| ) -> List[str]:
|
| """
|
| Generate solfege exercise.
|
|
|
| Args:
|
| key: Key signature (affects accidentals)
|
| difficulty: 1-5, higher = more accidentals, larger jumps
|
| num_notes: Number of notes in exercise
|
|
|
| Returns:
|
| List of solfege syllables
|
| """
|
| import random
|
|
|
|
|
| if difficulty <= 2:
|
|
|
| start_idx = random.randint(0, 4)
|
| exercise = []
|
| for i in range(num_notes):
|
| idx = (start_idx + i) % 7
|
| exercise.append(self.SOLFEGE[idx])
|
| return exercise
|
| else:
|
|
|
| exercise = []
|
| current = 0
|
| for _ in range(num_notes):
|
|
|
| max_jump = min(difficulty + 2, 7)
|
| jump = random.randint(-max_jump, max_jump)
|
| current = max(0, min(6, current + jump))
|
| exercise.append(self.SOLFEGE[current])
|
| return exercise
|
|
|
| def generate_interval_quiz(
|
| self,
|
| num_questions: int = 5,
|
| max_interval: int = 12,
|
| include_desc: bool = True,
|
| ) -> List[Dict]:
|
| """
|
| Generate interval identification quiz.
|
|
|
| Args:
|
| num_questions: Number of questions
|
| max_interval: Maximum interval size (up to 12)
|
| include_desc: Include descriptive hints
|
|
|
| Returns:
|
| List of quiz questions
|
| """
|
| import random
|
|
|
| questions = []
|
| for _ in range(num_questions):
|
| interval = random.randint(1, max_interval)
|
| quality = "perfect" if interval in [1, 4, 5, 8, 11, 12] else random.choice(["major", "minor"])
|
|
|
| question = {
|
| "interval_semitones": interval,
|
| "interval_name": self.INTERVALS[interval],
|
| "quality": quality,
|
| }
|
|
|
| if include_desc:
|
| question["hint"] = self.describe_interval(interval, reference="song")
|
|
|
| questions.append(question)
|
|
|
| return questions
|
|
|
| def describe_chord_quality(self, chord_type: str) -> str:
|
| """
|
| Describe how a chord quality sounds.
|
|
|
| Args:
|
| chord_type: Chord type (major, minor, etc)
|
|
|
| Returns:
|
| Descriptive string
|
| """
|
| description = self.CHORD_DESCRIPTIONS.get(chord_type, "unique sounding")
|
| return f"{chord_type} chords sound {description}."
|
|
|
| def suggest_listening_exercise(
|
| self,
|
| interval: Optional[int] = None,
|
| chord_quality: Optional[str] = None,
|
| ) -> Dict[str, str]:
|
| """
|
| Suggest specific songs/moments to listen for intervals or chords.
|
|
|
| Args:
|
| interval: Optional specific interval to practice
|
| chord_quality: Optional chord quality to practice
|
|
|
| Returns:
|
| Dictionary with listening suggestions
|
| """
|
| suggestions = {}
|
|
|
| if interval:
|
| song = self.INTERVAL_SONGS.get(interval, "various songs")
|
| suggestions["interval"] = f"Listen for {self.INTERVALS[interval]} in: {song}"
|
| suggestions["tip"] = "Try to hum along to internalize the sound."
|
|
|
| if chord_quality:
|
|
|
| examples = {
|
| "major": ["Happy Birthday", "Let It Be (chorus)"],
|
| "minor": ["House of the Rising Sun", "Greensleeves"],
|
| "diminished": ["The Simpsons theme (tritone)"],
|
| "dominant7": ["Blues progressions", "Purple Haze"],
|
| "major7": ["Something (The Beatles)", "So What (Miles Davis)"],
|
| }
|
| songs = examples.get(chord_quality, ["various songs"])
|
| suggestions["chord"] = f"Listen for {chord_quality} chords in: {', '.join(songs)}"
|
| suggestions["tip"] = "Focus on the emotional character."
|
|
|
| return suggestions
|
|
|
| def track_progress(
|
| self,
|
| exercise_history: List[Dict],
|
| current_performance: float,
|
| ) -> Dict[str, any]:
|
| """
|
| Track user's progress over session.
|
|
|
| Args:
|
| exercise_history: List of past exercises with scores
|
| current_performance: Current success rate (0-1)
|
|
|
| Returns:
|
| Progress analysis
|
| """
|
| if not exercise_history:
|
| return {"level": "beginner", "suggestion": "Start with interval identification"}
|
|
|
|
|
| avg_performance = sum(ex.get("score", 0) for ex in exercise_history) / len(exercise_history)
|
|
|
|
|
| if avg_performance < 0.5:
|
| level = "beginner"
|
| suggestion = "Practice more interval identification with smaller intervals (2nd-5th)."
|
| elif avg_performance < 0.7:
|
| level = "intermediate"
|
| suggestion = "Try more complex intervals and chord qualities."
|
| else:
|
| level = "advanced"
|
| suggestion = "Challenge yourself with inversions and advanced chords."
|
|
|
| return {
|
| "level": level,
|
| "average_score": avg_performance,
|
| "current_score": current_performance,
|
| "suggestion": suggestion,
|
| "exercises_completed": len(exercise_history),
|
| }
|
|
|
|
|
| def test_ear_training_module():
|
| """Test the EarTrainingModule."""
|
| import torch
|
|
|
|
|
| module = EarTrainingModule(d_model=4096)
|
|
|
|
|
| batch_size = 2
|
| seq_len = 10
|
| d_model = 4096
|
| hidden_states = torch.randn(batch_size, seq_len, d_model)
|
|
|
|
|
| outputs = module.forward(hidden_states)
|
|
|
| print("Ear Training Module outputs:")
|
| for key, value in outputs.items():
|
| print(f" {key}: {value.shape}")
|
|
|
|
|
| print("\nInterval descriptions:")
|
| for semitones in [3, 4, 5, 7, 10]:
|
| desc = module.describe_interval(semitones, reference="song")
|
| print(f" {semitones} semitones: {desc}")
|
|
|
|
|
| print("\nSolfege exercise (C, difficulty 2):")
|
| solfege = module.generate_solfege_exercise(key="C", difficulty=2, num_notes=8)
|
| print(f" {' '.join(solfege)}")
|
|
|
|
|
| print("\nInterval quiz (3 questions):")
|
| quiz = module.generate_interval_quiz(num_questions=3)
|
| for i, q in enumerate(quiz):
|
| print(f" Q{i+1}: {q['interval_name']} ({q['interval_semitones']} semitones)")
|
| if 'hint' in q:
|
| print(f" Hint: {q['hint']}")
|
|
|
|
|
| print("\nChord quality descriptions:")
|
| for chord in ["major", "minor", "diminished", "major7"]:
|
| desc = module.describe_chord_quality(chord)
|
| print(f" {chord}: {desc}")
|
|
|
|
|
| print("\nListening exercise suggestions:")
|
| suggestions = module.suggest_listening_exercise(interval=7, chord_quality="major")
|
| for key, value in suggestions.items():
|
| print(f" {key}: {value}")
|
|
|
|
|
| print("\nProgress tracking:")
|
| history = [
|
| {"exercise": "interval", "score": 0.6},
|
| {"exercise": "interval", "score": 0.7},
|
| {"exercise": "chord", "score": 0.5},
|
| ]
|
| progress = module.track_progress(history, current_performance=0.8)
|
| for key, value in progress.items():
|
| print(f" {key}: {value}")
|
|
|
| print("\nEar Training Module test complete!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_ear_training_module() |