| """
|
| Dataset Loader for TouchGrass.
|
| Handles loading and preprocessing of music QA data for fine-tuning.
|
| """
|
|
|
| from typing import List, Dict, Any, Optional
|
| from pathlib import Path
|
| import json
|
| import random
|
| from torch.utils.data import Dataset, DataLoader
|
| from transformers import AutoTokenizer
|
|
|
|
|
| class TouchGrassDataset(Dataset):
|
| """
|
| Dataset for TouchGrass fine-tuning.
|
| Loads chat-formatted data and tokenizes for training.
|
| """
|
|
|
| def __init__(
|
| self,
|
| data_path: str,
|
| tokenizer,
|
| max_seq_length: int = 4096,
|
| mode: str = "train",
|
| ):
|
| """
|
| Initialize dataset.
|
|
|
| Args:
|
| data_path: Path to JSONL file with chat data
|
| tokenizer: Tokenizer (extended Qwen tokenizer)
|
| max_seq_length: Maximum sequence length
|
| mode: "train" or "eval"
|
| """
|
| self.data_path = Path(data_path)
|
| self.tokenizer = tokenizer
|
| self.max_seq_length = max_seq_length
|
| self.mode = mode
|
|
|
|
|
| self.samples = self._load_data()
|
|
|
| print(f"Loaded {len(self.samples)} samples from {data_path}")
|
|
|
| def _load_data(self) -> List[Dict[str, Any]]:
|
| """Load data from JSONL file."""
|
| samples = []
|
| with open(self.data_path, "r", encoding="utf-8") as f:
|
| for line in f:
|
| if line.strip():
|
| samples.append(json.loads(line))
|
| return samples
|
|
|
| def __len__(self) -> int:
|
| return len(self.samples)
|
|
|
| def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| sample = self.samples[idx]
|
| messages = sample["messages"]
|
|
|
|
|
|
|
| formatted_text = self._format_chat_qwen(messages)
|
|
|
|
|
| encoding = self.tokenizer(
|
| formatted_text,
|
| truncation=True,
|
| max_length=self.max_seq_length,
|
| padding="max_length" if self.mode == "train" else False,
|
| return_tensors="pt",
|
| )
|
|
|
| input_ids = encoding["input_ids"].squeeze(0)
|
| attention_mask = encoding["attention_mask"].squeeze(0)
|
|
|
|
|
| labels = input_ids.clone()
|
|
|
|
|
|
|
|
|
|
|
| return {
|
| "input_ids": input_ids,
|
| "attention_mask": attention_mask,
|
| "labels": labels,
|
| }
|
|
|
| def _format_chat_qwen(self, messages: List[Dict[str, str]]) -> str:
|
| """
|
| Format messages into Qwen chat format.
|
|
|
| Qwen chat format:
|
| <|im_start|>system
|
| You are a helpful assistant.<|im_end|>
|
| <|im_start|>user
|
| Hello!<|im_end|>
|
| <|im_start|>assistant
|
| Hi there!<|im_end|>
|
| """
|
| formatted = []
|
| for msg in messages:
|
| role = msg["role"]
|
| content = msg["content"].strip()
|
|
|
|
|
| if role == "system":
|
| formatted.append(f"<|im_start|>system\n{content}<|im_end|>")
|
| elif role == "user":
|
| formatted.append(f"<|im_start|>user\n{content}<|im_end|>")
|
| elif role == "assistant":
|
| formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
|
| else:
|
|
|
| continue
|
|
|
| return "\n".join(formatted)
|
|
|
| def get_sample(self, idx: int) -> str:
|
| """Get raw formatted text for inspection."""
|
| sample = self.samples[idx]
|
| messages = sample["messages"]
|
| return self._format_chat_qwen(messages)
|
|
|
|
|
| def test_dataset():
|
| """Test the dataset loader."""
|
| from transformers import AutoTokenizer
|
|
|
|
|
| print("Loading tokenizer...")
|
| try:
|
| from tokenizer.music_token_extension import MusicTokenizerExtension
|
| tokenizer_ext = MusicTokenizerExtension(
|
| base_tokenizer_name="Qwen/Qwen3.5-3B-Instruct",
|
| )
|
| tokenizer = tokenizer_ext.get_tokenizer()
|
| except Exception as e:
|
| print(f"Could not load tokenizer: {e}")
|
| print("Using dummy tokenizer for testing...")
|
| from transformers import AutoTokenizer
|
| tokenizer = AutoTokenizer.from_pretrained(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| trust_remote_code=True,
|
| )
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
| print("\nCreating dataset...")
|
| dataset = TouchGrassDataset(
|
| data_path="data/processed/train.jsonl",
|
| tokenizer=tokenizer,
|
| max_seq_length=1024,
|
| mode="train",
|
| )
|
|
|
| print(f"Dataset size: {len(dataset)}")
|
|
|
|
|
| if len(dataset) > 0:
|
| sample = dataset[0]
|
| print("\nSample keys:", list(sample.keys()))
|
| print("Input IDs shape:", sample["input_ids"].shape)
|
| print("Attention mask shape:", sample["attention_mask"].shape)
|
| print("Labels shape:", sample["labels"].shape)
|
|
|
|
|
| decoded = tokenizer.decode(sample["input_ids"][:100])
|
| print(f"\nFirst 100 tokens:\n{decoded}...")
|
|
|
| print("\nDataset test complete!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_dataset() |