| """
|
| Tab & Chord Generation Module for TouchGrass.
|
| Generates guitar tabs, chord diagrams, and validates musical correctness.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Optional, Tuple, List, Dict
|
|
|
|
|
| class TabChordModule(nn.Module):
|
| """
|
| Generates and validates guitar tabs and chord diagrams.
|
|
|
| Features:
|
| - Generates ASCII tablature for guitar, bass, ukulele
|
| - Creates chord diagrams in standard format
|
| - Validates musical correctness (fret ranges, string counts)
|
| - Difficulty-aware: suggests easier voicings for beginners
|
| - Supports multiple tunings
|
| """
|
|
|
|
|
| STANDARD_TUNING = ["E2", "A2", "D3", "G3", "B3", "E4"]
|
| BASS_TUNING = ["E1", "A1", "D2", "G2"]
|
| UKULELE_TUNING = ["G4", "C4", "E4", "A4"]
|
| DROP_D_TUNING = ["D2", "A2", "D3", "G3", "B3", "E4"]
|
| OPEN_G_TUNING = ["D2", "G2", "D3", "G3", "B3", "D4"]
|
|
|
|
|
| MAX_FRET = 24
|
| OPEN_FRET = 0
|
| MUTED_FRET = -1
|
|
|
| def __init__(self, d_model: int, num_strings: int = 6, num_frets: int = 24):
|
| """
|
| Initialize TabChordModule.
|
|
|
| Args:
|
| d_model: Hidden dimension from base model
|
| num_strings: Number of strings (6 for guitar, 4 for bass)
|
| num_frets: Number of frets (typically 24)
|
| """
|
| super().__init__()
|
| self.d_model = d_model
|
| self.num_strings = num_strings
|
| self.num_frets = num_frets
|
|
|
|
|
| self.string_embed = nn.Embedding(num_strings, 64)
|
| self.fret_embed = nn.Embedding(num_frets + 2, 64)
|
|
|
|
|
| self.tab_validator = nn.Sequential(
|
| nn.Linear(d_model, 128),
|
| nn.ReLU(),
|
| nn.Linear(128, 1),
|
| nn.Sigmoid()
|
| )
|
|
|
|
|
| self.difficulty_head = nn.Linear(d_model, 3)
|
|
|
|
|
| self.instrument_embed = nn.Embedding(8, 64)
|
|
|
|
|
| self.fret_predictor = nn.Linear(d_model + 128, num_frets + 2)
|
|
|
|
|
| self.tab_generator = nn.GRU(
|
| input_size=d_model + 64,
|
| hidden_size=d_model,
|
| num_layers=1,
|
| batch_first=True,
|
| )
|
|
|
|
|
| self.chord_quality_head = nn.Linear(d_model, 8)
|
|
|
|
|
| self.root_note_head = nn.Linear(d_model, 12)
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| instrument: str = "guitar",
|
| skill_level: str = "intermediate",
|
| generate_tab: bool = False,
|
| ) -> Dict[str, torch.Tensor]:
|
| """
|
| Forward pass through TabChordModule.
|
|
|
| Args:
|
| hidden_states: Base model hidden states [batch, seq_len, d_model]
|
| instrument: Instrument type ("guitar", "bass", "ukulele")
|
| skill_level: "beginner", "intermediate", or "advanced"
|
| generate_tab: Whether to generate tab sequences
|
|
|
| Returns:
|
| Dictionary with tab_validity, difficulty_logits, fret_predictions, etc.
|
| """
|
| batch_size, seq_len, _ = hidden_states.shape
|
|
|
|
|
| pooled = hidden_states.mean(dim=1)
|
|
|
|
|
| tab_validity = self.tab_validator(pooled)
|
|
|
|
|
| difficulty_logits = self.difficulty_head(pooled)
|
|
|
|
|
| chord_quality_logits = self.chord_quality_head(pooled)
|
| root_note_logits = self.root_note_head(pooled)
|
|
|
| outputs = {
|
| "tab_validity": tab_validity,
|
| "difficulty_logits": difficulty_logits,
|
| "chord_quality_logits": chord_quality_logits,
|
| "root_note_logits": root_note_logits,
|
| }
|
|
|
| if generate_tab:
|
|
|
| tab_seq = self._generate_tab_sequence(hidden_states, instrument)
|
| outputs["tab_sequence"] = tab_seq
|
|
|
| return outputs
|
|
|
| def _generate_tab_sequence(
|
| self,
|
| hidden_states: torch.Tensor,
|
| instrument: str,
|
| max_length: int = 100,
|
| ) -> torch.Tensor:
|
| """
|
| Generate tab sequence using GRU decoder.
|
|
|
| Args:
|
| hidden_states: Base model hidden states
|
| instrument: Instrument type
|
| max_length: Maximum tab sequence length
|
|
|
| Returns:
|
| Generated tab token sequence
|
| """
|
| batch_size, seq_len, d_model = hidden_states.shape
|
|
|
|
|
| instrument_idx = self._instrument_to_idx(instrument)
|
| instrument_emb = self.instrument_embed(
|
| torch.tensor([instrument_idx], device=hidden_states.device)
|
| ).unsqueeze(0).expand(batch_size, -1)
|
|
|
|
|
| h0 = hidden_states.mean(dim=1, keepdim=True).transpose(0, 1)
|
|
|
|
|
| generated = []
|
| input_emb = hidden_states[:, 0:1, :]
|
|
|
| for _ in range(max_length):
|
|
|
| input_with_instr = torch.cat([input_emb, instrument_emb.unsqueeze(1)], dim=2)
|
|
|
|
|
| output, h0 = self.tab_generator(input_with_instr, h0)
|
|
|
|
|
| fret_logits = self.fret_predictor(output)
|
| next_token = fret_logits.argmax(dim=-1)
|
|
|
| generated.append(next_token.squeeze(1))
|
|
|
|
|
| input_emb = self.fret_embed(next_token)
|
|
|
| return torch.stack(generated, dim=1)
|
|
|
| def _instrument_to_idx(self, instrument: str) -> int:
|
| """Convert instrument name to index."""
|
| mapping = {
|
| "guitar": 0,
|
| "bass": 1,
|
| "ukulele": 2,
|
| "piano": 3,
|
| "drums": 4,
|
| "vocals": 5,
|
| "theory": 6,
|
| "dj": 7,
|
| }
|
| return mapping.get(instrument, 0)
|
|
|
| def validate_tab(
|
| self,
|
| tab_strings: List[List[str]],
|
| instrument: str = "guitar",
|
| ) -> Tuple[bool, List[str]]:
|
| """
|
| Validate ASCII tab for musical correctness.
|
|
|
| Args:
|
| tab_strings: List of tab rows (6 strings for guitar)
|
| instrument: Instrument type
|
|
|
| Returns:
|
| (is_valid, error_messages)
|
| """
|
| errors = []
|
|
|
|
|
| expected_strings = self._get_expected_strings(instrument)
|
| if len(tab_strings) != expected_strings:
|
| errors.append(f"Expected {expected_strings} strings, got {len(tab_strings)}")
|
|
|
|
|
| for i, string_row in enumerate(tab_strings):
|
|
|
| if not self._validate_tab_row(string_row, i, instrument):
|
| errors.append(f"Invalid format on string {i}: {string_row}")
|
|
|
|
|
| if not self._check_musical_consistency(tab_strings):
|
| errors.append("Tab has musical inconsistencies (impossible fingering)")
|
|
|
| return len(errors) == 0, errors
|
|
|
| def _get_expected_strings(self, instrument: str) -> int:
|
| """Get expected number of strings for instrument."""
|
| mapping = {
|
| "guitar": 6,
|
| "bass": 4,
|
| "ukulele": 4,
|
| }
|
| return mapping.get(instrument, 6)
|
|
|
| def _validate_tab_row(self, row: str, string_idx: int, instrument: str) -> bool:
|
| """Validate a single tab row."""
|
|
|
| if "|" not in row:
|
| return False
|
|
|
|
|
| parts = row.split("|")
|
| if len(parts) < 2:
|
| return False
|
|
|
|
|
| for part in parts[1:-1]:
|
| if part.strip():
|
| try:
|
| fret = int(part.strip().replace("-", ""))
|
| if fret < 0 or fret > self.MAX_FRET:
|
| return False
|
| except ValueError:
|
|
|
| if part.strip().lower() != "x":
|
| return False
|
|
|
| return True
|
|
|
| def _check_musical_consistency(self, tab_strings: List[List[str]]) -> bool:
|
| """
|
| Check if tab is musically possible (basic checks).
|
| - No impossible stretches
|
| - Open strings are marked as 0
|
| """
|
|
|
| for string_row in tab_strings:
|
| for part in string_row.split("|")[1:-1]:
|
| fret_str = part.strip().replace("-", "")
|
| if fret_str and fret_str.lower() != "x":
|
| try:
|
| fret = int(fret_str)
|
| if fret < 0 or fret > self.MAX_FRET:
|
| return False
|
| except ValueError:
|
| return False
|
| return True
|
|
|
| def format_tab(
|
| self,
|
| frets: List[List[int]],
|
| instrument: str = "guitar",
|
| tuning: List[str] = None,
|
| ) -> List[str]:
|
| """
|
| Format fret positions into ASCII tab.
|
|
|
| Args:
|
| frets: List of [num_strings] lists with fret numbers (0=open, -1=muted)
|
| instrument: Instrument type
|
| tuning: Optional custom tuning labels
|
|
|
| Returns:
|
| List of formatted tab strings
|
| """
|
| if tuning is None:
|
| tuning = self.STANDARD_TUNING
|
|
|
| tab_strings = []
|
| string_labels = ["e", "B", "G", "D", "A", "E"]
|
|
|
| for i, (label, fret_row) in enumerate(zip(string_labels, frets)):
|
|
|
| row = f"{label}|"
|
| for fret in fret_row:
|
| if fret == -1:
|
| row += "x-"
|
| elif fret == 0:
|
| row += "0-"
|
| else:
|
| row += f"{fret}-"
|
| row += "|"
|
| tab_strings.append(row)
|
|
|
| return tab_strings
|
|
|
| def format_chord(
|
| self,
|
| frets: List[int],
|
| instrument: str = "guitar",
|
| ) -> str:
|
| """
|
| Format chord as compact diagram.
|
|
|
| Args:
|
| frets: List of fret numbers for each string (low to high)
|
| instrument: Instrument type
|
|
|
| Returns:
|
| Chord string (e.g., "320003" for G major)
|
| """
|
|
|
| return "".join(str(fret) if fret >= 0 else "x" for fret in frets)
|
|
|
| def parse_chord(self, chord_str: str) -> List[int]:
|
| """
|
| Parse chord string to fret positions.
|
|
|
| Args:
|
| chord_str: Chord string like "320003" or "x32010"
|
|
|
| Returns:
|
| List of fret positions
|
| """
|
| frets = []
|
| for char in chord_str:
|
| if char.lower() == "x":
|
| frets.append(-1)
|
| else:
|
| frets.append(int(char))
|
| return frets
|
|
|
| def suggest_easier_voicing(
|
| self,
|
| chord_frets: List[int],
|
| skill_level: str = "beginner",
|
| ) -> List[int]:
|
| """
|
| Suggest easier chord voicing for beginners.
|
|
|
| Args:
|
| chord_frets: Original chord frets
|
| skill_level: Target skill level
|
|
|
| Returns:
|
| Simplified chord frets
|
| """
|
| if skill_level != "beginner":
|
| return chord_frets
|
|
|
|
|
| simplified = chord_frets.copy()
|
|
|
|
|
| fret_counts = {}
|
| for fret in chord_frets:
|
| if fret > 0:
|
| fret_counts[fret] = fret_counts.get(fret, 0) + 1
|
|
|
|
|
| for fret, count in fret_counts.items():
|
| if count >= 3:
|
|
|
| for i, f in enumerate(simplified):
|
| if f == fret and i % 2 == 0:
|
| simplified[i] = 0
|
|
|
| return simplified
|
|
|
|
|
| def test_tab_chord_module():
|
| """Test the TabChordModule."""
|
| import torch
|
|
|
|
|
| module = TabChordModule(d_model=4096, num_strings=6, num_frets=24)
|
|
|
|
|
| batch_size = 2
|
| seq_len = 10
|
| d_model = 4096
|
| hidden_states = torch.randn(batch_size, seq_len, d_model)
|
|
|
|
|
| outputs = module.forward(
|
| hidden_states,
|
| instrument="guitar",
|
| skill_level="beginner",
|
| generate_tab=True,
|
| )
|
|
|
| print("Outputs:")
|
| for key, value in outputs.items():
|
| if isinstance(value, torch.Tensor):
|
| print(f" {key}: {value.shape}")
|
| else:
|
| print(f" {key}: {value}")
|
|
|
|
|
| frets = [[3, 3, 0, 0, 2, 3]]
|
| tab = module.format_tab(frets, instrument="guitar")
|
| print("\nFormatted tab:")
|
| for line in tab:
|
| print(f" {line}")
|
|
|
|
|
| chord = module.format_chord([3, 2, 0, 0, 3, 3])
|
| print(f"\nChord: {chord}")
|
|
|
|
|
| is_valid, errors = module.validate_tab(tab, instrument="guitar")
|
| print(f"\nTab valid: {is_valid}")
|
| if errors:
|
| print(f"Errors: {errors}")
|
|
|
| print("\nTabChordModule test complete!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_tab_chord_module() |