Spaces:
Running
Running
fix all review issues: dedup sampling/unwrap, thread-safe lock, cleanup, retry, security docs
Browse files- app.py +63 -22
- train_engine.py +42 -72
app.py
CHANGED
|
@@ -13,12 +13,14 @@ import string
|
|
| 13 |
import random
|
| 14 |
import requests
|
| 15 |
import logging
|
|
|
|
| 16 |
|
| 17 |
from train_engine import (
|
| 18 |
preprocess_audio,
|
| 19 |
train_lora_generator,
|
| 20 |
cancel_training,
|
| 21 |
get_trained_loras as _get_trained_loras_engine,
|
|
|
|
| 22 |
)
|
| 23 |
|
| 24 |
logger = logging.getLogger(__name__)
|
|
@@ -28,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|
| 28 |
# ---------------------------------------------------------------------------
|
| 29 |
|
| 30 |
MAX_TOTAL_AUDIO = 1800 # seconds total across all uploaded files (30 min)
|
| 31 |
-
MAX_TRAINING_TIME
|
| 32 |
MAX_AUDIO_FILES = 50 # max number of training audio files per run
|
| 33 |
|
| 34 |
# ---------------------------------------------------------------------------
|
|
@@ -39,6 +41,21 @@ ACE_SERVER = os.environ.get("ACE_SERVER", "http://127.0.0.1:8085")
|
|
| 39 |
OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
|
| 40 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
ACE_CHECKPOINT_DIR = os.environ.get("ACE_CHECKPOINT_DIR", "/app/checkpoints")
|
| 43 |
ACE_SOURCE_DIR = "/app/ace-step-source"
|
| 44 |
ACE_HF_MODEL = "ACE-Step/Ace-Step1.5"
|
|
@@ -49,7 +66,7 @@ ACE_SERVER_BIN = "/app/ace-server"
|
|
| 49 |
|
| 50 |
# Detect if running on HF Space (ace-server available) vs locally (PyTorch only)
|
| 51 |
_is_space = os.path.isfile(ACE_SERVER_BIN) or os.environ.get("SPACE_ID") is not None
|
| 52 |
-
|
| 53 |
|
| 54 |
# HF repo for on-demand GGUF downloads
|
| 55 |
GGUF_HF_REPO = "Serveurperso/ACE-Step-1.5-GGUF"
|
|
@@ -329,25 +346,47 @@ def _stop_ace_server():
|
|
| 329 |
time.sleep(1)
|
| 330 |
|
| 331 |
|
| 332 |
-
def _start_ace_server():
|
| 333 |
-
"""Start ace-server in background and wait for health.
|
|
|
|
|
|
|
|
|
|
| 334 |
global _ace_proc
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
"--models", MODELS_DIR, "--adapters", ADAPTER_DIR, "--max-batch", "1"],
|
| 340 |
)
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
return False
|
| 352 |
|
| 353 |
|
|
@@ -449,8 +488,9 @@ def gradio_main():
|
|
| 449 |
def generate_music(caption, lyrics, instrumental, bpm, duration, seed,
|
| 450 |
steps, lora_select, lm_model_select,
|
| 451 |
progress=gr.Progress(track_tqdm=True)):
|
| 452 |
-
if
|
| 453 |
return None, "Training in progress. Inference unavailable until training completes. Press Cancel to stop training."
|
|
|
|
| 454 |
if not _server_ok():
|
| 455 |
return None, "ace-server not running. Check logs."
|
| 456 |
|
|
@@ -631,8 +671,7 @@ def gradio_main():
|
|
| 631 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 632 |
|
| 633 |
# Stop ace-server before training (frees memory)
|
| 634 |
-
|
| 635 |
-
_training_in_progress = True
|
| 636 |
_log("[INFO] Stopping ace-server for training...")
|
| 637 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 638 |
_stop_ace_server()
|
|
@@ -720,7 +759,7 @@ def gradio_main():
|
|
| 720 |
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 721 |
|
| 722 |
finally:
|
| 723 |
-
|
| 724 |
# Always restart ace-server
|
| 725 |
_log("[INFO] Restarting ace-server...")
|
| 726 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
|
@@ -746,6 +785,8 @@ def gradio_main():
|
|
| 746 |
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File(value=tmp_out.name, visible=True)
|
| 747 |
else:
|
| 748 |
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
|
|
|
|
|
|
| 749 |
|
| 750 |
# -- Cancel handler --
|
| 751 |
def _on_cancel():
|
|
|
|
| 13 |
import random
|
| 14 |
import requests
|
| 15 |
import logging
|
| 16 |
+
import threading
|
| 17 |
|
| 18 |
from train_engine import (
|
| 19 |
preprocess_audio,
|
| 20 |
train_lora_generator,
|
| 21 |
cancel_training,
|
| 22 |
get_trained_loras as _get_trained_loras_engine,
|
| 23 |
+
MAX_TRAINING_TIME,
|
| 24 |
)
|
| 25 |
|
| 26 |
logger = logging.getLogger(__name__)
|
|
|
|
| 30 |
# ---------------------------------------------------------------------------
|
| 31 |
|
| 32 |
MAX_TOTAL_AUDIO = 1800 # seconds total across all uploaded files (30 min)
|
| 33 |
+
# MAX_TRAINING_TIME is imported from train_engine (single source of truth)
|
| 34 |
MAX_AUDIO_FILES = 50 # max number of training audio files per run
|
| 35 |
|
| 36 |
# ---------------------------------------------------------------------------
|
|
|
|
| 41 |
OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
|
| 42 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 43 |
|
| 44 |
+
# Clean up old inference temp files (older than 1 hour) at startup
|
| 45 |
+
_CLEANUP_MAX_AGE = 3600 # seconds
|
| 46 |
+
try:
|
| 47 |
+
_now = time.time()
|
| 48 |
+
for _fname in os.listdir(OUTPUT_DIR):
|
| 49 |
+
if _fname.lower().endswith((".wav", ".mp3")):
|
| 50 |
+
_fpath = os.path.join(OUTPUT_DIR, _fname)
|
| 51 |
+
try:
|
| 52 |
+
if os.path.isfile(_fpath) and (_now - os.path.getmtime(_fpath)) > _CLEANUP_MAX_AGE:
|
| 53 |
+
os.remove(_fpath)
|
| 54 |
+
except OSError:
|
| 55 |
+
pass
|
| 56 |
+
except Exception:
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
ACE_CHECKPOINT_DIR = os.environ.get("ACE_CHECKPOINT_DIR", "/app/checkpoints")
|
| 60 |
ACE_SOURCE_DIR = "/app/ace-step-source"
|
| 61 |
ACE_HF_MODEL = "ACE-Step/Ace-Step1.5"
|
|
|
|
| 66 |
|
| 67 |
# Detect if running on HF Space (ace-server available) vs locally (PyTorch only)
|
| 68 |
_is_space = os.path.isfile(ACE_SERVER_BIN) or os.environ.get("SPACE_ID") is not None
|
| 69 |
+
_training_lock = threading.Lock()
|
| 70 |
|
| 71 |
# HF repo for on-demand GGUF downloads
|
| 72 |
GGUF_HF_REPO = "Serveurperso/ACE-Step-1.5-GGUF"
|
|
|
|
| 346 |
time.sleep(1)
|
| 347 |
|
| 348 |
|
| 349 |
+
def _start_ace_server(max_retries: int = 3, retry_delay: float = 5.0):
|
| 350 |
+
"""Start ace-server in background and wait for health.
|
| 351 |
+
|
| 352 |
+
Retries up to max_retries times with retry_delay seconds between attempts.
|
| 353 |
+
"""
|
| 354 |
global _ace_proc
|
| 355 |
+
for attempt in range(1, max_retries + 1):
|
| 356 |
+
logger.info(
|
| 357 |
+
"[ace-server] Starting (attempt %d/%d) with --adapters %s",
|
| 358 |
+
attempt, max_retries, ADAPTER_DIR,
|
|
|
|
| 359 |
)
|
| 360 |
+
try:
|
| 361 |
+
_ace_proc = subprocess.Popen(
|
| 362 |
+
[ACE_SERVER_BIN, "--host", "127.0.0.1", "--port", "8085",
|
| 363 |
+
"--models", MODELS_DIR, "--adapters", ADAPTER_DIR, "--max-batch", "1"],
|
| 364 |
+
)
|
| 365 |
+
except Exception as exc:
|
| 366 |
+
logger.error("[ace-server] Failed to start: %s", exc)
|
| 367 |
+
if attempt < max_retries:
|
| 368 |
+
time.sleep(retry_delay)
|
| 369 |
+
continue
|
| 370 |
+
return False
|
| 371 |
|
| 372 |
+
for _ in range(30):
|
| 373 |
+
if _server_ok():
|
| 374 |
+
logger.info("[ace-server] Healthy")
|
| 375 |
+
return True
|
| 376 |
+
time.sleep(2)
|
| 377 |
+
|
| 378 |
+
logger.warning("[ace-server] Health check timeout on attempt %d/%d", attempt, max_retries)
|
| 379 |
+
# Kill the failed process before retrying
|
| 380 |
+
if _ace_proc and _ace_proc.poll() is None:
|
| 381 |
+
_ace_proc.kill()
|
| 382 |
+
try:
|
| 383 |
+
_ace_proc.wait(timeout=5)
|
| 384 |
+
except subprocess.TimeoutExpired:
|
| 385 |
+
pass
|
| 386 |
+
if attempt < max_retries:
|
| 387 |
+
time.sleep(retry_delay)
|
| 388 |
+
|
| 389 |
+
logger.error("[ace-server] Failed to start after %d attempts", max_retries)
|
| 390 |
return False
|
| 391 |
|
| 392 |
|
|
|
|
| 488 |
def generate_music(caption, lyrics, instrumental, bpm, duration, seed,
|
| 489 |
steps, lora_select, lm_model_select,
|
| 490 |
progress=gr.Progress(track_tqdm=True)):
|
| 491 |
+
if not _training_lock.acquire(blocking=False):
|
| 492 |
return None, "Training in progress. Inference unavailable until training completes. Press Cancel to stop training."
|
| 493 |
+
_training_lock.release()
|
| 494 |
if not _server_ok():
|
| 495 |
return None, "ace-server not running. Check logs."
|
| 496 |
|
|
|
|
| 671 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 672 |
|
| 673 |
# Stop ace-server before training (frees memory)
|
| 674 |
+
_training_lock.acquire()
|
|
|
|
| 675 |
_log("[INFO] Stopping ace-server for training...")
|
| 676 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 677 |
_stop_ace_server()
|
|
|
|
| 759 |
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 760 |
|
| 761 |
finally:
|
| 762 |
+
_training_lock.release()
|
| 763 |
# Always restart ace-server
|
| 764 |
_log("[INFO] Restarting ace-server...")
|
| 765 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
|
|
|
| 785 |
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File(value=tmp_out.name, visible=True)
|
| 786 |
else:
|
| 787 |
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 788 |
+
# Clean up training workspace (preprocessed tensors, temp audio, etc.)
|
| 789 |
+
shutil.rmtree(work_dir, ignore_errors=True)
|
| 790 |
|
| 791 |
# -- Cancel handler --
|
| 792 |
def _on_cancel():
|
train_engine.py
CHANGED
|
@@ -374,28 +374,6 @@ def enable_gradient_checkpointing(decoder: nn.Module) -> bool:
|
|
| 374 |
return enabled
|
| 375 |
|
| 376 |
|
| 377 |
-
def force_disable_cache(decoder: nn.Module) -> None:
|
| 378 |
-
stack = [decoder]
|
| 379 |
-
visited = set()
|
| 380 |
-
while stack:
|
| 381 |
-
mod = stack.pop()
|
| 382 |
-
if not isinstance(mod, nn.Module):
|
| 383 |
-
continue
|
| 384 |
-
mid = id(mod)
|
| 385 |
-
if mid in visited:
|
| 386 |
-
continue
|
| 387 |
-
visited.add(mid)
|
| 388 |
-
cfg = getattr(mod, "config", None)
|
| 389 |
-
if cfg is not None and hasattr(cfg, "use_cache"):
|
| 390 |
-
try:
|
| 391 |
-
cfg.use_cache = False
|
| 392 |
-
except Exception:
|
| 393 |
-
pass
|
| 394 |
-
for a in ("_forward_module", "_orig_mod", "base_model", "model", "module"):
|
| 395 |
-
child = getattr(mod, a, None)
|
| 396 |
-
if isinstance(child, nn.Module):
|
| 397 |
-
stack.append(child)
|
| 398 |
-
|
| 399 |
|
| 400 |
# ============================================================================
|
| 401 |
# LORA INJECTION (PEFT only -- no DoRA/LoKR/LoHA/OFT)
|
|
@@ -464,9 +442,7 @@ def inject_lora(model, lora_cfg: LoRAConfig) -> Tuple[Any, Dict[str, Any]]:
|
|
| 464 |
|
| 465 |
def save_lora_adapter(model, output_dir: str) -> None:
|
| 466 |
os.makedirs(output_dir, exist_ok=True)
|
| 467 |
-
decoder =
|
| 468 |
-
while hasattr(decoder, "_forward_module"):
|
| 469 |
-
decoder = decoder._forward_module
|
| 470 |
|
| 471 |
if hasattr(decoder, "save_pretrained"):
|
| 472 |
decoder.save_pretrained(output_dir)
|
|
@@ -602,6 +578,10 @@ def load_model_for_training(
|
|
| 602 |
for idx, attn in enumerate(candidates):
|
| 603 |
try:
|
| 604 |
load_kwargs = dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
trust_remote_code=True,
|
| 606 |
attn_implementation=attn,
|
| 607 |
torch_dtype=dtype,
|
|
@@ -2443,8 +2423,8 @@ def train_lora_generator(
|
|
| 2443 |
|
| 2444 |
yield f"[OK] LoRA injected: {info['trainable_params']:,} trainable params"
|
| 2445 |
|
| 2446 |
-
# Gradient checkpointing + cache disable
|
| 2447 |
-
|
| 2448 |
ckpt_ok = enable_gradient_checkpointing(model.decoder)
|
| 2449 |
force_input_grads = ckpt_ok
|
| 2450 |
if ckpt_ok:
|
|
@@ -2511,9 +2491,7 @@ def train_lora_generator(
|
|
| 2511 |
if aw.exists():
|
| 2512 |
from safetensors.torch import load_file
|
| 2513 |
state = load_file(str(aw))
|
| 2514 |
-
decoder = model
|
| 2515 |
-
while hasattr(decoder, "_forward_module"):
|
| 2516 |
-
decoder = decoder._forward_module
|
| 2517 |
decoder.load_state_dict(state, strict=False)
|
| 2518 |
|
| 2519 |
# Load training state
|
|
@@ -3017,15 +2995,7 @@ def generate_audio(
|
|
| 3017 |
logger.info("Loading LoRA adapter from %s (scale=%.2f)...", adapter_path, adapter_scale)
|
| 3018 |
from peft import PeftModel
|
| 3019 |
|
| 3020 |
-
decoder =
|
| 3021 |
-
# Unwrap any wrappers
|
| 3022 |
-
while hasattr(decoder, "_forward_module"):
|
| 3023 |
-
decoder = decoder._forward_module
|
| 3024 |
-
if hasattr(decoder, "base_model"):
|
| 3025 |
-
bm = decoder.base_model
|
| 3026 |
-
decoder = bm.model if hasattr(bm, "model") else bm
|
| 3027 |
-
if hasattr(decoder, "model") and isinstance(decoder.model, nn.Module):
|
| 3028 |
-
decoder = decoder.model
|
| 3029 |
|
| 3030 |
model.decoder = PeftModel.from_pretrained(
|
| 3031 |
decoder, adapter_path, is_trainable=False,
|
|
@@ -3174,6 +3144,37 @@ _LM_UNDERSTAND_INSTRUCTION = (
|
|
| 3174 |
)
|
| 3175 |
|
| 3176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3177 |
def _build_understand_prompt(
|
| 3178 |
bpe_tokenizer, codes: List[int],
|
| 3179 |
) -> List[int]:
|
|
@@ -3357,22 +3358,7 @@ def _generate_codes_with_lm(
|
|
| 3357 |
logits[0] = logits[0] + mask
|
| 3358 |
|
| 3359 |
# Sample
|
| 3360 |
-
|
| 3361 |
-
next_id = int(logits[0].argmax().item())
|
| 3362 |
-
else:
|
| 3363 |
-
scaled = logits[0].clone() / temperature
|
| 3364 |
-
if top_k > 0:
|
| 3365 |
-
topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
|
| 3366 |
-
scaled[scaled < topk_vals[-1]] = float("-inf")
|
| 3367 |
-
if top_p > 0 and top_p < 1.0:
|
| 3368 |
-
sorted_logits, sorted_idx = torch.sort(scaled, descending=True)
|
| 3369 |
-
probs = torch.softmax(sorted_logits, dim=-1)
|
| 3370 |
-
cumsum = torch.cumsum(probs, dim=-1)
|
| 3371 |
-
nucleus_mask = cumsum - probs > top_p
|
| 3372 |
-
sorted_logits[nucleus_mask] = float("-inf")
|
| 3373 |
-
scaled = torch.zeros_like(scaled).scatter(0, sorted_idx, sorted_logits)
|
| 3374 |
-
probs = torch.softmax(scaled, dim=-1)
|
| 3375 |
-
next_id = int(torch.multinomial(probs, 1).item())
|
| 3376 |
|
| 3377 |
# Stop on im_end
|
| 3378 |
if next_id == _TOKEN_IM_END:
|
|
@@ -3701,23 +3687,7 @@ def understand_audio(
|
|
| 3701 |
logits[0, _AUDIO_CODE_BASE:] = float("-inf")
|
| 3702 |
|
| 3703 |
# Sample
|
| 3704 |
-
|
| 3705 |
-
next_id = int(logits[0].argmax().item())
|
| 3706 |
-
else:
|
| 3707 |
-
scaled = logits[0].clone() / temperature
|
| 3708 |
-
if top_k > 0:
|
| 3709 |
-
topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
|
| 3710 |
-
scaled[scaled < topk_vals[-1]] = float("-inf")
|
| 3711 |
-
if top_p > 0 and top_p < 1.0:
|
| 3712 |
-
sorted_logits, sorted_idx = torch.sort(scaled, descending=True)
|
| 3713 |
-
probs = torch.softmax(sorted_logits, dim=-1)
|
| 3714 |
-
cumsum = torch.cumsum(probs, dim=-1)
|
| 3715 |
-
mask = cumsum - probs > top_p
|
| 3716 |
-
sorted_logits[mask] = float("-inf")
|
| 3717 |
-
# Scatter masked values back to original positions
|
| 3718 |
-
scaled = torch.zeros_like(scaled).scatter(0, sorted_idx, sorted_logits)
|
| 3719 |
-
probs = torch.softmax(scaled, dim=-1)
|
| 3720 |
-
next_id = int(torch.multinomial(probs, 1).item())
|
| 3721 |
|
| 3722 |
if next_id == _TOKEN_IM_END:
|
| 3723 |
break
|
|
|
|
| 374 |
return enabled
|
| 375 |
|
| 376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
# ============================================================================
|
| 379 |
# LORA INJECTION (PEFT only -- no DoRA/LoKR/LoHA/OFT)
|
|
|
|
| 442 |
|
| 443 |
def save_lora_adapter(model, output_dir: str) -> None:
|
| 444 |
os.makedirs(output_dir, exist_ok=True)
|
| 445 |
+
decoder = _unwrap_decoder(model)
|
|
|
|
|
|
|
| 446 |
|
| 447 |
if hasattr(decoder, "save_pretrained"):
|
| 448 |
decoder.save_pretrained(output_dir)
|
|
|
|
| 578 |
for idx, attn in enumerate(candidates):
|
| 579 |
try:
|
| 580 |
load_kwargs = dict(
|
| 581 |
+
# SECURITY: trust_remote_code=True is required because the
|
| 582 |
+
# ACE-Step model config references custom Python code in its
|
| 583 |
+
# checkpoint (config.json -> auto_map). Only load checkpoints
|
| 584 |
+
# from trusted sources (the official ACE-Step HF repo).
|
| 585 |
trust_remote_code=True,
|
| 586 |
attn_implementation=attn,
|
| 587 |
torch_dtype=dtype,
|
|
|
|
| 2423 |
|
| 2424 |
yield f"[OK] LoRA injected: {info['trainable_params']:,} trainable params"
|
| 2425 |
|
| 2426 |
+
# Gradient checkpointing + cache disable (enable_gradient_checkpointing
|
| 2427 |
+
# also walks the module tree and sets use_cache=False on any config it finds)
|
| 2428 |
ckpt_ok = enable_gradient_checkpointing(model.decoder)
|
| 2429 |
force_input_grads = ckpt_ok
|
| 2430 |
if ckpt_ok:
|
|
|
|
| 2491 |
if aw.exists():
|
| 2492 |
from safetensors.torch import load_file
|
| 2493 |
state = load_file(str(aw))
|
| 2494 |
+
decoder = _unwrap_decoder(model)
|
|
|
|
|
|
|
| 2495 |
decoder.load_state_dict(state, strict=False)
|
| 2496 |
|
| 2497 |
# Load training state
|
|
|
|
| 2995 |
logger.info("Loading LoRA adapter from %s (scale=%.2f)...", adapter_path, adapter_scale)
|
| 2996 |
from peft import PeftModel
|
| 2997 |
|
| 2998 |
+
decoder = _unwrap_decoder(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2999 |
|
| 3000 |
model.decoder = PeftModel.from_pretrained(
|
| 3001 |
decoder, adapter_path, is_trainable=False,
|
|
|
|
| 3144 |
)
|
| 3145 |
|
| 3146 |
|
| 3147 |
+
def _sample_next_token(
|
| 3148 |
+
logits: torch.Tensor, temperature: float, top_k: int, top_p: float,
|
| 3149 |
+
) -> int:
|
| 3150 |
+
"""Sample a single token from logits with temperature, top-k, and top-p.
|
| 3151 |
+
|
| 3152 |
+
Args:
|
| 3153 |
+
logits: 1-D logits tensor (vocab_size,).
|
| 3154 |
+
temperature: Sampling temperature (<=0 for argmax).
|
| 3155 |
+
top_k: Top-K filtering (0 = disabled).
|
| 3156 |
+
top_p: Nucleus sampling cutoff (0 or >=1 = disabled).
|
| 3157 |
+
|
| 3158 |
+
Returns:
|
| 3159 |
+
Selected token ID as int.
|
| 3160 |
+
"""
|
| 3161 |
+
if temperature <= 0:
|
| 3162 |
+
return int(logits.argmax().item())
|
| 3163 |
+
scaled = logits.clone() / temperature
|
| 3164 |
+
if top_k > 0:
|
| 3165 |
+
topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
|
| 3166 |
+
scaled[scaled < topk_vals[-1]] = float("-inf")
|
| 3167 |
+
if top_p > 0 and top_p < 1.0:
|
| 3168 |
+
sorted_logits, sorted_idx = torch.sort(scaled, descending=True)
|
| 3169 |
+
probs = torch.softmax(sorted_logits, dim=-1)
|
| 3170 |
+
cumsum = torch.cumsum(probs, dim=-1)
|
| 3171 |
+
nucleus_mask = cumsum - probs > top_p
|
| 3172 |
+
sorted_logits[nucleus_mask] = float("-inf")
|
| 3173 |
+
scaled = torch.zeros_like(scaled).scatter(0, sorted_idx, sorted_logits)
|
| 3174 |
+
probs = torch.softmax(scaled, dim=-1)
|
| 3175 |
+
return int(torch.multinomial(probs, 1).item())
|
| 3176 |
+
|
| 3177 |
+
|
| 3178 |
def _build_understand_prompt(
|
| 3179 |
bpe_tokenizer, codes: List[int],
|
| 3180 |
) -> List[int]:
|
|
|
|
| 3358 |
logits[0] = logits[0] + mask
|
| 3359 |
|
| 3360 |
# Sample
|
| 3361 |
+
next_id = _sample_next_token(logits[0], temperature, top_k, top_p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3362 |
|
| 3363 |
# Stop on im_end
|
| 3364 |
if next_id == _TOKEN_IM_END:
|
|
|
|
| 3687 |
logits[0, _AUDIO_CODE_BASE:] = float("-inf")
|
| 3688 |
|
| 3689 |
# Sample
|
| 3690 |
+
next_id = _sample_next_token(logits[0], temperature, top_k, top_p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3691 |
|
| 3692 |
if next_id == _TOKEN_IM_END:
|
| 3693 |
break
|