| | """ |
| | 5Hz LM (Language Model) Handler |
| | Handles all LM-related operations including initialization and generation |
| | """ |
| | import os |
| | import traceback |
| | import time |
| | import random |
| | from typing import Optional, Dict, Any, Tuple, List, Union |
| | from contextlib import contextmanager |
| |
|
| | import yaml |
| | import torch |
| | from loguru import logger |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from transformers.generation.streamers import BaseStreamer |
| | from transformers.generation.logits_process import ( |
| | LogitsProcessorList, |
| | RepetitionPenaltyLogitsProcessor, |
| | ) |
| | from acestep.constrained_logits_processor import MetadataConstrainedLogitsProcessor |
| | from acestep.constants import DEFAULT_LM_INSTRUCTION, DEFAULT_LM_UNDERSTAND_INSTRUCTION, DEFAULT_LM_INSPIRED_INSTRUCTION, DEFAULT_LM_REWRITE_INSTRUCTION |
| |
|
| |
|
| | class LLMHandler: |
| | """5Hz LM Handler for audio code generation""" |
| |
|
| | STOP_REASONING_TAG = "</think>" |
| | |
| | def __init__(self): |
| | """Initialize LLMHandler with default values""" |
| | self.llm = None |
| | self.llm_tokenizer = None |
| | self.llm_initialized = False |
| | self.llm_backend = None |
| | self.max_model_len = 4096 |
| | self.device = "cpu" |
| | self.dtype = torch.float32 |
| | self.offload_to_cpu = False |
| | |
| | |
| | self.constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None |
| | |
| | |
| | self._hf_model_for_scoring = None |
| | |
| | def get_available_5hz_lm_models(self) -> List[str]: |
| | """Scan and return all model directory names starting with 'acestep-5Hz-lm-'""" |
| | current_file = os.path.abspath(__file__) |
| | project_root = os.path.dirname(os.path.dirname(current_file)) |
| | checkpoint_dir = os.path.join(project_root, "checkpoints") |
| | |
| | models = [] |
| | if os.path.exists(checkpoint_dir): |
| | for item in os.listdir(checkpoint_dir): |
| | item_path = os.path.join(checkpoint_dir, item) |
| | if os.path.isdir(item_path) and item.startswith("acestep-5Hz-lm-"): |
| | models.append(item) |
| | |
| | models.sort() |
| | return models |
| | |
| | def get_gpu_memory_utilization(self, minimal_gpu: float = 8, min_ratio: float = 0.2, max_ratio: float = 0.9) -> Tuple[float, bool]: |
| | """Get GPU memory utilization ratio""" |
| | try: |
| | device = torch.device("cuda:0") |
| | total_gpu_mem_bytes = torch.cuda.get_device_properties(device).total_memory |
| | allocated_mem_bytes = torch.cuda.memory_allocated(device) |
| | reserved_mem_bytes = torch.cuda.memory_reserved(device) |
| | |
| | total_gpu = total_gpu_mem_bytes / 1024**3 |
| | low_gpu_memory_mode = False |
| | if total_gpu < minimal_gpu: |
| | minimal_gpu = 0.5 * total_gpu |
| | low_gpu_memory_mode = True |
| | allocated_gpu = allocated_mem_bytes / 1024**3 |
| | reserved_gpu = reserved_mem_bytes / 1024**3 |
| | available_gpu = total_gpu - reserved_gpu |
| | |
| | if available_gpu >= minimal_gpu: |
| | ratio = min(max_ratio, max(min_ratio, minimal_gpu / total_gpu)) |
| | else: |
| | ratio = min(max_ratio, max(min_ratio, (available_gpu * 0.8) / total_gpu)) |
| | |
| | return ratio, low_gpu_memory_mode |
| | except Exception as e: |
| | return 0.9, False |
| | |
| | def _has_meaningful_negative_prompt(self, negative_prompt: str) -> bool: |
| | """Check if negative prompt is meaningful (not default/empty)""" |
| | return negative_prompt and negative_prompt.strip() and negative_prompt.strip() != "NO USER INPUT" |
| | |
| | def _build_logits_processor(self, repetition_penalty: float) -> LogitsProcessorList: |
| | """Build logits processor list with repetition penalty if needed""" |
| | logits_processor = LogitsProcessorList() |
| | if repetition_penalty != 1.0: |
| | logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) |
| | return logits_processor |
| | |
| | def _setup_constrained_processor( |
| | self, |
| | use_constrained_decoding: bool, |
| | constrained_decoding_debug: bool, |
| | target_duration: Optional[float], |
| | user_metadata: Optional[Dict[str, Optional[str]]], |
| | stop_at_reasoning: bool, |
| | skip_genres: bool, |
| | skip_caption: bool, |
| | skip_language: bool, |
| | generation_phase: str, |
| | is_batch: bool = False, |
| | metadata_temperature: Optional[float] = None, |
| | codes_temperature: Optional[float] = None, |
| | ) -> Optional[MetadataConstrainedLogitsProcessor]: |
| | """Setup and configure constrained processor for generation""" |
| | use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None) |
| | |
| | if not use_constrained_decoding and not use_phase_temperatures: |
| | return None |
| | |
| | |
| | self.constrained_processor.reset() |
| | |
| | |
| | self.constrained_processor.enabled = use_constrained_decoding |
| | self.constrained_processor.debug = constrained_decoding_debug |
| | |
| | |
| | if use_phase_temperatures: |
| | self.constrained_processor.metadata_temperature = metadata_temperature |
| | self.constrained_processor.codes_temperature = codes_temperature |
| | else: |
| | self.constrained_processor.metadata_temperature = None |
| | self.constrained_processor.codes_temperature = None |
| | |
| | self.constrained_processor.set_target_duration(target_duration) |
| | |
| | |
| | if is_batch: |
| | self.constrained_processor.set_user_metadata(None) |
| | self.constrained_processor.set_stop_at_reasoning(False) |
| | self.constrained_processor.set_skip_genres(True) |
| | self.constrained_processor.set_skip_caption(True) |
| | self.constrained_processor.set_skip_language(True) |
| | else: |
| | |
| | self.constrained_processor.set_user_metadata(user_metadata) |
| | self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning) |
| | self.constrained_processor.set_skip_genres(skip_genres) |
| | self.constrained_processor.set_skip_caption(skip_caption) |
| | self.constrained_processor.set_skip_language(skip_language) |
| | |
| | |
| | self.constrained_processor.set_generation_phase(generation_phase) |
| | |
| | return self.constrained_processor |
| | |
| | def _build_unconditional_prompt( |
| | self, |
| | caption: str, |
| | lyrics: str, |
| | cot_text: str, |
| | negative_prompt: str, |
| | generation_phase: str, |
| | is_batch: bool = False, |
| | ) -> str: |
| | """Build unconditional prompt for CFG based on generation phase and batch mode""" |
| | if is_batch or generation_phase == "codes": |
| | |
| | return self.build_formatted_prompt_with_cot( |
| | caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt |
| | ) |
| | else: |
| | |
| | |
| | return self.build_formatted_prompt( |
| | caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt |
| | ) |
| | |
| | def _load_pytorch_model(self, model_path: str, device: str) -> Tuple[bool, str]: |
| | """Load PyTorch model from path and return (success, status_message)""" |
| | try: |
| | self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) |
| | if not self.offload_to_cpu: |
| | self.llm = self.llm.to(device).to(self.dtype) |
| | else: |
| | self.llm = self.llm.to("cpu").to(self.dtype) |
| | self.llm.eval() |
| | self.llm_backend = "pt" |
| | self.llm_initialized = True |
| | logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}") |
| | status_msg = f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nBackend: PyTorch\nDevice: {device}" |
| | return True, status_msg |
| | except Exception as e: |
| | return False, f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
| | |
| | def _apply_top_k_filter(self, logits: torch.Tensor, top_k: Optional[int]) -> torch.Tensor: |
| | """Apply top-k filtering to logits""" |
| | if top_k is not None and top_k > 0: |
| | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| | logits[indices_to_remove] = float('-inf') |
| | return logits |
| | |
| | def _apply_top_p_filter(self, logits: torch.Tensor, top_p: Optional[float]) -> torch.Tensor: |
| | """Apply top-p (nucleus) filtering to logits""" |
| | if top_p is not None and 0.0 < top_p < 1.0: |
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| | cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| | logits[indices_to_remove] = float('-inf') |
| | return logits |
| | |
| | def _sample_tokens(self, logits: torch.Tensor, temperature: float) -> torch.Tensor: |
| | """Sample tokens from logits with temperature""" |
| | if temperature > 0: |
| | logits = logits / temperature |
| | probs = torch.softmax(logits, dim=-1) |
| | return torch.multinomial(probs, num_samples=1).squeeze(1) |
| | else: |
| | return torch.argmax(logits, dim=-1) |
| | |
| | def _check_eos_token(self, tokens: torch.Tensor, eos_token_id: int, pad_token_id: Optional[int]) -> bool: |
| | """Check if any token in the batch is EOS or pad token""" |
| | if torch.any(tokens == eos_token_id): |
| | return True |
| | if pad_token_id is not None and pad_token_id != eos_token_id: |
| | if torch.any(tokens == pad_token_id): |
| | return True |
| | return False |
| | |
| | def _update_constrained_processor_state(self, constrained_processor: Optional[MetadataConstrainedLogitsProcessor], tokens: torch.Tensor): |
| | """Update constrained processor state with generated tokens""" |
| | if constrained_processor is not None: |
| | for b in range(tokens.shape[0]): |
| | constrained_processor.update_state(tokens[b].item()) |
| | |
| | def _forward_pass( |
| | self, |
| | model: Any, |
| | generated_ids: torch.Tensor, |
| | model_kwargs: Dict[str, Any], |
| | past_key_values: Optional[Any], |
| | use_cache: bool, |
| | ) -> Any: |
| | """Perform forward pass with KV cache support""" |
| | if past_key_values is None: |
| | outputs = model( |
| | input_ids=generated_ids, |
| | **model_kwargs, |
| | use_cache=use_cache, |
| | ) |
| | else: |
| | outputs = model( |
| | input_ids=generated_ids[:, -1:], |
| | past_key_values=past_key_values, |
| | **model_kwargs, |
| | use_cache=use_cache, |
| | ) |
| | return outputs |
| | |
| | def _normalize_batch_input(self, formatted_prompts: Union[str, List[str]]) -> Tuple[List[str], bool]: |
| | """Normalize batch input: convert single string to list and return (list, is_batch)""" |
| | is_batch = isinstance(formatted_prompts, list) |
| | if is_batch: |
| | return formatted_prompts, is_batch |
| | else: |
| | return [formatted_prompts], is_batch |
| | |
| | def initialize( |
| | self, |
| | checkpoint_dir: str, |
| | lm_model_path: str, |
| | backend: str = "vllm", |
| | device: str = "auto", |
| | offload_to_cpu: bool = False, |
| | dtype: Optional[torch.dtype] = None, |
| | ) -> Tuple[str, bool]: |
| | """ |
| | Initialize 5Hz LM model |
| | |
| | Args: |
| | checkpoint_dir: Checkpoint directory path |
| | lm_model_path: LM model path (relative to checkpoint_dir) |
| | backend: Backend type ("vllm" or "pt") |
| | device: Device type ("auto", "cuda", or "cpu") |
| | offload_to_cpu: Whether to offload to CPU |
| | dtype: Data type (if None, auto-detect based on device) |
| | |
| | Returns: |
| | (status_message, success) |
| | """ |
| | try: |
| | if device == "auto": |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | |
| | self.device = device |
| | self.offload_to_cpu = offload_to_cpu |
| | |
| | if dtype is None: |
| | self.dtype = torch.bfloat16 if device in ["cuda", "xpu"] else torch.float32 |
| | else: |
| | self.dtype = dtype |
| | |
| | full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path) |
| | if not os.path.exists(full_lm_model_path): |
| | return f"❌ 5Hz LM model not found at {full_lm_model_path}", False |
| | |
| | logger.info("loading 5Hz LM tokenizer... it may take 80~90s") |
| | start_time = time.time() |
| | |
| | llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True) |
| | logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds") |
| | self.llm_tokenizer = llm_tokenizer |
| | |
| | |
| | logger.info("Initializing constrained decoding processor...") |
| | processor_start = time.time() |
| | self.constrained_processor = MetadataConstrainedLogitsProcessor( |
| | tokenizer=self.llm_tokenizer, |
| | enabled=True, |
| | debug=False, |
| | ) |
| | logger.info(f"Constrained processor initialized in {time.time() - processor_start:.2f} seconds") |
| | |
| | |
| | if backend == "vllm": |
| | |
| | status_msg = self._initialize_5hz_lm_vllm(full_lm_model_path) |
| | logger.info(f"5Hz LM status message: {status_msg}") |
| | |
| | if status_msg.startswith("❌"): |
| | |
| | if not self.llm_initialized: |
| | logger.warning("vllm initialization failed, falling back to PyTorch backend") |
| | success, status_msg = self._load_pytorch_model(full_lm_model_path, device) |
| | if not success: |
| | return status_msg, False |
| | status_msg = f"✅ 5Hz LM initialized successfully (PyTorch fallback)\nModel: {full_lm_model_path}\nBackend: PyTorch" |
| | |
| | else: |
| | |
| | success, status_msg = self._load_pytorch_model(full_lm_model_path, device) |
| | if not success: |
| | return status_msg, False |
| | |
| | return status_msg, True |
| | |
| | except Exception as e: |
| | return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False |
| | |
| | def _initialize_5hz_lm_vllm(self, model_path: str) -> str: |
| | """Initialize 5Hz LM model using vllm backend""" |
| | if not torch.cuda.is_available(): |
| | self.llm_initialized = False |
| | logger.error("CUDA is not available. Please check your GPU setup.") |
| | return "❌ CUDA is not available. Please check your GPU setup." |
| | try: |
| | from nanovllm import LLM, SamplingParams |
| | except ImportError: |
| | self.llm_initialized = False |
| | logger.error("nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .") |
| | return "❌ nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install ." |
| | |
| | try: |
| | current_device = torch.cuda.current_device() |
| | device_name = torch.cuda.get_device_name(current_device) |
| | |
| | torch.cuda.empty_cache() |
| | gpu_memory_utilization, low_gpu_memory_mode = self.get_gpu_memory_utilization( |
| | minimal_gpu=8, |
| | min_ratio=0.2, |
| | max_ratio=0.9 |
| | ) |
| | if low_gpu_memory_mode: |
| | self.max_model_len = 2048 |
| | else: |
| | self.max_model_len = 4096 |
| | |
| | logger.info(f"Initializing 5Hz LM with model: {model_path}, enforce_eager: False, tensor_parallel_size: 1, max_model_len: {self.max_model_len}, gpu_memory_utilization: {gpu_memory_utilization}") |
| | start_time = time.time() |
| | self.llm = LLM( |
| | model=model_path, |
| | enforce_eager=False, |
| | tensor_parallel_size=1, |
| | max_model_len=self.max_model_len, |
| | gpu_memory_utilization=gpu_memory_utilization, |
| | tokenizer=self.llm_tokenizer, |
| | ) |
| | logger.info(f"5Hz LM initialized successfully in {time.time() - start_time:.2f} seconds") |
| | self.llm_initialized = True |
| | self.llm_backend = "vllm" |
| | return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}" |
| | except Exception as e: |
| | self.llm_initialized = False |
| | return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
| |
|
| | def _run_vllm( |
| | self, |
| | formatted_prompts: Union[str, List[str]], |
| | temperature: float, |
| | cfg_scale: float, |
| | negative_prompt: str, |
| | top_k: Optional[int], |
| | top_p: Optional[float], |
| | repetition_penalty: float, |
| | use_constrained_decoding: bool = True, |
| | constrained_decoding_debug: bool = False, |
| | metadata_temperature: Optional[float] = None, |
| | codes_temperature: Optional[float] = None, |
| | target_duration: Optional[float] = None, |
| | user_metadata: Optional[Dict[str, Optional[str]]] = None, |
| | stop_at_reasoning: bool = False, |
| | skip_genres: bool = True, |
| | skip_caption: bool = False, |
| | skip_language: bool = False, |
| | generation_phase: str = "cot", |
| | caption: str = "", |
| | lyrics: str = "", |
| | cot_text: str = "", |
| | seeds: Optional[List[int]] = None, |
| | ) -> Union[str, List[str]]: |
| | """ |
| | Unified vllm generation function supporting both single and batch modes. |
| | Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]). |
| | Returns a single string for single mode, or a list of strings for batch mode. |
| | """ |
| | from nanovllm import SamplingParams |
| |
|
| | |
| | formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts) |
| | batch_size = len(formatted_prompt_list) |
| |
|
| | |
| | |
| | |
| | use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None) |
| | effective_sampler_temp = 1.0 if use_phase_temperatures else temperature |
| |
|
| | |
| | constrained_processor = self._setup_constrained_processor( |
| | use_constrained_decoding=use_constrained_decoding or use_phase_temperatures, |
| | constrained_decoding_debug=constrained_decoding_debug, |
| | target_duration=target_duration, |
| | user_metadata=user_metadata, |
| | stop_at_reasoning=stop_at_reasoning, |
| | skip_genres=skip_genres, |
| | skip_caption=skip_caption, |
| | skip_language=skip_language, |
| | generation_phase=generation_phase, |
| | is_batch=is_batch, |
| | metadata_temperature=metadata_temperature, |
| | codes_temperature=codes_temperature, |
| | ) |
| |
|
| | sampling_params = SamplingParams( |
| | max_tokens=self.max_model_len - 64, |
| | temperature=effective_sampler_temp, |
| | cfg_scale=cfg_scale, |
| | top_k=top_k, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | logits_processor=constrained_processor, |
| | logits_processor_update_state=constrained_processor.update_state if constrained_processor else None, |
| | ) |
| |
|
| | if cfg_scale > 1.0: |
| | |
| | formatted_unconditional_prompt = self._build_unconditional_prompt( |
| | caption=caption, |
| | lyrics=lyrics, |
| | cot_text=cot_text, |
| | negative_prompt=negative_prompt, |
| | generation_phase=generation_phase, |
| | is_batch=is_batch, |
| | ) |
| | unconditional_prompts = [formatted_unconditional_prompt] * batch_size |
| | |
| | outputs = self.llm.generate( |
| | formatted_prompt_list, |
| | sampling_params, |
| | unconditional_prompts=unconditional_prompts, |
| | ) |
| | else: |
| | outputs = self.llm.generate(formatted_prompt_list, sampling_params) |
| |
|
| | |
| | output_texts = [] |
| | for output in outputs: |
| | if hasattr(output, "outputs") and len(output.outputs) > 0: |
| | output_texts.append(output.outputs[0].text) |
| | elif hasattr(output, "text"): |
| | output_texts.append(output.text) |
| | elif isinstance(output, dict) and "text" in output: |
| | output_texts.append(output["text"]) |
| | else: |
| | output_texts.append(str(output)) |
| |
|
| | |
| | return output_texts[0] if not is_batch else output_texts |
| |
|
| | def _run_pt_single( |
| | self, |
| | formatted_prompt: str, |
| | temperature: float, |
| | cfg_scale: float, |
| | negative_prompt: str, |
| | top_k: Optional[int], |
| | top_p: Optional[float], |
| | repetition_penalty: float, |
| | use_constrained_decoding: bool, |
| | constrained_decoding_debug: bool, |
| | target_duration: Optional[float], |
| | user_metadata: Optional[Dict[str, Optional[str]]], |
| | stop_at_reasoning: bool, |
| | skip_genres: bool, |
| | skip_caption: bool, |
| | skip_language: bool, |
| | generation_phase: str, |
| | caption: str, |
| | lyrics: str, |
| | cot_text: str, |
| | ) -> str: |
| | """Internal helper function for single-item PyTorch generation.""" |
| | inputs = self.llm_tokenizer( |
| | formatted_prompt, |
| | return_tensors="pt", |
| | padding=False, |
| | truncation=True, |
| | ) |
| |
|
| | |
| | constrained_processor = self._setup_constrained_processor( |
| | use_constrained_decoding=use_constrained_decoding, |
| | constrained_decoding_debug=constrained_decoding_debug, |
| | target_duration=target_duration, |
| | user_metadata=user_metadata, |
| | stop_at_reasoning=stop_at_reasoning, |
| | skip_genres=skip_genres, |
| | skip_caption=skip_caption, |
| | skip_language=skip_language, |
| | generation_phase=generation_phase, |
| | is_batch=False, |
| | ) |
| |
|
| | with self._load_model_context(): |
| | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| | max_new_tokens = getattr(self.llm.config, "max_new_tokens", 4096) |
| | if hasattr(self, "max_model_len"): |
| | max_new_tokens = min(max_new_tokens, self.max_model_len - 64) |
| |
|
| | |
| | logits_processor = self._build_logits_processor(repetition_penalty) |
| |
|
| | if cfg_scale > 1.0: |
| | |
| | formatted_unconditional_prompt = self._build_unconditional_prompt( |
| | caption=caption, |
| | lyrics=lyrics, |
| | cot_text=cot_text, |
| | negative_prompt=negative_prompt, |
| | generation_phase=generation_phase, |
| | is_batch=False, |
| | ) |
| | |
| | |
| | |
| | batch_texts = [formatted_prompt, formatted_unconditional_prompt] |
| | original_padding_side = self.llm_tokenizer.padding_side |
| | self.llm_tokenizer.padding_side = 'left' |
| | batch_inputs_tokenized = self.llm_tokenizer( |
| | batch_texts, |
| | return_tensors="pt", |
| | padding=True, |
| | truncation=True, |
| | ) |
| | self.llm_tokenizer.padding_side = original_padding_side |
| | batch_inputs_tokenized = {k: v.to(self.device) for k, v in batch_inputs_tokenized.items()} |
| | |
| | |
| | batch_input_ids = batch_inputs_tokenized['input_ids'] |
| | batch_attention_mask = batch_inputs_tokenized.get('attention_mask', None) |
| |
|
| | |
| | outputs = self._generate_with_cfg_custom( |
| | batch_input_ids=batch_input_ids, |
| | batch_attention_mask=batch_attention_mask, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | cfg_scale=cfg_scale, |
| | top_k=top_k, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id, |
| | streamer=None, |
| | constrained_processor=constrained_processor, |
| | ) |
| | |
| | |
| | outputs = outputs[0:1] |
| | elif use_constrained_decoding: |
| | |
| | outputs = self._generate_with_constrained_decoding( |
| | input_ids=inputs["input_ids"], |
| | attention_mask=inputs.get("attention_mask"), |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | top_k=top_k, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id, |
| | streamer=None, |
| | constrained_processor=constrained_processor, |
| | ) |
| | else: |
| | |
| | with torch.no_grad(): |
| | outputs = self.llm.generate( |
| | **inputs, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature if temperature > 0 else 1.0, |
| | do_sample=True if temperature > 0 else False, |
| | top_k=top_k if top_k is not None and top_k > 0 else None, |
| | top_p=top_p if top_p is not None and 0.0 < top_p < 1.0 else None, |
| | logits_processor=logits_processor if len(logits_processor) > 0 else None, |
| | pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id, |
| | streamer=None, |
| | ) |
| |
|
| | |
| | |
| | if isinstance(outputs, torch.Tensor): |
| | if outputs.dim() == 2: |
| | generated_ids = outputs[0] |
| | else: |
| | generated_ids = outputs |
| | else: |
| | generated_ids = outputs[0] |
| | |
| | |
| | |
| | if cfg_scale > 1.0: |
| | |
| | |
| | input_length = batch_inputs_tokenized['input_ids'].shape[1] |
| | else: |
| | input_length = inputs["input_ids"].shape[1] |
| | |
| | generated_ids = generated_ids[input_length:] |
| | |
| | |
| | if generated_ids.is_cuda: |
| | generated_ids = generated_ids.cpu() |
| | |
| | output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False) |
| | return output_text |
| |
|
| | def _run_pt( |
| | self, |
| | formatted_prompts: Union[str, List[str]], |
| | temperature: float, |
| | cfg_scale: float, |
| | negative_prompt: str, |
| | top_k: Optional[int], |
| | top_p: Optional[float], |
| | repetition_penalty: float, |
| | use_constrained_decoding: bool = True, |
| | constrained_decoding_debug: bool = False, |
| | target_duration: Optional[float] = None, |
| | user_metadata: Optional[Dict[str, Optional[str]]] = None, |
| | stop_at_reasoning: bool = False, |
| | skip_genres: bool = True, |
| | skip_caption: bool = False, |
| | skip_language: bool = False, |
| | generation_phase: str = "cot", |
| | caption: str = "", |
| | lyrics: str = "", |
| | cot_text: str = "", |
| | seeds: Optional[List[int]] = None, |
| | ) -> Union[str, List[str]]: |
| | """ |
| | Unified PyTorch generation function supporting both single and batch modes. |
| | Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]). |
| | Returns a single string for single mode, or a list of strings for batch mode. |
| | Note: PyTorch backend processes batch items sequentially (doesn't support true batching efficiently). |
| | """ |
| | |
| | formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts) |
| |
|
| | |
| | if is_batch: |
| | output_texts = [] |
| | for i, formatted_prompt in enumerate(formatted_prompt_list): |
| | |
| | if seeds and i < len(seeds): |
| | torch.manual_seed(seeds[i]) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(seeds[i]) |
| | |
| | |
| | output_text = self._run_pt_single( |
| | formatted_prompt=formatted_prompt, |
| | temperature=temperature, |
| | cfg_scale=cfg_scale, |
| | negative_prompt=negative_prompt, |
| | top_k=top_k, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | use_constrained_decoding=use_constrained_decoding, |
| | constrained_decoding_debug=constrained_decoding_debug, |
| | target_duration=target_duration, |
| | user_metadata=None, |
| | stop_at_reasoning=False, |
| | skip_genres=True, |
| | skip_caption=True, |
| | skip_language=True, |
| | generation_phase=generation_phase, |
| | caption=caption, |
| | lyrics=lyrics, |
| | cot_text=cot_text, |
| | ) |
| | |
| | output_texts.append(output_text) |
| | |
| | return output_texts |
| |
|
| | |
| | formatted_prompt = formatted_prompt_list[0] |
| | |
| | return self._run_pt_single( |
| | formatted_prompt=formatted_prompt, |
| | temperature=temperature, |
| | cfg_scale=cfg_scale, |
| | negative_prompt=negative_prompt, |
| | top_k=top_k, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | use_constrained_decoding=use_constrained_decoding, |
| | constrained_decoding_debug=constrained_decoding_debug, |
| | target_duration=target_duration, |
| | user_metadata=user_metadata, |
| | stop_at_reasoning=stop_at_reasoning, |
| | skip_genres=skip_genres, |
| | skip_caption=skip_caption, |
| | skip_language=skip_language, |
| | generation_phase=generation_phase, |
| | caption=caption, |
| | lyrics=lyrics, |
| | cot_text=cot_text, |
| | ) |
| |
|
| | def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool: |
| | """Check if all required metadata are present.""" |
| | if user_metadata is None: |
| | return False |
| | if 'bpm' in user_metadata and 'keyscale' in user_metadata and 'timesignature' in user_metadata and 'duration' in user_metadata: |
| | return True |
| | return False |
| | |
| | def _format_metadata_as_cot(self, metadata: Dict[str, Any]) -> str: |
| | """ |
| | Format parsed metadata as CoT text using YAML format (matching training format). |
| | |
| | Args: |
| | metadata: Dictionary with keys: bpm, caption, duration, keyscale, language, timesignature |
| | |
| | Returns: |
| | Formatted CoT text: "<think>\n{yaml_content}\n</think>" |
| | """ |
| | |
| | cot_items = {} |
| | for key in ['bpm', 'caption', 'duration', 'keyscale', 'language', 'timesignature']: |
| | if key in metadata and metadata[key] is not None: |
| | value = metadata[key] |
| | if key == "timesignature" and value.endswith("/4"): |
| | value = value.split("/")[0] |
| | if isinstance(value, str) and value.isdigit(): |
| | value = int(value) |
| | cot_items[key] = value |
| | |
| | |
| | if len(cot_items) > 0: |
| | cot_yaml = yaml.dump(cot_items, allow_unicode=True, sort_keys=True).strip() |
| | else: |
| | cot_yaml = "" |
| | |
| | return f"<think>\n{cot_yaml}\n</think>" |
| |
|
| | def generate_with_stop_condition( |
| | self, |
| | caption: str, |
| | lyrics: str, |
| | infer_type: str, |
| | temperature: float = 0.85, |
| | cfg_scale: float = 1.0, |
| | negative_prompt: str = "NO USER INPUT", |
| | top_k: Optional[int] = None, |
| | top_p: Optional[float] = None, |
| | repetition_penalty: float = 1.0, |
| | use_constrained_decoding: bool = True, |
| | constrained_decoding_debug: bool = False, |
| | target_duration: Optional[float] = None, |
| | user_metadata: Optional[Dict[str, Optional[str]]] = None, |
| | use_cot_metas: bool = True, |
| | use_cot_caption: bool = True, |
| | use_cot_language: bool = True, |
| | batch_size: Optional[int] = None, |
| | seeds: Optional[List[int]] = None, |
| | progress=None, |
| | ) -> Dict[str, Any]: |
| | """Two-phase LM generation: CoT generation followed by audio codes generation. |
| | |
| | - infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes) |
| | - infer_type='llm_dit': Phase 1 + Phase 2 - generate CoT then audio codes |
| | |
| | Args: |
| | target_duration: Target duration in seconds for codes generation constraint. |
| | 5 codes = 1 second. If specified, blocks EOS until target reached. |
| | user_metadata: User-provided metadata fields (e.g. bpm/duration/keyscale/timesignature). |
| | If specified, constrained decoding will inject these values directly. |
| | use_cot_caption: Whether to generate caption in CoT (default True). |
| | use_cot_language: Whether to generate language in CoT (default True). |
| | batch_size: Optional batch size for batch generation. If None or 1, returns single result. |
| | If > 1, returns batch results (lists). |
| | seeds: Optional list of seeds for batch generation (for reproducibility). |
| | Only used when batch_size > 1. TODO: not used yet |
| | |
| | Returns: |
| | Dictionary containing: |
| | - metadata: Dict or List[Dict] - Generated metadata |
| | - audio_codes: str or List[str] - Generated audio codes |
| | - success: bool - Whether generation succeeded |
| | - error: Optional[str] - Error message if failed |
| | - extra_outputs: Dict with time_costs and other info |
| | """ |
| | if progress is None: |
| | def progress(*args, **kwargs): |
| | pass |
| |
|
| | infer_type = (infer_type or "").strip().lower() |
| | if infer_type not in {"dit", "llm_dit"}: |
| | error_msg = f"invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')" |
| | return { |
| | "metadata": [] if (batch_size and batch_size > 1) else {}, |
| | "audio_codes": [] if (batch_size and batch_size > 1) else "", |
| | "success": False, |
| | "error": error_msg, |
| | "extra_outputs": {"time_costs": {}}, |
| | } |
| | |
| | |
| | is_batch = batch_size and batch_size > 1 |
| | actual_batch_size = batch_size if is_batch else 1 |
| | |
| | |
| | metadata = {} |
| | audio_codes = "" |
| | has_all_metas = self.has_all_metas(user_metadata) |
| | phase1_time = 0.0 |
| | phase2_time = 0.0 |
| | |
| | |
| | if is_batch: |
| | if seeds is None: |
| | seeds = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)] |
| | elif len(seeds) < actual_batch_size: |
| | seeds = list(seeds) + [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size - len(seeds))] |
| | else: |
| | seeds = seeds[:actual_batch_size] |
| | |
| | |
| | |
| | progress(0.1, f"Phase 1: Generating CoT metadata (once for all items)...") |
| | if not has_all_metas and use_cot_metas: |
| | if is_batch: |
| | logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...") |
| | else: |
| | logger.info("Phase 1: Generating CoT metadata...") |
| | phase1_start = time.time() |
| | |
| | |
| | formatted_prompt = self.build_formatted_prompt(caption, lyrics, generation_phase="cot") |
| | |
| | logger.info(f"generate_with_stop_condition: formatted_prompt={formatted_prompt}") |
| | |
| | cot_output_text, status = self.generate_from_formatted_prompt( |
| | formatted_prompt=formatted_prompt, |
| | cfg={ |
| | "temperature": temperature, |
| | "cfg_scale": cfg_scale, |
| | "negative_prompt": negative_prompt, |
| | "top_k": top_k, |
| | "top_p": top_p, |
| | "repetition_penalty": repetition_penalty, |
| | "target_duration": None, |
| | "user_metadata": user_metadata, |
| | "skip_caption": not use_cot_caption, |
| | "skip_language": not use_cot_language, |
| | "skip_genres": True, |
| | "generation_phase": "cot", |
| | |
| | "caption": caption, |
| | "lyrics": lyrics, |
| | }, |
| | use_constrained_decoding=use_constrained_decoding, |
| | constrained_decoding_debug=constrained_decoding_debug, |
| | stop_at_reasoning=True, |
| | ) |
| | |
| | phase1_time = time.time() - phase1_start |
| | |
| | if not cot_output_text: |
| | return { |
| | "metadata": [] if is_batch else {}, |
| | "audio_codes": [] if is_batch else "", |
| | "success": False, |
| | "error": status, |
| | "extra_outputs": {"time_costs": {"phase1_time": phase1_time}}, |
| | } |
| | |
| | |
| | metadata, _ = self.parse_lm_output(cot_output_text) |
| | if is_batch: |
| | logger.info(f"Batch Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}") |
| | else: |
| | logger.info(f"Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}") |
| | else: |
| | |
| | if is_batch: |
| | logger.info("Batch Phase 1: Using user-provided metadata (skipping generation)") |
| | else: |
| | logger.info("Phase 1: Using user-provided metadata (skipping generation)") |
| | metadata = {k: v for k, v in user_metadata.items() if v is not None} |
| | |
| | |
| | if infer_type == "dit": |
| | if is_batch: |
| | metadata_list = [metadata.copy() for _ in range(actual_batch_size)] |
| | return { |
| | "metadata": metadata_list, |
| | "audio_codes": [""] * actual_batch_size, |
| | "success": True, |
| | "error": None, |
| | "extra_outputs": { |
| | "time_costs": { |
| | "phase1_time": phase1_time, |
| | "total_time": phase1_time, |
| | } |
| | }, |
| | } |
| | else: |
| | return { |
| | "metadata": metadata, |
| | "audio_codes": "", |
| | "success": True, |
| | "error": None, |
| | "extra_outputs": { |
| | "time_costs": { |
| | "phase1_time": phase1_time, |
| | "total_time": phase1_time, |
| | } |
| | }, |
| | } |
| | |
| | |
| | if is_batch: |
| | logger.info(f"Batch Phase 2: Generating audio codes for {actual_batch_size} items...") |
| | else: |
| | logger.info("Phase 2: Generating audio codes...") |
| | phase2_start = time.time() |
| | |
| | |
| | cot_text = self._format_metadata_as_cot(metadata) |
| | |
| | |
| | formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text) |
| | logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}") |
| | |
| | progress(0.5, f"Phase 2: Generating audio codes for {actual_batch_size} items...") |
| | if is_batch: |
| | |
| | formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size |
| | |
| | |
| | try: |
| | if self.llm_backend == "vllm": |
| | codes_outputs = self._run_vllm( |
| | formatted_prompts=formatted_prompts, |
| | temperature=temperature, |
| | cfg_scale=cfg_scale, |
| | negative_prompt=negative_prompt, |
| | top_k=top_k, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | use_constrained_decoding=use_constrained_decoding, |
| | constrained_decoding_debug=constrained_decoding_debug, |
| | target_duration=target_duration, |
| | generation_phase="codes", |
| | caption=caption, |
| | lyrics=lyrics, |
| | cot_text=cot_text, |
| | seeds=seeds, |
| | ) |
| | else: |
| | codes_outputs = self._run_pt( |
| | formatted_prompts=formatted_prompts, |
| | temperature=temperature, |
| | cfg_scale=cfg_scale, |
| | negative_prompt=negative_prompt, |
| | top_k=top_k, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | use_constrained_decoding=use_constrained_decoding, |
| | constrained_decoding_debug=constrained_decoding_debug, |
| | target_duration=target_duration, |
| | generation_phase="codes", |
| | caption=caption, |
| | lyrics=lyrics, |
| | cot_text=cot_text, |
| | seeds=seeds, |
| | ) |
| | except Exception as e: |
| | error_msg = f"Error in batch codes generation: {str(e)}" |
| | logger.error(error_msg) |
| | return { |
| | "metadata": [], |
| | "audio_codes": [], |
| | "success": False, |
| | "error": error_msg, |
| | "extra_outputs": { |
| | "time_costs": { |
| | "phase1_time": phase1_time, |
| | "phase2_time": 0.0, |
| | "total_time": phase1_time, |
| | } |
| | }, |
| | } |
| | |
| | |
| | audio_codes_list = [] |
| | metadata_list = [] |
| | for output_text in codes_outputs: |
| | _, audio_codes_item = self.parse_lm_output(output_text) |
| | audio_codes_list.append(audio_codes_item) |
| | metadata_list.append(metadata.copy()) |
| | |
| | phase2_time = time.time() - phase2_start |
| | |
| | |
| | codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list] |
| | logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}") |
| | |
| | total_time = phase1_time + phase2_time |
| | return { |
| | "metadata": metadata_list, |
| | "audio_codes": audio_codes_list, |
| | "success": True, |
| | "error": None, |
| | "extra_outputs": { |
| | "time_costs": { |
| | "phase1_time": phase1_time, |
| | "phase2_time": phase2_time, |
| | "total_time": total_time, |
| | }, |
| | "codes_counts": codes_counts, |
| | "total_codes": sum(codes_counts), |
| | }, |
| | } |
| | else: |
| | |
| | codes_output_text, status = self.generate_from_formatted_prompt( |
| | formatted_prompt=formatted_prompt_with_cot, |
| | cfg={ |
| | "temperature": temperature, |
| | "cfg_scale": cfg_scale, |
| | "negative_prompt": negative_prompt, |
| | "top_k": top_k, |
| | "top_p": top_p, |
| | "repetition_penalty": repetition_penalty, |
| | "target_duration": target_duration, |
| | "user_metadata": None, |
| | "skip_caption": True, |
| | "skip_language": True, |
| | "generation_phase": "codes", |
| | |
| | "caption": caption, |
| | "lyrics": lyrics, |
| | "cot_text": cot_text, |
| | }, |
| | use_constrained_decoding=use_constrained_decoding, |
| | constrained_decoding_debug=constrained_decoding_debug, |
| | stop_at_reasoning=False, |
| | ) |
| | |
| | if not codes_output_text: |
| | total_time = phase1_time + phase2_time |
| | return { |
| | "metadata": metadata, |
| | "audio_codes": "", |
| | "success": False, |
| | "error": status, |
| | "extra_outputs": { |
| | "time_costs": { |
| | "phase1_time": phase1_time, |
| | "phase2_time": phase2_time, |
| | "total_time": total_time, |
| | } |
| | }, |
| | } |
| | |
| | phase2_time = time.time() - phase2_start |
| | |
| | |
| | _, audio_codes = self.parse_lm_output(codes_output_text) |
| | |
| | codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0 |
| | logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes") |
| | |
| | total_time = phase1_time + phase2_time |
| | return { |
| | "metadata": metadata, |
| | "audio_codes": audio_codes, |
| | "success": True, |
| | "error": None, |
| | "extra_outputs": { |
| | "time_costs": { |
| | "phase1_time": phase1_time, |
| | "phase2_time": phase2_time, |
| | "total_time": total_time, |
| | }, |
| | "codes_count": codes_count, |
| | }, |
| | } |
| | |
| | def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str: |
| | """ |
| | Build the chat-formatted prompt for 5Hz LM from caption/lyrics. |
| | Raises a ValueError if the tokenizer is not initialized. |
| | |
| | Args: |
| | caption: Caption text |
| | lyrics: Lyrics text |
| | is_negative_prompt: If True, builds unconditional prompt for CFG |
| | generation_phase: "cot" or "codes" - affects unconditional prompt format |
| | negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True) |
| | |
| | Example: |
| | prompt = handler.build_formatted_prompt("calm piano", "hello world") |
| | """ |
| | if self.llm_tokenizer is None: |
| | raise ValueError("LLM tokenizer is not initialized. Call initialize() first.") |
| | |
| | if is_negative_prompt: |
| | |
| | |
| | has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt) |
| | |
| | if generation_phase == "cot": |
| | |
| | if has_negative_prompt: |
| | |
| | prompt = f"# Caption\n{negative_prompt}\n\n# Lyric\n{lyrics}\n" |
| | else: |
| | |
| | prompt = f"# Lyric\n{lyrics}\n" |
| | else: |
| | |
| | |
| | prompt = caption |
| | else: |
| | |
| | prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n" |
| | |
| | return self.llm_tokenizer.apply_chat_template( |
| | [ |
| | {"role": "system", "content": f"# Instruction\n{DEFAULT_LM_INSTRUCTION}\n\n"}, |
| | {"role": "user", "content": prompt}, |
| | ], |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| | |
| | def build_formatted_prompt_with_cot(self, caption: str, lyrics: str, cot_text: str, is_negative_prompt: bool = False, negative_prompt: str = "NO USER INPUT") -> str: |
| | """ |
| | Build the chat-formatted prompt for codes generation phase with pre-generated CoT. |
| | |
| | Args: |
| | caption: Caption text |
| | lyrics: Lyrics text |
| | cot_text: Pre-generated CoT text (e.g., "<think>\\nbpm: 120\\n...\\n</think>") |
| | is_negative_prompt: If True, uses empty CoT for CFG unconditional prompt |
| | negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True) |
| | |
| | Returns: |
| | Formatted prompt string |
| | |
| | Example: |
| | cot = "<think>\\nbpm: 120\\ncaption: calm piano\\n...\\n</think>" |
| | prompt = handler.build_formatted_prompt_with_cot("calm piano", "hello", cot) |
| | """ |
| | if self.llm_tokenizer is None: |
| | raise ValueError("LLM tokenizer is not initialized. Call initialize() first.") |
| | |
| | if is_negative_prompt: |
| | |
| | |
| | has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt) |
| | |
| | |
| | cot_for_prompt = "<think>\n</think>" |
| | |
| | if has_negative_prompt: |
| | |
| | caption_for_prompt = negative_prompt |
| | else: |
| | |
| | caption_for_prompt = caption |
| | else: |
| | |
| | cot_for_prompt = cot_text |
| | caption_for_prompt = caption |
| | |
| | |
| | |
| | user_prompt = f"# Caption\n{caption_for_prompt}\n\n# Lyric\n{lyrics}\n" |
| | |
| | |
| | |
| | formatted = self.llm_tokenizer.apply_chat_template( |
| | [ |
| | {"role": "system", "content": f"# Instruction\n{DEFAULT_LM_INSTRUCTION}\n\n"}, |
| | {"role": "user", "content": user_prompt}, |
| | {"role": "assistant", "content": cot_for_prompt}, |
| | ], |
| | tokenize=False, |
| | add_generation_prompt=False, |
| | ) |
| | |
| | |
| | if not formatted.endswith('\n'): |
| | formatted += '\n' |
| | |
| | return formatted |
| | |
| | def build_formatted_prompt_for_understanding( |
| | self, |
| | audio_codes: str, |
| | is_negative_prompt: bool = False, |
| | negative_prompt: str = "NO USER INPUT" |
| | ) -> str: |
| | """ |
| | Build the chat-formatted prompt for audio understanding from codes. |
| | |
| | This is the reverse of generation: given audio codes, generate metadata and lyrics. |
| | |
| | Args: |
| | audio_codes: Audio code string (e.g., "<|audio_code_123|><|audio_code_456|>...") |
| | is_negative_prompt: If True, builds unconditional prompt for CFG |
| | negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True) |
| | |
| | Returns: |
| | Formatted prompt string |
| | |
| | Example: |
| | codes = "<|audio_code_18953|><|audio_code_13833|>..." |
| | prompt = handler.build_formatted_prompt_for_understanding(codes) |
| | """ |
| | if self.llm_tokenizer is None: |
| | raise ValueError("LLM tokenizer is not initialized. Call initialize() first.") |
| | |
| | |
| | |
| | if is_negative_prompt: |
| | user_content = negative_prompt if negative_prompt and negative_prompt.strip() else "" |
| | else: |
| | user_content = audio_codes |
| | |
| | return self.llm_tokenizer.apply_chat_template( |
| | [ |
| | { |
| | "role": "system", |
| | "content": f"# Instruction\n{DEFAULT_LM_UNDERSTAND_INSTRUCTION}\n\n" |
| | }, |
| | { |
| | "role": "user", |
| | "content": user_content |
| | }, |
| | ], |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| | |
| | def understand_audio_from_codes( |
| | self, |
| | audio_codes: str, |
| | temperature: float = 0.3, |
| | top_k: Optional[int] = None, |
| | top_p: Optional[float] = None, |
| | repetition_penalty: float = 1.0, |
| | use_constrained_decoding: bool = True, |
| | constrained_decoding_debug: bool = False, |
| | ) -> Tuple[Dict[str, Any], str]: |
| | """ |
| | Understand audio codes and generate metadata + lyrics. |
| | |
| | This is the reverse of the normal generation flow: |
| | - Input: Audio codes |
| | - Output: Metadata (bpm, caption, duration, etc.) + Lyrics |
| | |
| | Note: cfg_scale and negative_prompt are not supported in understand mode. |
| | |
| | Args: |
| | audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...") |
| | temperature: Sampling temperature for generation |
| | top_k: Top-K sampling (None = disabled) |
| | top_p: Top-P (nucleus) sampling (None = disabled) |
| | repetition_penalty: Repetition penalty (1.0 = no penalty) |
| | use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata |
| | constrained_decoding_debug: Whether to enable debug logging for constrained decoding |
| | |
| | Returns: |
| | Tuple of (metadata_dict, status_message) |
| | metadata_dict contains: |
| | - bpm: int or str |
| | - caption: str |
| | - duration: int or str |
| | - keyscale: str |
| | - language: str |
| | - timesignature: str |
| | - lyrics: str (extracted from output after </think>) |
| | |
| | Example: |
| | codes = "<|audio_code_18953|><|audio_code_13833|>..." |
| | metadata, status = handler.understand_audio_from_codes(codes) |
| | print(metadata['caption']) # "A cinematic orchestral piece..." |
| | print(metadata['lyrics']) # "[Intro: ...]\\n..." |
| | """ |
| | if not getattr(self, "llm_initialized", False): |
| | return {}, "❌ 5Hz LM not initialized. Please initialize it first." |
| | |
| | if not audio_codes or not audio_codes.strip(): |
| | return {}, "❌ No audio codes provided. Please paste audio codes first." |
| | |
| | logger.info(f"Understanding audio codes (length: {len(audio_codes)} chars)") |
| | |
| | |
| | formatted_prompt = self.build_formatted_prompt_for_understanding(audio_codes) |
| | print(f"formatted_prompt: {formatted_prompt}") |
| | |
| | |
| | |
| | output_text, status = self.generate_from_formatted_prompt( |
| | formatted_prompt=formatted_prompt, |
| | cfg={ |
| | "temperature": temperature, |
| | "top_k": top_k, |
| | "top_p": top_p, |
| | "repetition_penalty": repetition_penalty, |
| | "target_duration": None, |
| | "user_metadata": None, |
| | "skip_caption": False, |
| | "skip_language": False, |
| | "skip_genres": False, |
| | "generation_phase": "understand", |
| | |
| | "caption": "", |
| | "lyrics": "", |
| | }, |
| | use_constrained_decoding=use_constrained_decoding, |
| | constrained_decoding_debug=constrained_decoding_debug, |
| | stop_at_reasoning=False, |
| | ) |
| | |
| | if not output_text: |
| | return {}, status |
| | |
| | |
| | metadata, _ = self.parse_lm_output(output_text) |
| | |
| | |
| | lyrics = self._extract_lyrics_from_output(output_text) |
| | if lyrics: |
| | metadata['lyrics'] = lyrics |
| | |
| | logger.info(f"Understanding completed. Generated {len(metadata)} metadata fields") |
| | if constrained_decoding_debug: |
| | logger.debug(f"Generated metadata: {list(metadata.keys())}") |
| | logger.debug(f"Output text preview: {output_text[:200]}...") |
| | |
| | status_msg = f"✅ Understanding completed successfully\nGenerated fields: {', '.join(metadata.keys())}" |
| | return metadata, status_msg |
| | |
| | def _extract_lyrics_from_output(self, output_text: str) -> str: |
| | """ |
| | Extract lyrics section from LLM output. |
| | |
| | The lyrics appear after the </think> tag and typically start with "# Lyric" |
| | or directly with lyric content. |
| | |
| | Args: |
| | output_text: Full LLM output text |
| | |
| | Returns: |
| | Extracted lyrics string, or empty string if no lyrics found |
| | """ |
| | import re |
| | |
| | |
| | think_end_pattern = r'</think>' |
| | match = re.search(think_end_pattern, output_text) |
| | |
| | if not match: |
| | |
| | return "" |
| | |
| | |
| | after_think = output_text[match.end():].strip() |
| | |
| | if not after_think: |
| | return "" |
| | |
| | |
| | lyric_header_pattern = r'^#\s*Lyri[c|cs]?\s*\n' |
| | after_think = re.sub(lyric_header_pattern, '', after_think, flags=re.IGNORECASE) |
| | |
| | |
| | after_think = re.sub(r'<\|im_end\|>\s*$', '', after_think) |
| | |
| | return after_think.strip() |
| | |
| | def build_formatted_prompt_for_inspiration( |
| | self, |
| | query: str, |
| | instrumental: bool = False, |
| | is_negative_prompt: bool = False, |
| | negative_prompt: str = "NO USER INPUT" |
| | ) -> str: |
| | """ |
| | Build the chat-formatted prompt for inspiration/simple mode. |
| | |
| | This generates a complete sample (caption, lyrics, metadata) from a user's |
| | natural language music description query. |
| | |
| | Args: |
| | query: User's natural language music description |
| | instrumental: Whether to generate instrumental music (no vocals) |
| | is_negative_prompt: If True, builds unconditional prompt for CFG |
| | negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True) |
| | |
| | Returns: |
| | Formatted prompt string |
| | |
| | Example: |
| | query = "a soft Bengali love song for a quiet evening" |
| | prompt = handler.build_formatted_prompt_for_inspiration(query, instrumental=False) |
| | """ |
| | if self.llm_tokenizer is None: |
| | raise ValueError("LLM tokenizer is not initialized. Call initialize() first.") |
| | |
| | |
| | instrumental_str = "true" if instrumental else "false" |
| | |
| | if is_negative_prompt: |
| | |
| | user_content = negative_prompt if negative_prompt and negative_prompt.strip() else "" |
| | else: |
| | |
| | user_content = f"{query}\n\ninstrumental: {instrumental_str}" |
| | |
| | return self.llm_tokenizer.apply_chat_template( |
| | [ |
| | { |
| | "role": "system", |
| | "content": f"# Instruction\n{DEFAULT_LM_INSPIRED_INSTRUCTION}\n\n" |
| | }, |
| | { |
| | "role": "user", |
| | "content": user_content |
| | }, |
| | ], |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| | |
| | def create_sample_from_query( |
| | self, |
| | query: str, |
| | instrumental: bool = False, |
| | vocal_language: Optional[str] = None, |
| | temperature: float = 0.85, |
| | top_k: Optional[int] = None, |
| | top_p: Optional[float] = None, |
| | repetition_penalty: float = 1.0, |
| | use_constrained_decoding: bool = True, |
| | constrained_decoding_debug: bool = False, |
| | ) -> Tuple[Dict[str, Any], str]: |
| | """ |
| | Create a complete music sample from a user's natural language query. |
| | |
| | This is the "Simple Mode" / "Inspiration Mode" feature that generates: |
| | - Metadata (bpm, caption, duration, keyscale, language, timesignature) |
| | - Lyrics (unless instrumental=True) |
| | |
| | Args: |
| | query: User's natural language music description |
| | instrumental: Whether to generate instrumental music (no vocals) |
| | vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh"). |
| | If provided and not "unknown", it will be used. |
| | temperature: Sampling temperature for generation (0.0-2.0) |
| | top_k: Top-K sampling (None = disabled) |
| | top_p: Top-P (nucleus) sampling (None = disabled) |
| | repetition_penalty: Repetition penalty (1.0 = no penalty) |
| | use_constrained_decoding: Whether to use FSM-based constrained decoding |
| | constrained_decoding_debug: Whether to enable debug logging |
| | |
| | Returns: |
| | Tuple of (metadata_dict, status_message) |
| | metadata_dict contains: |
| | - bpm: int or str |
| | - caption: str |
| | - duration: int or str |
| | - keyscale: str |
| | - language: str |
| | - timesignature: str |
| | - lyrics: str (extracted from output after </think>) |
| | - instrumental: bool (echoed back) |
| | |
| | Example: |
| | query = "a soft Bengali love song for a quiet evening" |
| | metadata, status = handler.create_sample_from_query(query, instrumental=False, vocal_language="bn") |
| | print(metadata['caption']) # "A gentle romantic acoustic pop ballad..." |
| | print(metadata['lyrics']) # "[Intro: ...]\\n..." |
| | """ |
| | if not getattr(self, "llm_initialized", False): |
| | return {}, "❌ 5Hz LM not initialized. Please initialize it first." |
| | |
| | if not query or not query.strip(): |
| | query = "NO USER INPUT" |
| | |
| | logger.info(f"Creating sample from query: {query[:100]}... (instrumental={instrumental}, vocal_language={vocal_language})") |
| | |
| | |
| | formatted_prompt = self.build_formatted_prompt_for_inspiration( |
| | query=query, |
| | instrumental=instrumental, |
| | ) |
| | logger.debug(f"Formatted prompt for inspiration: {formatted_prompt}") |
| | |
| | |
| | user_metadata = None |
| | skip_language = False |
| | if vocal_language and vocal_language.strip() and vocal_language.strip().lower() != "unknown": |
| | |
| | user_metadata = {"language": vocal_language.strip()} |
| | |
| | logger.info(f"Using user-specified language: {vocal_language.strip()}") |
| | |
| | |
| | |
| | |
| | output_text, status = self.generate_from_formatted_prompt( |
| | formatted_prompt=formatted_prompt, |
| | cfg={ |
| | "temperature": temperature, |
| | "top_k": top_k, |
| | "top_p": top_p, |
| | "repetition_penalty": repetition_penalty, |
| | "target_duration": None, |
| | "user_metadata": user_metadata, |
| | "skip_caption": False, |
| | "skip_language": False, |
| | "skip_genres": False, |
| | "generation_phase": "understand", |
| | "caption": "", |
| | "lyrics": "", |
| | }, |
| | use_constrained_decoding=use_constrained_decoding, |
| | constrained_decoding_debug=constrained_decoding_debug, |
| | stop_at_reasoning=False, |
| | ) |
| | |
| | if not output_text: |
| | return {}, status |
| | |
| | |
| | metadata, _ = self.parse_lm_output(output_text) |
| | |
| | |
| | lyrics = self._extract_lyrics_from_output(output_text) |
| | if lyrics: |
| | metadata['lyrics'] = lyrics |
| | elif instrumental: |
| | |
| | metadata['lyrics'] = "[Instrumental]" |
| | |
| | |
| | metadata['instrumental'] = instrumental |
| | |
| | logger.info(f"Sample created successfully. Generated {metadata} fields") |
| | if constrained_decoding_debug: |
| | logger.debug(f"Generated metadata: {list(metadata.keys())}") |
| | logger.debug(f"Output text preview: {output_text[:300]}...") |
| | |
| | status_msg = f"✅ Sample created successfully\nGenerated fields: {metadata}" |
| | return metadata, status_msg |
| | |
| | def build_formatted_prompt_for_format( |
| | self, |
| | caption: str, |
| | lyrics: str, |
| | is_negative_prompt: bool = False, |
| | negative_prompt: str = "NO USER INPUT" |
| | ) -> str: |
| | """ |
| | Build the chat-formatted prompt for format/rewrite mode. |
| | |
| | This formats user-provided caption and lyrics into a more detailed and specific |
| | musical description with metadata. |
| | |
| | Args: |
| | caption: User's caption/description of the music |
| | lyrics: User's lyrics |
| | is_negative_prompt: If True, builds unconditional prompt for CFG |
| | negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True) |
| | |
| | Returns: |
| | Formatted prompt string |
| | |
| | Example: |
| | caption = "Latin pop, reggaeton, flamenco-pop" |
| | lyrics = "[Verse 1]\\nTengo un nudo..." |
| | prompt = handler.build_formatted_prompt_for_format(caption, lyrics) |
| | """ |
| | if self.llm_tokenizer is None: |
| | raise ValueError("LLM tokenizer is not initialized. Call initialize() first.") |
| | |
| | if is_negative_prompt: |
| | |
| | user_content = negative_prompt if negative_prompt and negative_prompt.strip() else "" |
| | else: |
| | |
| | user_content = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}" |
| | |
| | return self.llm_tokenizer.apply_chat_template( |
| | [ |
| | { |
| | "role": "system", |
| | "content": f"# Instruction\n{DEFAULT_LM_REWRITE_INSTRUCTION}\n\n" |
| | }, |
| | { |
| | "role": "user", |
| | "content": user_content |
| | }, |
| | ], |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| | |
| | def format_sample_from_input( |
| | self, |
| | caption: str, |
| | lyrics: str, |
| | user_metadata: Optional[Dict[str, Any]] = None, |
| | temperature: float = 0.85, |
| | top_k: Optional[int] = None, |
| | top_p: Optional[float] = None, |
| | repetition_penalty: float = 1.0, |
| | use_constrained_decoding: bool = True, |
| | constrained_decoding_debug: bool = False, |
| | ) -> Tuple[Dict[str, Any], str]: |
| | """ |
| | Format user-provided caption and lyrics into structured music metadata. |
| | |
| | This is the "Format" feature that takes user input and generates: |
| | - Enhanced caption with detailed music description |
| | - Metadata (bpm, duration, keyscale, language, timesignature) |
| | - Formatted lyrics (preserved from input) |
| | |
| | Note: cfg_scale and negative_prompt are not supported in format mode. |
| | |
| | Args: |
| | caption: User's caption/description (e.g., "Latin pop, reggaeton") |
| | lyrics: User's lyrics with structure tags |
| | user_metadata: Optional dict with user-provided metadata to constrain decoding. |
| | Supported keys: bpm, duration, keyscale, timesignature, language |
| | temperature: Sampling temperature for generation (0.0-2.0) |
| | top_k: Top-K sampling (None = disabled) |
| | top_p: Top-P (nucleus) sampling (None = disabled) |
| | repetition_penalty: Repetition penalty (1.0 = no penalty) |
| | use_constrained_decoding: Whether to use FSM-based constrained decoding |
| | constrained_decoding_debug: Whether to enable debug logging |
| | |
| | Returns: |
| | Tuple of (metadata_dict, status_message) |
| | metadata_dict contains: |
| | - bpm: int or str |
| | - caption: str (enhanced) |
| | - duration: int or str |
| | - keyscale: str |
| | - language: str |
| | - timesignature: str |
| | - lyrics: str (from input, possibly formatted) |
| | |
| | Example: |
| | caption = "Latin pop, reggaeton, flamenco-pop" |
| | lyrics = "[Verse 1]\\nTengo un nudo en la garganta..." |
| | metadata, status = handler.format_sample_from_input(caption, lyrics) |
| | print(metadata['caption']) # "A dramatic and powerful Latin pop track..." |
| | print(metadata['bpm']) # 100 |
| | """ |
| | if not getattr(self, "llm_initialized", False): |
| | return {}, "❌ 5Hz LM not initialized. Please initialize it first." |
| | |
| | if not caption or not caption.strip(): |
| | caption = "NO USER INPUT" |
| | if not lyrics or not lyrics.strip(): |
| | lyrics = "[Instrumental]" |
| | |
| | logger.info(f"Formatting sample from input: caption={caption[:50]}..., lyrics length={len(lyrics)}") |
| | |
| | |
| | formatted_prompt = self.build_formatted_prompt_for_format( |
| | caption=caption, |
| | lyrics=lyrics, |
| | ) |
| | logger.debug(f"Formatted prompt for format: {formatted_prompt}") |
| | |
| | |
| | constrained_metadata = None |
| | if user_metadata: |
| | constrained_metadata = {} |
| | if user_metadata.get('bpm') is not None: |
| | try: |
| | bpm_val = int(user_metadata['bpm']) |
| | if bpm_val > 0: |
| | constrained_metadata['bpm'] = bpm_val |
| | except (ValueError, TypeError): |
| | pass |
| | if user_metadata.get('duration') is not None: |
| | try: |
| | dur_val = int(user_metadata['duration']) |
| | if dur_val > 0: |
| | constrained_metadata['duration'] = dur_val |
| | except (ValueError, TypeError): |
| | pass |
| | if user_metadata.get('keyscale'): |
| | constrained_metadata['keyscale'] = user_metadata['keyscale'] |
| | if user_metadata.get('timesignature'): |
| | constrained_metadata['timesignature'] = user_metadata['timesignature'] |
| | if user_metadata.get('language'): |
| | constrained_metadata['language'] = user_metadata['language'] |
| | |
| | |
| | if not constrained_metadata: |
| | constrained_metadata = None |
| | else: |
| | logger.info(f"Using user-provided metadata constraints: {constrained_metadata}") |
| | |
| | |
| | |
| | |
| | output_text, status = self.generate_from_formatted_prompt( |
| | formatted_prompt=formatted_prompt, |
| | cfg={ |
| | "temperature": temperature, |
| | "top_k": top_k, |
| | "top_p": top_p, |
| | "repetition_penalty": repetition_penalty, |
| | "target_duration": None, |
| | "user_metadata": constrained_metadata, |
| | "skip_caption": False, |
| | "skip_language": constrained_metadata.get('language') is not None if constrained_metadata else False, |
| | "skip_genres": False, |
| | "generation_phase": "understand", |
| | "caption": "", |
| | "lyrics": "", |
| | }, |
| | use_constrained_decoding=use_constrained_decoding, |
| | constrained_decoding_debug=constrained_decoding_debug, |
| | stop_at_reasoning=False, |
| | ) |
| | |
| | if not output_text: |
| | return {}, status |
| | |
| | |
| | metadata, _ = self.parse_lm_output(output_text) |
| | |
| | |
| | formatted_lyrics = self._extract_lyrics_from_output(output_text) |
| | if formatted_lyrics: |
| | metadata['lyrics'] = formatted_lyrics |
| | else: |
| | |
| | metadata['lyrics'] = lyrics |
| | |
| | logger.info(f"Format completed successfully. Generated {metadata} fields") |
| | if constrained_decoding_debug: |
| | logger.debug(f"Generated metadata: {list(metadata.keys())}") |
| | logger.debug(f"Output text preview: {output_text[:300]}...") |
| | |
| | status_msg = f"✅ Format completed successfully\nGenerated fields: {', '.join(metadata.keys())}" |
| | return metadata, status_msg |
| | |
| | def generate_from_formatted_prompt( |
| | self, |
| | formatted_prompt: str, |
| | cfg: Optional[Dict[str, Any]] = None, |
| | use_constrained_decoding: bool = True, |
| | constrained_decoding_debug: bool = False, |
| | stop_at_reasoning: bool = False, |
| | ) -> Tuple[str, str]: |
| | """ |
| | Generate raw LM text output from a pre-built formatted prompt. |
| | |
| | Args: |
| | formatted_prompt: Prompt that is already formatted by `build_formatted_prompt`. |
| | cfg: Optional dict supporting keys: |
| | - temperature (float) |
| | - cfg_scale (float) |
| | - negative_prompt (str) used when cfg_scale > 1 |
| | - top_k (int), top_p (float), repetition_penalty (float) |
| | - target_duration (float): Target duration in seconds for codes generation |
| | - generation_phase (str): "cot" or "codes" for phase-aware CFG |
| | use_constrained_decoding: Whether to use FSM-based constrained decoding |
| | constrained_decoding_debug: Whether to enable debug logging for constrained decoding |
| | stop_at_reasoning: If True, stop generation immediately after </think> tag (no audio codes) |
| | |
| | Returns: |
| | (output_text, status_message) |
| | |
| | Example: |
| | prompt = handler.build_formatted_prompt(caption, lyric) |
| | text, status = handler.generate_from_formatted_prompt(prompt, {"temperature": 0.7}) |
| | """ |
| | if not getattr(self, "llm_initialized", False): |
| | return "", "❌ 5Hz LM not initialized. Please initialize it first." |
| | if self.llm is None or self.llm_tokenizer is None: |
| | return "", "❌ 5Hz LM is missing model or tokenizer." |
| |
|
| | cfg = cfg or {} |
| | temperature = cfg.get("temperature", 0.6) |
| | cfg_scale = cfg.get("cfg_scale", 1.0) |
| | negative_prompt = cfg.get("negative_prompt", "NO USER INPUT") |
| | top_k = cfg.get("top_k") |
| | top_p = cfg.get("top_p") |
| | repetition_penalty = cfg.get("repetition_penalty", 1.0) |
| | target_duration = cfg.get("target_duration") |
| | user_metadata = cfg.get("user_metadata") |
| | skip_caption = cfg.get("skip_caption", False) |
| | skip_language = cfg.get("skip_language", False) |
| | skip_genres = cfg.get("skip_genres", False) |
| | generation_phase = cfg.get("generation_phase", "cot") |
| | |
| | caption = cfg.get("caption", "") |
| | lyrics = cfg.get("lyrics", "") |
| | cot_text = cfg.get("cot_text", "") |
| |
|
| | try: |
| | if self.llm_backend == "vllm": |
| | output_text = self._run_vllm( |
| | formatted_prompts=formatted_prompt, |
| | temperature=temperature, |
| | cfg_scale=cfg_scale, |
| | negative_prompt=negative_prompt, |
| | top_k=top_k, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | use_constrained_decoding=use_constrained_decoding, |
| | constrained_decoding_debug=constrained_decoding_debug, |
| | target_duration=target_duration, |
| | user_metadata=user_metadata, |
| | stop_at_reasoning=stop_at_reasoning, |
| | skip_genres=skip_genres, |
| | skip_caption=skip_caption, |
| | skip_language=skip_language, |
| | generation_phase=generation_phase, |
| | caption=caption, |
| | lyrics=lyrics, |
| | cot_text=cot_text, |
| | ) |
| | return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}" |
| |
|
| | |
| | output_text = self._run_pt( |
| | formatted_prompts=formatted_prompt, |
| | temperature=temperature, |
| | cfg_scale=cfg_scale, |
| | negative_prompt=negative_prompt, |
| | top_k=top_k, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | use_constrained_decoding=use_constrained_decoding, |
| | constrained_decoding_debug=constrained_decoding_debug, |
| | target_duration=target_duration, |
| | user_metadata=user_metadata, |
| | stop_at_reasoning=stop_at_reasoning, |
| | skip_genres=skip_genres, |
| | skip_caption=skip_caption, |
| | skip_language=skip_language, |
| | generation_phase=generation_phase, |
| | caption=caption, |
| | lyrics=lyrics, |
| | cot_text=cot_text, |
| | ) |
| | return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}" |
| |
|
| | except Exception as e: |
| | return "", f"❌ Error generating from formatted prompt: {e}" |
| | |
| | def _generate_with_constrained_decoding( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor], |
| | max_new_tokens: int, |
| | temperature: float, |
| | top_k: Optional[int], |
| | top_p: Optional[float], |
| | repetition_penalty: float, |
| | pad_token_id: int, |
| | streamer: Optional[BaseStreamer], |
| | constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Custom generation loop with constrained decoding support (non-CFG). |
| | This allows us to call update_state() after each token generation. |
| | """ |
| | model = self.llm |
| | device = self.device |
| | |
| | |
| | generated_ids = input_ids.clone() |
| | if attention_mask is not None: |
| | attn_mask = attention_mask.clone() |
| | else: |
| | attn_mask = torch.ones_like(input_ids) |
| | |
| | |
| | model_kwargs = {'attention_mask': attn_mask} |
| | |
| | |
| | past_key_values = None |
| | use_cache = hasattr(model, 'generation_config') and getattr(model.generation_config, 'use_cache', True) |
| | |
| | |
| | eos_token_id = self.llm_tokenizer.eos_token_id |
| | if eos_token_id is None: |
| | eos_token_id = pad_token_id |
| | |
| | |
| | logits_processor = self._build_logits_processor(repetition_penalty) |
| | |
| | with torch.no_grad(): |
| | for step in range(max_new_tokens): |
| | |
| | outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache) |
| | |
| | |
| | next_token_logits = outputs.logits[:, -1, :] |
| | |
| | |
| | if constrained_processor is not None: |
| | next_token_logits = constrained_processor(generated_ids, next_token_logits) |
| | |
| | |
| | for processor in logits_processor: |
| | next_token_logits = processor(generated_ids, next_token_logits) |
| | |
| | |
| | next_token_logits = self._apply_top_k_filter(next_token_logits, top_k) |
| | next_token_logits = self._apply_top_p_filter(next_token_logits, top_p) |
| | |
| | |
| | next_tokens = self._sample_tokens(next_token_logits, temperature) |
| | |
| | |
| | self._update_constrained_processor_state(constrained_processor, next_tokens) |
| | |
| | |
| | should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id) |
| | |
| | |
| | next_tokens_unsqueezed = next_tokens.unsqueeze(1) |
| | generated_ids = torch.cat([generated_ids, next_tokens_unsqueezed], dim=1) |
| | attn_mask = torch.cat([attn_mask, torch.ones((input_ids.shape[0], 1), device=device, dtype=attn_mask.dtype)], dim=1) |
| | model_kwargs['attention_mask'] = attn_mask |
| | |
| | |
| | if use_cache and hasattr(outputs, 'past_key_values'): |
| | past_key_values = outputs.past_key_values |
| | |
| | |
| | if streamer is not None: |
| | streamer.put(next_tokens_unsqueezed) |
| | |
| | if should_stop: |
| | break |
| | |
| | if streamer is not None: |
| | streamer.end() |
| | |
| | return generated_ids |
| | |
| | def _generate_with_cfg_custom( |
| | self, |
| | batch_input_ids: torch.Tensor, |
| | batch_attention_mask: Optional[torch.Tensor], |
| | max_new_tokens: int, |
| | temperature: float, |
| | cfg_scale: float, |
| | top_k: Optional[int], |
| | top_p: Optional[float], |
| | repetition_penalty: float, |
| | pad_token_id: int, |
| | streamer: Optional[BaseStreamer], |
| | constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Custom CFG generation loop that: |
| | 1. Processes both conditional and unconditional sequences in parallel |
| | 2. Applies CFG formula to logits |
| | 3. Samples tokens only for conditional sequences |
| | 4. Applies the same sampled tokens to both conditional and unconditional sequences |
| | 5. Optionally applies constrained decoding via FSM-based logits processor |
| | |
| | Batch format: [cond_input, uncond_input] |
| | """ |
| | model = self.llm |
| | device = self.device |
| | batch_size = batch_input_ids.shape[0] // 2 |
| | cond_start_idx = 0 |
| | uncond_start_idx = batch_size |
| | |
| | |
| | generated_ids = batch_input_ids.clone() |
| | if batch_attention_mask is not None: |
| | attention_mask = batch_attention_mask.clone() |
| | else: |
| | attention_mask = torch.ones_like(batch_input_ids) |
| | |
| | |
| | model_kwargs = {} |
| | if batch_attention_mask is not None: |
| | model_kwargs['attention_mask'] = attention_mask |
| | |
| | |
| | past_key_values = None |
| | use_cache = hasattr(model, 'generation_config') and getattr(model.generation_config, 'use_cache', True) |
| | |
| | |
| | eos_token_id = self.llm_tokenizer.eos_token_id |
| | if eos_token_id is None: |
| | eos_token_id = pad_token_id |
| | |
| | |
| | logits_processor = self._build_logits_processor(repetition_penalty) |
| | |
| | with torch.no_grad(): |
| | for step in range(max_new_tokens): |
| | |
| | outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache) |
| | |
| | |
| | next_token_logits = outputs.logits[:, -1, :] |
| | |
| | |
| | cond_logits = next_token_logits[cond_start_idx:cond_start_idx+batch_size] |
| | uncond_logits = next_token_logits[uncond_start_idx:uncond_start_idx+batch_size] |
| | |
| | |
| | cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits) |
| | |
| | |
| | if constrained_processor is not None: |
| | current_input_ids = generated_ids[cond_start_idx:cond_start_idx+batch_size] |
| | cfg_logits = constrained_processor(current_input_ids, cfg_logits) |
| | |
| | |
| | |
| | current_input_ids = generated_ids[cond_start_idx:cond_start_idx+batch_size] |
| | for processor in logits_processor: |
| | cfg_logits = processor(current_input_ids, cfg_logits) |
| | |
| | |
| | cfg_logits = self._apply_top_k_filter(cfg_logits, top_k) |
| | cfg_logits = self._apply_top_p_filter(cfg_logits, top_p) |
| | |
| | |
| | next_tokens = self._sample_tokens(cfg_logits, temperature) |
| | |
| | |
| | self._update_constrained_processor_state(constrained_processor, next_tokens) |
| | |
| | |
| | |
| | |
| | should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id) |
| | |
| | |
| | next_tokens_unsqueezed = next_tokens.unsqueeze(1) |
| | generated_ids = torch.cat([generated_ids, next_tokens_unsqueezed.repeat(2, 1)], dim=1) |
| | attention_mask = torch.cat([attention_mask, torch.ones((batch_size*2, 1), device=device, dtype=attention_mask.dtype)], dim=1) |
| | model_kwargs['attention_mask'] = attention_mask |
| | |
| | |
| | if use_cache and hasattr(outputs, 'past_key_values'): |
| | past_key_values = outputs.past_key_values |
| | |
| | |
| | if streamer is not None: |
| | streamer.put(next_tokens_unsqueezed) |
| | |
| | |
| | if should_stop: |
| | break |
| | |
| | if streamer is not None: |
| | streamer.end() |
| | |
| | |
| | |
| | return generated_ids |
| | |
| | def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]: |
| | """ |
| | Parse LM output to extract metadata and audio codes. |
| | |
| | Expected format: |
| | <think> |
| | bpm: 73 |
| | caption: A calm piano melody |
| | duration: 273 |
| | genres: Chinese folk |
| | keyscale: G major |
| | language: en |
| | timesignature: 4 |
| | </think> |
| | |
| | <|audio_code_56535|><|audio_code_62918|>... |
| | |
| | Returns: |
| | Tuple of (metadata_dict, audio_codes_string) |
| | """ |
| | debug_output_text = output_text.split("</think>")[0] |
| | logger.debug(f"Debug output text: {debug_output_text}") |
| | metadata = {} |
| | audio_codes = "" |
| | |
| | import re |
| | |
| | |
| | code_pattern = r'<\|audio_code_\d+\|>' |
| | code_matches = re.findall(code_pattern, output_text) |
| | if code_matches: |
| | audio_codes = "".join(code_matches) |
| | |
| | |
| | |
| | reasoning_patterns = [ |
| | r'<think>(.*?)</think>', |
| | r'<think>(.*?)</think>', |
| | r'<reasoning>(.*?)</reasoning>', |
| | ] |
| | |
| | reasoning_text = None |
| | for pattern in reasoning_patterns: |
| | match = re.search(pattern, output_text, re.DOTALL) |
| | if match: |
| | reasoning_text = match.group(1).strip() |
| | break |
| | |
| | |
| | if not reasoning_text: |
| | |
| | lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text |
| | reasoning_text = lines_before_codes.strip() |
| | |
| | |
| | if reasoning_text: |
| | lines = reasoning_text.split('\n') |
| | current_key = None |
| | current_value_lines = [] |
| | |
| | def save_current_field(): |
| | """Save the accumulated field value""" |
| | nonlocal current_key, current_value_lines |
| | if current_key and current_value_lines: |
| | |
| | value = '\n'.join(current_value_lines) |
| | |
| | if current_key == 'bpm': |
| | try: |
| | metadata['bpm'] = int(value.strip()) |
| | except: |
| | metadata['bpm'] = value.strip() |
| | elif current_key == 'caption': |
| | |
| | metadata['caption'] = MetadataConstrainedLogitsProcessor.postprocess_caption(value) |
| | elif current_key == 'duration': |
| | try: |
| | metadata['duration'] = int(value.strip()) |
| | except: |
| | metadata['duration'] = value.strip() |
| | elif current_key == 'genres': |
| | metadata['genres'] = value.strip() |
| | elif current_key == 'keyscale': |
| | metadata['keyscale'] = value.strip() |
| | elif current_key == 'language': |
| | metadata['language'] = value.strip() |
| | elif current_key == 'timesignature': |
| | metadata['timesignature'] = value.strip() |
| | |
| | current_key = None |
| | current_value_lines = [] |
| | |
| | for line in lines: |
| | |
| | if line.strip().startswith('<'): |
| | continue |
| | |
| | |
| | if line and not line[0].isspace() and ':' in line: |
| | |
| | save_current_field() |
| | |
| | |
| | parts = line.split(':', 1) |
| | if len(parts) == 2: |
| | current_key = parts[0].strip().lower() |
| | |
| | first_value = parts[1] |
| | if first_value.strip(): |
| | current_value_lines.append(first_value) |
| | elif line.startswith(' ') or line.startswith('\t'): |
| | |
| | if current_key: |
| | current_value_lines.append(line) |
| | |
| | |
| | save_current_field() |
| | |
| | return metadata, audio_codes |
| | |
| | @contextmanager |
| | def _load_model_context(self): |
| | """ |
| | Context manager to load a model to GPU and offload it back to CPU after use. |
| | Only used for PyTorch backend when offload_to_cpu is True. |
| | """ |
| | if not self.offload_to_cpu: |
| | yield |
| | return |
| | |
| | |
| | if self.llm_backend == "vllm": |
| | yield |
| | return |
| | |
| | model = self.llm |
| | if model is None: |
| | yield |
| | return |
| | |
| | |
| | logger.info(f"Loading LLM to {self.device}") |
| | start_time = time.time() |
| | if hasattr(model, "to"): |
| | model.to(self.device).to(self.dtype) |
| | load_time = time.time() - start_time |
| | logger.info(f"Loaded LLM to {self.device} in {load_time:.4f}s") |
| |
|
| | try: |
| | yield |
| | finally: |
| | |
| | logger.info(f"Offloading LLM to CPU") |
| | start_time = time.time() |
| | if hasattr(model, "to"): |
| | model.to("cpu") |
| | torch.cuda.empty_cache() |
| | offload_time = time.time() - start_time |
| | logger.info(f"Offloaded LLM to CPU in {offload_time:.4f}s") |
| | |
| | def get_hf_model_for_scoring(self): |
| | """ |
| | Get HuggingFace model for perplexity scoring. |
| | |
| | For vllm backend, loads HuggingFace model from disk (weights are cached by transformers). |
| | For pt backend, returns the existing model. |
| | |
| | Returns: |
| | HuggingFace model instance |
| | """ |
| | if self.llm_backend == "pt": |
| | |
| | return self.llm |
| | |
| | elif self.llm_backend == "vllm": |
| | |
| | |
| | if self._hf_model_for_scoring is None: |
| | logger.info("Loading HuggingFace model for scoring (from checkpoint)") |
| | |
| | |
| | model_runner = self.llm.model_runner |
| | model_path = model_runner.config.model |
| | |
| | |
| | |
| | import time |
| | start_time = time.time() |
| | self._hf_model_for_scoring = AutoModelForCausalLM.from_pretrained( |
| | model_path, |
| | trust_remote_code=True, |
| | torch_dtype=self.dtype |
| | ) |
| | load_time = time.time() - start_time |
| | logger.info(f"HuggingFace model loaded in {load_time:.2f}s") |
| | |
| | |
| | device = next(model_runner.model.parameters()).device |
| | self._hf_model_for_scoring = self._hf_model_for_scoring.to(device) |
| | self._hf_model_for_scoring.eval() |
| | |
| | logger.info(f"HuggingFace model for scoring ready on {device}") |
| | |
| | return self._hf_model_for_scoring |
| | |
| | else: |
| | raise ValueError(f"Unknown backend: {self.llm_backend}") |
| |
|