mehdi999 commited on
Commit
fd1f480
·
1 Parent(s): 6d19e74

back to basics

Browse files
Files changed (3) hide show
  1. app.py +91 -247
  2. app.py.bak +247 -92
  3. tts/model/simple_gla.py +222 -235
app.py CHANGED
@@ -1,176 +1,54 @@
1
  import os
2
- import re
3
- import json
4
- import sys
5
- import time
6
- import threading
7
- import traceback
8
-
9
  import gradio as gr
10
  import numpy as np
11
- import soundfile as sf
12
  import torch
 
13
  import spaces
14
- from huggingface_hub import login, snapshot_download
15
-
16
- # --------- Environnement / stabilité ----------
17
- os.environ.setdefault("FLA_CONV_BACKEND", "torch") # éviter les kernels Triton
18
- os.environ.setdefault("FLA_USE_FAST_OPS", "0")
19
- os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
20
- torch.backends.cuda.matmul.allow_tf32 = True
21
- try:
22
- torch.set_float32_matmul_precision("high")
23
- except Exception:
24
- pass
25
 
 
26
  from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # présent dans ce repo
27
 
28
  MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden")
29
- HF_TOKEN = os.environ.get("HF_TOKEN")
30
 
31
- # --------- Cache global (préchargement au démarrage) ----------
32
- _MODEL = {"pardi": None, "sr": 24000, "err": None, "logs": [], "thread": None}
33
-
34
- def _log(msg: str):
35
- _MODEL["logs"].append(str(msg))
36
- # borne la taille
37
- if len(_MODEL["logs"]) > 2000:
38
- _MODEL["logs"] = _MODEL["logs"][-2000:]
39
-
40
- def _env_diag() -> str:
41
- parts = []
42
  try:
43
- parts.append(f"torch={torch.__version__}")
44
- try:
45
- import triton # type: ignore
46
- parts.append(f"triton={getattr(triton, '__version__', 'unknown')}")
47
- except Exception:
48
- parts.append("triton=not_importable")
49
- parts.append(f"cuda.is_available={torch.cuda.is_available()}")
50
- if torch.cuda.is_available():
51
- parts.append(f"cuda.version={torch.version.cuda}")
52
- try:
53
- free, total = torch.cuda.mem_get_info()
54
- parts.append(f"mem_free={free/1e9:.2f}GB/{total/1e9:.2f}GB")
55
- except Exception:
56
- pass
57
  except Exception as e:
58
- parts.append(f"env_diag_error={e}")
59
- return " | ".join(parts)
 
 
60
 
61
  def _normalize_text(s: str, lang_hint: str = "fr") -> str:
62
- s = (s or "").strip()
63
  try:
64
- import re as _re
65
  from num2words import num2words
66
- def repl(m):
67
- try:
68
- return num2words(int(m.group()), lang=lang_hint)
69
- except Exception:
70
- return m.group()
71
- s = _re.sub(r"\d+", repl, s)
72
  except Exception:
73
  pass
74
  return s
75
 
 
 
 
 
 
 
 
 
76
  def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
77
- arr = np.asarray(arr)
78
  if arr.ndim == 2:
79
  arr = arr.mean(axis=1)
80
- return arr.astype(np.float32)
81
-
82
- def _extract_repo_ids_from_config(config_path: str):
83
- repo_ids = set()
84
- preview = None
85
- try:
86
- with open(config_path, "r", encoding="utf-8") as f:
87
- cfg = json.load(f)
88
- pattern = re.compile(r"^[\w\-]+\/[\w\.\-]+$") # org/name
89
- def rec(obj):
90
- if isinstance(obj, dict):
91
- for v in obj.values(): rec(v)
92
- elif isinstance(obj, list):
93
- for v in obj: rec(v)
94
- elif isinstance(obj, str):
95
- if pattern.match(obj): repo_ids.add(obj)
96
- rec(cfg)
97
- try:
98
- subset_keys = list(cfg)[:5] if isinstance(cfg, dict) else []
99
- preview = json.dumps({k: cfg[k] for k in subset_keys}, ensure_ascii=False)[:600]
100
- except Exception:
101
- pass
102
- except Exception:
103
- pass
104
- return sorted(repo_ids), preview
105
-
106
- def _prefetch_and_load_cpu():
107
- """Exécuté dans un thread au démarrage du Space (hors worker GPU)."""
108
- try:
109
- _log("[prefetch] snapshot_download (main)...")
110
- local_dir = snapshot_download(
111
- repo_id=MODEL_REPO_ID,
112
- token=HF_TOKEN,
113
- local_dir=None,
114
- local_files_only=False,
115
- )
116
- _log(f"[prefetch] main done -> {local_dir}")
117
-
118
- cfg_path = os.path.join(local_dir, "config.json")
119
- nested, cfg_preview = _extract_repo_ids_from_config(cfg_path)
120
- if cfg_preview:
121
- _log(f"[config] preview: {cfg_preview}")
122
- for rid in nested:
123
- if rid == MODEL_REPO_ID:
124
- continue
125
- _log(f"[prefetch] nested repo: {rid} ...")
126
- snapshot_download(repo_id=rid, token=HF_TOKEN, local_dir=None, local_files_only=False)
127
- _log(f"[prefetch] nested repo: {rid} done")
128
-
129
- # Forcer offline pendant le vrai chargement
130
- old_off = os.environ.get("HF_HUB_OFFLINE")
131
- os.environ["HF_HUB_OFFLINE"] = "1"
132
- try:
133
- _log("[load] from_pretrained(map_location='cpu')...")
134
- m = PardiSpeech.from_pretrained(local_dir, map_location="cpu")
135
- m.eval()
136
- _MODEL["pardi"] = m
137
- _MODEL["sr"] = getattr(m, "sampling_rate", 24000)
138
- _log(f"[load] cpu OK (sr={_MODEL['sr']})")
139
- finally:
140
- if old_off is None:
141
- os.environ.pop("HF_HUB_OFFLINE", None)
142
- else:
143
- os.environ["HF_HUB_OFFLINE"] = old_off
144
-
145
- except BaseException as e:
146
- _MODEL["err"] = e
147
- _log(f"[EXC@preload] {type(e).__name__}: {e}")
148
- _log(traceback.format_exc())
149
 
150
- # Lance le préchargement (hors GPU) dès l’import
151
- if _MODEL["thread"] is None:
152
- _MODEL["thread"] = threading.Thread(target=_prefetch_and_load_cpu, daemon=True)
153
- _MODEL["thread"].start()
154
-
155
- def _move_to_cuda_if_available(m, logs_acc):
156
- def L(msg): logs_acc.append(str(msg))
157
- if torch.cuda.is_available():
158
- L("[move] moving model to cuda...")
159
- try:
160
- m = m.to("cuda") # type: ignore[attr-defined]
161
- L("[move] cuda OK")
162
- except Exception as e:
163
- L(f"[move] .to('cuda') failed: {e}. Keeping on CPU.")
164
- else:
165
- L("[move] cuda not available, keep CPU")
166
- return m
167
-
168
- # --------- UI callback (GPU) ----------
169
- @spaces.GPU(duration=200)
170
  def synthesize(
171
  text: str,
172
- debug: bool,
173
- adv_sampling: bool, # Velocity Head sampling
174
  ref_audio,
175
  ref_text: str,
176
  steps: int,
@@ -179,112 +57,83 @@ def synthesize(
179
  temperature: float,
180
  max_seq_len: int,
181
  seed: int,
182
- lang_hint: str,
183
  ):
184
- logs = []
185
- def LOG(msg: str):
186
- logs.append(str(msg))
187
- joined = "\n".join(logs + _MODEL["logs"][-50:]) # mêle quelques logs de préchargement
188
- if len(joined) > 12000:
189
- joined = joined[-12000:]
190
- return joined
191
-
192
- try:
193
- if HF_TOKEN:
194
- try:
195
- login(token=HF_TOKEN)
196
- yield None, LOG("✅ HF login ok")
197
- except Exception as e:
198
- yield None, LOG(f"⚠️ HF login failed: {e}")
199
-
200
- yield None, LOG("[env] " + _env_diag())
201
- torch.manual_seed(int(seed))
202
- os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
203
-
204
- # Si le modèle n’est pas encore prêt, on attend jusqu’à 180s max ici
205
- t0 = time.perf_counter()
206
- while _MODEL["pardi"] is None and _MODEL["err"] is None:
207
- elapsed = time.perf_counter() - t0
208
- yield None, LOG(f"[init] still loading on CPU… {elapsed:.1f}s")
209
- if elapsed > 180:
210
- # dump de la stack du thread de préchargement pour debug
211
- tid = _MODEL["thread"].ident if _MODEL["thread"] else None
212
- if tid is not None:
213
- frame = sys._current_frames().get(tid)
214
- if frame is not None:
215
- stack_txt = "".join(traceback.format_stack(frame))
216
- yield None, LOG("[stack-final]\n" + stack_txt)
217
- raise TimeoutError("Preload timeout (>180s)")
218
- time.sleep(2.0)
219
-
220
- if _MODEL["err"]:
221
- raise _MODEL["err"]
222
-
223
- pardi = _MODEL["pardi"]
224
- sr_out = _MODEL["sr"]
225
-
226
- # Déplacement vers CUDA si possible
227
- pardi = _move_to_cuda_if_available(pardi, logs)
228
- yield None, LOG(f"[init] model ready on {'cuda' if torch.cuda.is_available() else 'cpu'}, sr={sr_out}")
229
-
230
- # ---- Texte + prefix optionnel ----
231
- txt = _normalize_text(text or "", lang_hint=lang_hint)
232
- yield None, LOG(f"[text] {txt[:120]}{'...' if len(txt) > 120 else ''}")
233
 
234
- steps = int(min(max(1, int(steps)), 16))
235
- max_seq_len = int(min(max(50, int(max_seq_len)), 600))
236
 
237
- prefix = None
238
- if ref_audio is not None:
239
- yield None, LOG("[prefix] encoding reference audio...")
240
- if isinstance(ref_audio, str):
241
- wav, sr = sf.read(ref_audio)
242
- else:
243
- sr, wav = ref_audio
244
- wav = _to_mono_float32(wav)
245
- device = "cuda" if torch.cuda.is_available() else "cpu"
246
- wav_t = torch.from_numpy(wav).to(device).unsqueeze(0)
247
- with torch.inference_mode():
248
- prefix_tokens = pardi.patchvae.encode(wav_t) # type: ignore[attr-defined]
249
- prefix = (ref_text or "", prefix_tokens[0])
250
- yield None, LOG("[prefix] done.")
251
 
252
- yield None, LOG(f"[run] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, "
253
- f"T={temperature}, max_seq_len={max_seq_len}, seed={seed}, adv_sampling={adv_sampling}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
- # ---- Chemin rapide (comme le notebook) ----
 
 
 
 
 
 
 
 
 
 
 
 
256
  with torch.inference_mode():
257
- if adv_sampling:
258
- try:
259
- vparams = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg), num_steps=int(steps))
260
- except TypeError:
261
- vparams = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg),
262
- num_steps=int(steps), temperature=float(temperature))
263
- wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len),
264
- velocity_head_sampling_params=vparams)
265
- else:
266
- wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len))
267
 
268
- wav = wavs[0].detach().cpu().numpy().astype(np.float32)
269
- yield (sr_out, wav), LOG("[ok] done.")
270
 
 
 
 
 
 
 
 
 
 
271
  except Exception as e:
272
- tb = traceback.format_exc()
273
- yield None, LOG(f"[EXC] {type(e).__name__}: {e}\n{tb}")
 
 
 
 
 
274
 
275
- # --------- UI ----------
276
  def build_demo():
277
  with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
278
  gr.Markdown(
279
- "### Lina-speech (pardi-speech) – Démo TTS\n"
280
- "Génère de l'audio à partir de texte, avec ou sans prefix (audio de référence).\n"
281
- "Chemin rapide par défaut (comme le notebook)."
282
  )
 
283
  with gr.Row():
284
  text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
285
  with gr.Accordion("Prefix (optionnel)", open=False):
286
  ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
287
- ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)")
288
  with gr.Accordion("Options avancées", open=False):
289
  with gr.Row():
290
  steps = gr.Slider(1, 50, value=10, step=1, label="num_steps")
@@ -293,26 +142,21 @@ def build_demo():
293
  with gr.Row():
294
  temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Température")
295
  max_seq_len = gr.Slider(50, 1200, value=300, step=10, label="max_seq_len (tokens audio)")
296
- seed = gr.Number(value=0, precision=0, label="Seed")
297
- lang_hint = gr.Dropdown(choices=["fr", "en"], value="fr", label="Langue (normalisation)")
298
- with gr.Row():
299
- debug = gr.Checkbox(value=False, label="Mode debug")
300
- adv_sampling = gr.Checkbox(value=False, label="Sampling avancé (Velocity Head)")
301
 
302
  btn = gr.Button("Synthétiser")
303
  out_audio = gr.Audio(label="Sortie audio", type="numpy")
304
- logs_box = gr.Textbox(label="Logs (live)", lines=28)
305
 
306
  demo.queue(default_concurrency_limit=1, max_size=32)
 
307
  btn.click(
308
  fn=synthesize,
309
- inputs=[text, debug, adv_sampling, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
310
- outputs=[out_audio, logs_box],
311
- api_name="synthesize",
312
  )
313
  return demo
314
 
315
  if __name__ == "__main__":
316
- build_demo().launch(ssr_mode=False)
317
- # retrigger 2025-10-30T15:17:49+01:00
318
- # retrigger 2025-10-30T16:37:47+01:00
 
1
  import os
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import numpy as np
 
4
  import torch
5
+ import soundfile as sf
6
  import spaces
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ from huggingface_hub import login
9
  from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # présent dans ce repo
10
 
11
  MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden")
 
12
 
13
+ HF_TOKEN = os.environ.get("HF_TOKEN")
14
+ if HF_TOKEN:
 
 
 
 
 
 
 
 
 
15
  try:
16
+ login(token=HF_TOKEN)
17
+ print("✅ Logged to Hugging Face Hub.")
 
 
 
 
 
 
 
 
 
 
 
 
18
  except Exception as e:
19
+ print("⚠️ HF login failed:", e)
20
+
21
+ _pardi = None
22
+ _sampling_rate = 24000
23
 
24
  def _normalize_text(s: str, lang_hint: str = "fr") -> str:
25
+ s = (s or "").strip().lower()
26
  try:
27
+ import re
28
  from num2words import num2words
29
+ def repl(m): return num2words(int(m.group()), lang=lang_hint)
30
+ s = re.sub(r"\d+", repl, s)
 
 
 
 
31
  except Exception:
32
  pass
33
  return s
34
 
35
+ def _load_model(device: str = "cuda"):
36
+ global _pardi, _sampling_rate
37
+ if _pardi is None:
38
+ _pardi = PardiSpeech.from_pretrained(MODEL_REPO_ID, map_location=device)
39
+ _sampling_rate = getattr(_pardi, "sampling_rate", 24000)
40
+ print(f"✅ PardiSpeech loaded on {device} (sr={_sampling_rate}).")
41
+ return _pardi
42
+
43
  def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
44
+ arr = arr.astype(np.float32)
45
  if arr.ndim == 2:
46
  arr = arr.mean(axis=1)
47
+ return arr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def synthesize(
51
  text: str,
 
 
52
  ref_audio,
53
  ref_text: str,
54
  steps: int,
 
57
  temperature: float,
58
  max_seq_len: int,
59
  seed: int,
60
+ lang_hint: str
61
  ):
62
+ device = "cuda" if torch.cuda.is_available() else "cpu"
63
+ torch.manual_seed(int(seed))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ pardi = _load_model(device)
66
+ txt = _normalize_text(text, lang_hint=lang_hint)
67
 
68
+ cache = pardi.tts.audio_decoder.init_cache(int(max_seq_len), device)
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ # --- IMPORTANT : signature de VelocityHeadSamplingParams ---
71
+ # Dans ton notebook d’inférence, la classe attend (cfg_ref, cfg, num_steps) SANS 'temperature'.
72
+ # On essaie d’abord sans temperature, puis fallback si la classe en accepte une.
73
+ try:
74
+ vel_params = VelocityHeadSamplingParams(
75
+ cfg_ref=float(cfg_ref),
76
+ cfg=float(cfg),
77
+ num_steps=int(steps)
78
+ )
79
+ except TypeError:
80
+ vel_params = VelocityHeadSamplingParams(
81
+ cfg_ref=float(cfg_ref),
82
+ cfg=float(cfg),
83
+ num_steps=int(steps),
84
+ temperature=float(temperature)
85
+ )
86
 
87
+ # Prefix optionnel
88
+ prefix = None
89
+ if ref_audio is not None:
90
+ if isinstance(ref_audio, str):
91
+ wav, sr = sf.read(ref_audio)
92
+ else:
93
+ sr, wav = ref_audio
94
+ wav = _to_mono_float32(np.array(wav))
95
+ wav_t = torch.from_numpy(wav).to(device)
96
+ import torchaudio
97
+ if sr != pardi.sampling_rate:
98
+ wav_t = torchaudio.functional.resample(wav_t, sr, pardi.sampling_rate)
99
+ wav_t = wav_t.unsqueeze(0)
100
  with torch.inference_mode():
101
+ prefix_tokens = pardi.patchvae.encode(wav_t)
102
+ prefix = (ref_text or "", prefix_tokens[0])
 
 
 
 
 
 
 
 
103
 
104
+ print(f"[debug] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}")
 
105
 
106
+ try:
107
+ with torch.inference_mode():
108
+ wavs, _ = pardi.text_to_speech(
109
+ [txt],
110
+ prefix,
111
+ max_seq_len=int(max_seq_len),
112
+ velocity_head_sampling_params=vel_params,
113
+ cache=cache
114
+ )
115
  except Exception as e:
116
+ import traceback, sys
117
+ print("❌ text_to_speech failed:", e, file=sys.stderr)
118
+ traceback.print_exc()
119
+ raise gr.Error(f"Synthèse échouée: {type(e).__name__}: {e}")
120
+
121
+ wav = wavs[0].detach().cpu().numpy()
122
+ return (_sampling_rate, wav)
123
 
 
124
  def build_demo():
125
  with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
126
  gr.Markdown(
127
+ "## Lina-speech (pardi-speech) – Démo TTS\n"
128
+ "Génère de l'audio à partir de texte, avec ou sans *prefix* (audio de référence).\n"
129
+ "Paramètres avancés: *num_steps*, *CFG*, *température*, *max_seq_len*, *seed*."
130
  )
131
+
132
  with gr.Row():
133
  text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
134
  with gr.Accordion("Prefix (optionnel)", open=False):
135
  ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
136
+ ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)")
137
  with gr.Accordion("Options avancées", open=False):
138
  with gr.Row():
139
  steps = gr.Slider(1, 50, value=10, step=1, label="num_steps")
 
142
  with gr.Row():
143
  temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Température")
144
  max_seq_len = gr.Slider(50, 1200, value=300, step=10, label="max_seq_len (tokens audio)")
145
+ seed = gr.Number(value=0, precision=0, label="Seed (reproductibilité)")
146
+ lang_hint = gr.Dropdown(choices=["fr", "en"], value="fr", label="Langue (normalisation)")
 
 
 
147
 
148
  btn = gr.Button("Synthétiser")
149
  out_audio = gr.Audio(label="Sortie audio", type="numpy")
 
150
 
151
  demo.queue(default_concurrency_limit=1, max_size=32)
152
+
153
  btn.click(
154
  fn=synthesize,
155
+ inputs=[text, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
156
+ outputs=[out_audio]
 
157
  )
158
  return demo
159
 
160
  if __name__ == "__main__":
161
+ demo = build_demo()
162
+ demo.launch()
 
app.py.bak CHANGED
@@ -1,54 +1,176 @@
1
  import os
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import numpy as np
4
- import torch
5
  import soundfile as sf
 
6
  import spaces
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- from huggingface_hub import login
9
  from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # présent dans ce repo
10
 
11
  MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden")
12
-
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
- if HF_TOKEN:
 
 
 
 
 
 
 
 
 
 
 
15
  try:
16
- login(token=HF_TOKEN)
17
- print("✅ Logged to Hugging Face Hub.")
 
 
 
 
 
 
 
 
 
 
 
 
18
  except Exception as e:
19
- print("⚠️ HF login failed:", e)
20
-
21
- _pardi = None
22
- _sampling_rate = 24000
23
 
24
  def _normalize_text(s: str, lang_hint: str = "fr") -> str:
25
- s = (s or "").strip().lower()
26
  try:
27
- import re
28
  from num2words import num2words
29
- def repl(m): return num2words(int(m.group()), lang=lang_hint)
30
- s = re.sub(r"\d+", repl, s)
 
 
 
 
31
  except Exception:
32
  pass
33
  return s
34
 
35
- def _load_model(device: str = "cuda"):
36
- global _pardi, _sampling_rate
37
- if _pardi is None:
38
- _pardi = PardiSpeech.from_pretrained(MODEL_REPO_ID, map_location=device)
39
- _sampling_rate = getattr(_pardi, "sampling_rate", 24000)
40
- print(f"✅ PardiSpeech loaded on {device} (sr={_sampling_rate}).")
41
- return _pardi
42
-
43
  def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
44
- arr = arr.astype(np.float32)
45
  if arr.ndim == 2:
46
  arr = arr.mean(axis=1)
47
- return arr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def synthesize(
51
  text: str,
 
 
52
  ref_audio,
53
  ref_text: str,
54
  steps: int,
@@ -57,83 +179,112 @@ def synthesize(
57
  temperature: float,
58
  max_seq_len: int,
59
  seed: int,
60
- lang_hint: str
61
  ):
62
- device = "cuda" if torch.cuda.is_available() else "cpu"
63
- torch.manual_seed(int(seed))
 
 
 
 
 
64
 
65
- pardi = _load_model(device)
66
- txt = _normalize_text(text, lang_hint=lang_hint)
 
 
 
 
 
67
 
68
- cache = pardi.tts.audio_decoder.init_cache(int(max_seq_len), device)
 
 
69
 
70
- # --- IMPORTANT : signature de VelocityHeadSamplingParams ---
71
- # Dans ton notebook d’inférence, la classe attend (cfg_ref, cfg, num_steps) SANS 'temperature'.
72
- # On essaie d’abord sans temperature, puis fallback si la classe en accepte une.
73
- try:
74
- vel_params = VelocityHeadSamplingParams(
75
- cfg_ref=float(cfg_ref),
76
- cfg=float(cfg),
77
- num_steps=int(steps)
78
- )
79
- except TypeError:
80
- vel_params = VelocityHeadSamplingParams(
81
- cfg_ref=float(cfg_ref),
82
- cfg=float(cfg),
83
- num_steps=int(steps),
84
- temperature=float(temperature)
85
- )
86
 
87
- # Prefix optionnel
88
- prefix = None
89
- if ref_audio is not None:
90
- if isinstance(ref_audio, str):
91
- wav, sr = sf.read(ref_audio)
92
- else:
93
- sr, wav = ref_audio
94
- wav = _to_mono_float32(np.array(wav))
95
- wav_t = torch.from_numpy(wav).to(device)
96
- import torchaudio
97
- if sr != pardi.sampling_rate:
98
- wav_t = torchaudio.functional.resample(wav_t, sr, pardi.sampling_rate)
99
- wav_t = wav_t.unsqueeze(0)
100
- with torch.inference_mode():
101
- prefix_tokens = pardi.patchvae.encode(wav_t)
102
- prefix = (ref_text or "", prefix_tokens[0])
103
 
104
- print(f"[debug] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}")
 
105
 
106
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  with torch.inference_mode():
108
- wavs, _ = pardi.text_to_speech(
109
- [txt],
110
- prefix,
111
- max_seq_len=int(max_seq_len),
112
- velocity_head_sampling_params=vel_params,
113
- cache=cache
114
- )
115
- except Exception as e:
116
- import traceback, sys
117
- print("❌ text_to_speech failed:", e, file=sys.stderr)
118
- traceback.print_exc()
119
- raise gr.Error(f"Synthèse échouée: {type(e).__name__}: {e}")
120
 
121
- wav = wavs[0].detach().cpu().numpy()
122
- return (_sampling_rate, wav)
123
 
 
 
 
 
 
124
  def build_demo():
125
  with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
126
  gr.Markdown(
127
- "## Lina-speech (pardi-speech) – Démo TTS\n"
128
- "Génère de l'audio à partir de texte, avec ou sans *prefix* (audio de référence).\n"
129
- "Paramètres avancés: *num_steps*, *CFG*, *température*, *max_seq_len*, *seed*."
130
  )
131
-
132
  with gr.Row():
133
  text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
134
  with gr.Accordion("Prefix (optionnel)", open=False):
135
  ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
136
- ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)")
137
  with gr.Accordion("Options avancées", open=False):
138
  with gr.Row():
139
  steps = gr.Slider(1, 50, value=10, step=1, label="num_steps")
@@ -142,22 +293,26 @@ def build_demo():
142
  with gr.Row():
143
  temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Température")
144
  max_seq_len = gr.Slider(50, 1200, value=300, step=10, label="max_seq_len (tokens audio)")
145
- seed = gr.Number(value=0, precision=0, label="Seed (reproductibilité)")
146
- lang_hint = gr.Dropdown(choices=["fr", "en"], value="fr", label="Langue (normalisation)")
 
 
 
147
 
148
  btn = gr.Button("Synthétiser")
149
  out_audio = gr.Audio(label="Sortie audio", type="numpy")
 
150
 
151
  demo.queue(default_concurrency_limit=1, max_size=32)
152
-
153
  btn.click(
154
  fn=synthesize,
155
- inputs=[text, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
156
- outputs=[out_audio]
 
157
  )
158
  return demo
159
 
160
  if __name__ == "__main__":
161
- demo = build_demo()
162
- demo.launch()
163
- # retrigger 2025-10-29T16:27:55+01:00
 
1
  import os
2
+ import re
3
+ import json
4
+ import sys
5
+ import time
6
+ import threading
7
+ import traceback
8
+
9
  import gradio as gr
10
  import numpy as np
 
11
  import soundfile as sf
12
+ import torch
13
  import spaces
14
+ from huggingface_hub import login, snapshot_download
15
+
16
+ # --------- Environnement / stabilité ----------
17
+ os.environ.setdefault("FLA_CONV_BACKEND", "torch") # éviter les kernels Triton
18
+ os.environ.setdefault("FLA_USE_FAST_OPS", "0")
19
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
20
+ torch.backends.cuda.matmul.allow_tf32 = True
21
+ try:
22
+ torch.set_float32_matmul_precision("high")
23
+ except Exception:
24
+ pass
25
 
 
26
  from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # présent dans ce repo
27
 
28
  MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden")
 
29
  HF_TOKEN = os.environ.get("HF_TOKEN")
30
+
31
+ # --------- Cache global (préchargement au démarrage) ----------
32
+ _MODEL = {"pardi": None, "sr": 24000, "err": None, "logs": [], "thread": None}
33
+
34
+ def _log(msg: str):
35
+ _MODEL["logs"].append(str(msg))
36
+ # borne la taille
37
+ if len(_MODEL["logs"]) > 2000:
38
+ _MODEL["logs"] = _MODEL["logs"][-2000:]
39
+
40
+ def _env_diag() -> str:
41
+ parts = []
42
  try:
43
+ parts.append(f"torch={torch.__version__}")
44
+ try:
45
+ import triton # type: ignore
46
+ parts.append(f"triton={getattr(triton, '__version__', 'unknown')}")
47
+ except Exception:
48
+ parts.append("triton=not_importable")
49
+ parts.append(f"cuda.is_available={torch.cuda.is_available()}")
50
+ if torch.cuda.is_available():
51
+ parts.append(f"cuda.version={torch.version.cuda}")
52
+ try:
53
+ free, total = torch.cuda.mem_get_info()
54
+ parts.append(f"mem_free={free/1e9:.2f}GB/{total/1e9:.2f}GB")
55
+ except Exception:
56
+ pass
57
  except Exception as e:
58
+ parts.append(f"env_diag_error={e}")
59
+ return " | ".join(parts)
 
 
60
 
61
  def _normalize_text(s: str, lang_hint: str = "fr") -> str:
62
+ s = (s or "").strip()
63
  try:
64
+ import re as _re
65
  from num2words import num2words
66
+ def repl(m):
67
+ try:
68
+ return num2words(int(m.group()), lang=lang_hint)
69
+ except Exception:
70
+ return m.group()
71
+ s = _re.sub(r"\d+", repl, s)
72
  except Exception:
73
  pass
74
  return s
75
 
 
 
 
 
 
 
 
 
76
  def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
77
+ arr = np.asarray(arr)
78
  if arr.ndim == 2:
79
  arr = arr.mean(axis=1)
80
+ return arr.astype(np.float32)
81
+
82
+ def _extract_repo_ids_from_config(config_path: str):
83
+ repo_ids = set()
84
+ preview = None
85
+ try:
86
+ with open(config_path, "r", encoding="utf-8") as f:
87
+ cfg = json.load(f)
88
+ pattern = re.compile(r"^[\w\-]+\/[\w\.\-]+$") # org/name
89
+ def rec(obj):
90
+ if isinstance(obj, dict):
91
+ for v in obj.values(): rec(v)
92
+ elif isinstance(obj, list):
93
+ for v in obj: rec(v)
94
+ elif isinstance(obj, str):
95
+ if pattern.match(obj): repo_ids.add(obj)
96
+ rec(cfg)
97
+ try:
98
+ subset_keys = list(cfg)[:5] if isinstance(cfg, dict) else []
99
+ preview = json.dumps({k: cfg[k] for k in subset_keys}, ensure_ascii=False)[:600]
100
+ except Exception:
101
+ pass
102
+ except Exception:
103
+ pass
104
+ return sorted(repo_ids), preview
105
+
106
+ def _prefetch_and_load_cpu():
107
+ """Exécuté dans un thread au démarrage du Space (hors worker GPU)."""
108
+ try:
109
+ _log("[prefetch] snapshot_download (main)...")
110
+ local_dir = snapshot_download(
111
+ repo_id=MODEL_REPO_ID,
112
+ token=HF_TOKEN,
113
+ local_dir=None,
114
+ local_files_only=False,
115
+ )
116
+ _log(f"[prefetch] main done -> {local_dir}")
117
+
118
+ cfg_path = os.path.join(local_dir, "config.json")
119
+ nested, cfg_preview = _extract_repo_ids_from_config(cfg_path)
120
+ if cfg_preview:
121
+ _log(f"[config] preview: {cfg_preview}")
122
+ for rid in nested:
123
+ if rid == MODEL_REPO_ID:
124
+ continue
125
+ _log(f"[prefetch] nested repo: {rid} ...")
126
+ snapshot_download(repo_id=rid, token=HF_TOKEN, local_dir=None, local_files_only=False)
127
+ _log(f"[prefetch] nested repo: {rid} done")
128
+
129
+ # Forcer offline pendant le vrai chargement
130
+ old_off = os.environ.get("HF_HUB_OFFLINE")
131
+ os.environ["HF_HUB_OFFLINE"] = "1"
132
+ try:
133
+ _log("[load] from_pretrained(map_location='cpu')...")
134
+ m = PardiSpeech.from_pretrained(local_dir, map_location="cpu")
135
+ m.eval()
136
+ _MODEL["pardi"] = m
137
+ _MODEL["sr"] = getattr(m, "sampling_rate", 24000)
138
+ _log(f"[load] cpu OK (sr={_MODEL['sr']})")
139
+ finally:
140
+ if old_off is None:
141
+ os.environ.pop("HF_HUB_OFFLINE", None)
142
+ else:
143
+ os.environ["HF_HUB_OFFLINE"] = old_off
144
+
145
+ except BaseException as e:
146
+ _MODEL["err"] = e
147
+ _log(f"[EXC@preload] {type(e).__name__}: {e}")
148
+ _log(traceback.format_exc())
149
 
150
+ # Lance le préchargement (hors GPU) dès l’import
151
+ if _MODEL["thread"] is None:
152
+ _MODEL["thread"] = threading.Thread(target=_prefetch_and_load_cpu, daemon=True)
153
+ _MODEL["thread"].start()
154
+
155
+ def _move_to_cuda_if_available(m, logs_acc):
156
+ def L(msg): logs_acc.append(str(msg))
157
+ if torch.cuda.is_available():
158
+ L("[move] moving model to cuda...")
159
+ try:
160
+ m = m.to("cuda") # type: ignore[attr-defined]
161
+ L("[move] cuda OK")
162
+ except Exception as e:
163
+ L(f"[move] .to('cuda') failed: {e}. Keeping on CPU.")
164
+ else:
165
+ L("[move] cuda not available, keep CPU")
166
+ return m
167
+
168
+ # --------- UI callback (GPU) ----------
169
+ @spaces.GPU(duration=200)
170
  def synthesize(
171
  text: str,
172
+ debug: bool,
173
+ adv_sampling: bool, # Velocity Head sampling
174
  ref_audio,
175
  ref_text: str,
176
  steps: int,
 
179
  temperature: float,
180
  max_seq_len: int,
181
  seed: int,
182
+ lang_hint: str,
183
  ):
184
+ logs = []
185
+ def LOG(msg: str):
186
+ logs.append(str(msg))
187
+ joined = "\n".join(logs + _MODEL["logs"][-50:]) # mêle quelques logs de préchargement
188
+ if len(joined) > 12000:
189
+ joined = joined[-12000:]
190
+ return joined
191
 
192
+ try:
193
+ if HF_TOKEN:
194
+ try:
195
+ login(token=HF_TOKEN)
196
+ yield None, LOG("✅ HF login ok")
197
+ except Exception as e:
198
+ yield None, LOG(f"⚠️ HF login failed: {e}")
199
 
200
+ yield None, LOG("[env] " + _env_diag())
201
+ torch.manual_seed(int(seed))
202
+ os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")
203
 
204
+ # Si le modèle n’est pas encore prêt, on attend jusqu’à 180s max ici
205
+ t0 = time.perf_counter()
206
+ while _MODEL["pardi"] is None and _MODEL["err"] is None:
207
+ elapsed = time.perf_counter() - t0
208
+ yield None, LOG(f"[init] still loading on CPU… {elapsed:.1f}s")
209
+ if elapsed > 180:
210
+ # dump de la stack du thread de préchargement pour debug
211
+ tid = _MODEL["thread"].ident if _MODEL["thread"] else None
212
+ if tid is not None:
213
+ frame = sys._current_frames().get(tid)
214
+ if frame is not None:
215
+ stack_txt = "".join(traceback.format_stack(frame))
216
+ yield None, LOG("[stack-final]\n" + stack_txt)
217
+ raise TimeoutError("Preload timeout (>180s)")
218
+ time.sleep(2.0)
 
219
 
220
+ if _MODEL["err"]:
221
+ raise _MODEL["err"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ pardi = _MODEL["pardi"]
224
+ sr_out = _MODEL["sr"]
225
 
226
+ # Déplacement vers CUDA si possible
227
+ pardi = _move_to_cuda_if_available(pardi, logs)
228
+ yield None, LOG(f"[init] model ready on {'cuda' if torch.cuda.is_available() else 'cpu'}, sr={sr_out}")
229
+
230
+ # ---- Texte + prefix optionnel ----
231
+ txt = _normalize_text(text or "", lang_hint=lang_hint)
232
+ yield None, LOG(f"[text] {txt[:120]}{'...' if len(txt) > 120 else ''}")
233
+
234
+ steps = int(min(max(1, int(steps)), 16))
235
+ max_seq_len = int(min(max(50, int(max_seq_len)), 600))
236
+
237
+ prefix = None
238
+ if ref_audio is not None:
239
+ yield None, LOG("[prefix] encoding reference audio...")
240
+ if isinstance(ref_audio, str):
241
+ wav, sr = sf.read(ref_audio)
242
+ else:
243
+ sr, wav = ref_audio
244
+ wav = _to_mono_float32(wav)
245
+ device = "cuda" if torch.cuda.is_available() else "cpu"
246
+ wav_t = torch.from_numpy(wav).to(device).unsqueeze(0)
247
+ with torch.inference_mode():
248
+ prefix_tokens = pardi.patchvae.encode(wav_t) # type: ignore[attr-defined]
249
+ prefix = (ref_text or "", prefix_tokens[0])
250
+ yield None, LOG("[prefix] done.")
251
+
252
+ yield None, LOG(f"[run] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, "
253
+ f"T={temperature}, max_seq_len={max_seq_len}, seed={seed}, adv_sampling={adv_sampling}")
254
+
255
+ # ---- Chemin rapide (comme le notebook) ----
256
  with torch.inference_mode():
257
+ if adv_sampling:
258
+ try:
259
+ vparams = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg), num_steps=int(steps))
260
+ except TypeError:
261
+ vparams = VelocityHeadSamplingParams(cfg_ref=float(cfg_ref), cfg=float(cfg),
262
+ num_steps=int(steps), temperature=float(temperature))
263
+ wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len),
264
+ velocity_head_sampling_params=vparams)
265
+ else:
266
+ wavs, _ = pardi.text_to_speech([txt], prefix, max_seq_len=int(max_seq_len))
 
 
267
 
268
+ wav = wavs[0].detach().cpu().numpy().astype(np.float32)
269
+ yield (sr_out, wav), LOG("[ok] done.")
270
 
271
+ except Exception as e:
272
+ tb = traceback.format_exc()
273
+ yield None, LOG(f"[EXC] {type(e).__name__}: {e}\n{tb}")
274
+
275
+ # --------- UI ----------
276
  def build_demo():
277
  with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
278
  gr.Markdown(
279
+ "### Lina-speech (pardi-speech) – Démo TTS\n"
280
+ "Génère de l'audio à partir de texte, avec ou sans prefix (audio de référence).\n"
281
+ "Chemin rapide par défaut (comme le notebook)."
282
  )
 
283
  with gr.Row():
284
  text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
285
  with gr.Accordion("Prefix (optionnel)", open=False):
286
  ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
287
+ ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)")
288
  with gr.Accordion("Options avancées", open=False):
289
  with gr.Row():
290
  steps = gr.Slider(1, 50, value=10, step=1, label="num_steps")
 
293
  with gr.Row():
294
  temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Température")
295
  max_seq_len = gr.Slider(50, 1200, value=300, step=10, label="max_seq_len (tokens audio)")
296
+ seed = gr.Number(value=0, precision=0, label="Seed")
297
+ lang_hint = gr.Dropdown(choices=["fr", "en"], value="fr", label="Langue (normalisation)")
298
+ with gr.Row():
299
+ debug = gr.Checkbox(value=False, label="Mode debug")
300
+ adv_sampling = gr.Checkbox(value=False, label="Sampling avancé (Velocity Head)")
301
 
302
  btn = gr.Button("Synthétiser")
303
  out_audio = gr.Audio(label="Sortie audio", type="numpy")
304
+ logs_box = gr.Textbox(label="Logs (live)", lines=28)
305
 
306
  demo.queue(default_concurrency_limit=1, max_size=32)
 
307
  btn.click(
308
  fn=synthesize,
309
+ inputs=[text, debug, adv_sampling, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
310
+ outputs=[out_audio, logs_box],
311
+ api_name="synthesize",
312
  )
313
  return demo
314
 
315
  if __name__ == "__main__":
316
+ build_demo().launch(ssr_mode=False)
317
+ # retrigger 2025-10-30T15:17:49+01:00
318
+ # retrigger 2025-10-30T16:37:47+01:00
tts/model/simple_gla.py CHANGED
@@ -1,304 +1,291 @@
1
- """
2
- Patched Simple GLA decoder for HF Spaces (ZeroGPU) — safe PyTorch-only paths.
3
-
4
- - Forces FLA (flash-linear-attention) to avoid fused/Triton kernels during __init__ & forward
5
- - Adds tolerant construction of SimpleGatedLinearAttention (backend="torch", fused=False)
6
- - Falls back to a no-op GLA stub if FLA construction fails (for demo resilience)
7
- - Keeps cache handling defensive to avoid NoneType unpack errors
8
-
9
- Drop-in replacement for: tts/model/simple_gla.py
10
- """
11
-
12
  import os
13
- from typing import Optional, Dict, Any, Tuple, List, Union
14
-
15
- # ---- Force safe runtime defaults (no Triton / no compile) ----
16
- os.environ.setdefault("FLA_CONV_BACKEND", "torch")
17
- os.environ.setdefault("FLA_USE_FAST_OPS", "0")
18
- os.environ.setdefault("FLA_DISABLE_TRITON", "1") # ignored if not recognized
19
- os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
20
- os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
21
 
22
  import torch
23
  import torch.nn.functional as F
24
- from torch import nn
25
-
26
  from einops import rearrange
 
 
 
 
27
 
28
- # ---------- Try importing FLA; otherwise define a stub ----------
29
- try:
30
- from fla.layers.simple_gla import SimpleGatedLinearAttention # type: ignore
31
- from fla.models.utils import Cache # type: ignore
32
- _FLA_AVAILABLE = True
33
- except Exception:
34
- _FLA_AVAILABLE = False
35
-
36
- class SimpleGatedLinearAttention(nn.Module): # minimal stub (identity)
37
- def __init__(self, *args, **kwargs):
38
- super().__init__()
39
-
40
- def forward(self, x, past_key_values=None, use_cache: bool = False, **kwargs):
41
- # Match tuple output convention used by callers: (x, kv)
42
- return x, None
43
-
44
- # Fallback Cache typing
45
- Cache = Dict[str, Any] # type: ignore
46
-
47
-
48
- # Local layers / utils
49
  from tts.layers.attention import CrossAttention
50
  from tts.layers.ffn import SwiGLU
51
 
52
  from .cache_utils import FLACache
53
  from .config import SimpleGLADecoderConfig
54
  from .registry import register_decoder
 
55
 
 
56
 
57
- def _force_safe_fla_impl(m: SimpleGatedLinearAttention) -> None:
58
- """Force SimpleGatedLinearAttention to use non-fused, PyTorch-only kernels.
59
-
60
- On Hugging Face Spaces (ZeroGPU) + Python 3.10/Triton 3.1, fused kernels can hang at import/first call.
61
- We harden the module to avoid fused/triton implementations.
62
- """
63
- # Prefer explicit mode to avoid backend auto-selection
64
- try:
65
- if hasattr(m, "mode"):
66
- m.mode = "chunk" # safer than "recurrent" fused paths
67
- except Exception:
68
- pass
69
-
70
- # For recent versions exposing implementation switches:
71
- for attr, val in (("recurrent_impl", "naive"),
72
- ("chunk_impl", "naive"),
73
- ("fused", False),
74
- ("backend", "torch")):
75
- if hasattr(m, attr):
76
- try:
77
- setattr(m, attr, val)
78
- except Exception:
79
- pass
80
-
81
-
82
- def _make_tmix(dim: int, num_heads: int) -> SimpleGatedLinearAttention:
83
- """
84
- Construct SimpleGatedLinearAttention using the safest available signature.
85
- Falls back gracefully if kwargs are not supported by the installed FLA version.
86
- """
87
- # Try most explicit signature first
88
- try:
89
- tmix = SimpleGatedLinearAttention(
90
- hidden_size=dim,
91
- num_heads=num_heads,
92
- causal=True,
93
- backend="torch", # key to avoid Triton
94
- fused=False,
95
- )
96
- _force_safe_fla_impl(tmix)
97
- return tmix
98
- except TypeError:
99
- pass
100
- except Exception:
101
- # If constructing with explicit kwargs fails for another reason,
102
- # we will try progressively simpler signatures below.
103
- pass
104
-
105
- # Try without fused/backends but keep causal if supported
106
- try:
107
- tmix = SimpleGatedLinearAttention(
108
- hidden_size=dim,
109
- num_heads=num_heads,
110
- causal=True,
111
- )
112
- _force_safe_fla_impl(tmix)
113
- return tmix
114
- except TypeError:
115
- pass
116
- except Exception:
117
- pass
118
-
119
- # Try minimal signature
120
- try:
121
- tmix = SimpleGatedLinearAttention(
122
- hidden_size=dim,
123
- num_heads=num_heads,
124
- )
125
- _force_safe_fla_impl(tmix)
126
- return tmix
127
- except Exception:
128
- # Last resort: identity stub
129
- return SimpleGatedLinearAttention()
130
-
131
-
132
- def _cache_for_layer(cache: Optional[Cache], idx: int) -> Optional[Cache]:
133
- """
134
- Extract per-layer cache if present; return None if structure is not compatible.
135
- FLA expects either:
136
- - cache["layers"][i]["conv_state"] being a tuple/list
137
- - or a top-level cache dict with "conv_state" key
138
- """
139
- if isinstance(cache, dict):
140
- if "layers" in cache and isinstance(cache["layers"], (list, tuple)):
141
- if idx < len(cache["layers"]) and isinstance(cache["layers"][idx], dict):
142
- # Layer-specific cache entry
143
- c = cache["layers"][idx]
144
- # Validate conv_state shape
145
- if isinstance(c.get("conv_state", None), (list, tuple)):
146
- return c
147
- # If not valid, ignore layer cache to prevent NoneType errors
148
- return None
149
- # Some layouts put conv states directly at top-level
150
- if isinstance(cache.get("conv_state", None), (list, tuple)):
151
- return cache
152
- return None
153
 
 
 
154
 
155
- class SimpleGLABlock(nn.Module):
156
- """One decoder block with GLA time-mixing + feed-forward + (optional) norm/shortconv."""
157
 
 
 
158
  def __init__(
159
  self,
160
  dim: int,
161
  num_heads: int,
162
- layer_idx: int = 0,
163
- expand_k: float = 1.0,
164
- expand_v: float = 1.0,
165
- use_short_conv: bool = False,
166
- ffn_expansion_factor: int = 4,
167
  ):
168
  super().__init__()
169
- # Time-mixing (GLA) — robust construction
170
- self.tmix = _make_tmix(dim=dim, num_heads=num_heads)
171
-
172
- # Feed-forward
173
- hidden_ff = int(dim * ffn_expansion_factor)
174
- self.cmix = SwiGLU(dim, hidden_ff)
175
-
176
- # Norms
177
  self.norm1 = nn.LayerNorm(dim)
178
  self.norm2 = nn.LayerNorm(dim)
179
 
180
- # (Optional) short conv placeholder
181
- self.use_short_conv = use_short_conv
182
-
183
  def forward(
184
  self,
185
- x: torch.Tensor,
186
- cache: Optional[Cache] = None,
187
- **kwargs,
188
- ) -> torch.Tensor:
189
- # Extract a valid cache view for this layer (if any)
190
- pkv = _cache_for_layer(cache, idx=getattr(self, "layer_idx", 0))
191
-
192
- # Some FLA versions want explicit flags
193
- use_cache_flag = isinstance(pkv, dict) and isinstance(pkv.get("conv_state", None), (list, tuple))
194
-
195
- y, _ = self.tmix(
196
- self.norm1(x),
197
- past_key_values=pkv,
198
- use_cache=use_cache_flag,
199
  )
200
- x = y + x
201
  x = self.cmix(self.norm2(x)) + x
202
  return x
203
 
204
 
205
  class DecoderBlockWithOptionalCrossAttention(nn.Module):
206
- """Wrap a GLABlock and add cross-attention (encoder-decoder attention) if provided."""
207
-
208
- def __init__(self, decoder_block: nn.Module, crossatt: Optional[nn.Module] = None):
209
  super().__init__()
 
210
  self.decoder_block = decoder_block
211
  self.crossatt = crossatt
212
 
213
  def forward(
214
  self,
215
  x: torch.Tensor,
216
- encoder_output: Optional[torch.Tensor] = None,
217
- text_freqs: Optional[torch.Tensor] = None,
218
- cache: Optional[Cache] = None,
219
- crossatt_mask: Optional[torch.Tensor] = None,
 
 
220
  ) -> torch.Tensor:
221
- if self.crossatt is not None and encoder_output is not None:
222
- # Standard cross-attention (keys/values from encoder_output)
223
- x = self.crossatt(
 
 
 
 
 
 
224
  x,
225
- context=encoder_output,
 
226
  mask=crossatt_mask,
 
227
  )
228
- x = self.decoder_block(x, cache=cache)
229
  return x
230
 
231
 
232
  @register_decoder("simple_gla")
233
  class SimpleGLADecoder(nn.Module):
234
  config = SimpleGLADecoderConfig
235
- """Decoder composed of a stack of SimpleGLABlock (+ optional cross-attention)."""
236
 
237
- def __init__(self, config: SimpleGLADecoderConfig):
238
  super().__init__()
239
- self.config = config
240
-
241
- dim = getattr(config, "hidden_size", getattr(config, "dim", 512))
242
- num_heads = getattr(config, "num_heads", 8)
243
- num_layers = getattr(config, "num_layers", 12)
244
- ffn_expansion_factor = getattr(config, "ffn_expansion_factor", 4)
245
- expand_k = getattr(config, "expand_k", 1.0)
246
- expand_v = getattr(config, "expand_v", 1.0)
247
- use_short_conv = getattr(config, "use_short_conv", False)
248
- cross_every = getattr(config, "cross_every", 1) # add cross-att every N layers (1 = every layer)
249
-
250
- decoder_layers: List[nn.Module] = []
251
- for i in range(num_layers):
252
- block = SimpleGLABlock(
253
- dim=dim,
254
- num_heads=num_heads,
255
- layer_idx=i,
256
- expand_k=expand_k,
257
- expand_v=expand_v,
258
- use_short_conv=use_short_conv,
259
- ffn_expansion_factor=ffn_expansion_factor,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  )
261
- cross = None
262
- if cross_every and (i % cross_every == 0):
263
- # CrossAttention(dim, num_heads=num_heads) -> module expects (x, context, mask)
264
- cross = CrossAttention(dim, num_heads=num_heads)
265
- decoder_layers.append(DecoderBlockWithOptionalCrossAttention(block, cross))
266
 
267
- self.decoder_layers = nn.ModuleList(decoder_layers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
- # Backward compatibility with code expecting "prefill" API
270
  def prefill(
271
  self,
272
- encoder_output: Optional[torch.Tensor],
273
  decoder_input: torch.Tensor,
274
- cache: Optional[Cache],
275
- text_freqs: Optional[torch.Tensor] = None,
276
- crossatt_mask: Optional[torch.Tensor] = None,
277
- ) -> torch.Tensor:
278
- return self(
279
- encoder_output=encoder_output,
280
- decoder_input=decoder_input,
281
- cache=cache,
282
- text_freqs=text_freqs,
283
- crossatt_mask=crossatt_mask,
284
- )
285
 
286
- def forward(
287
  self,
288
- encoder_output: Optional[torch.Tensor],
289
  decoder_input: torch.Tensor,
290
- cache: Optional[Cache],
291
- text_freqs: Optional[torch.Tensor] = None,
292
- crossatt_mask: Optional[torch.Tensor] = None,
293
- ) -> torch.Tensor:
294
  x = decoder_input
295
- for idx, layer in enumerate(self.decoder_layers):
296
- layer_cache = _cache_for_layer(cache, idx)
297
  x = layer(
298
  x,
299
- encoder_output=encoder_output,
300
  text_freqs=text_freqs,
301
- cache=layer_cache,
302
  crossatt_mask=crossatt_mask,
303
  )
304
  return x
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  import torch.nn.functional as F
 
 
5
  from einops import rearrange
6
+ from fla.layers.simple_gla import SimpleGatedLinearAttention
7
+ from fla.models.utils import Cache
8
+ from sympy import num_digits
9
+ from torch import nn
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from tts.layers.attention import CrossAttention
12
  from tts.layers.ffn import SwiGLU
13
 
14
  from .cache_utils import FLACache
15
  from .config import SimpleGLADecoderConfig
16
  from .registry import register_decoder
17
+ from .shortconv import ShortConvBlock
18
 
19
+ if "GRAD_CKPT" in os.environ:
20
 
21
+ def maybe_grad_ckpt(f):
22
+ def grad_ckpt_f(*args, **kwargs):
23
+ return torch.utils.checkpoint.checkpoint(
24
+ f, *args, **kwargs, use_reentrant=False
25
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ return grad_ckpt_f
28
+ else:
29
 
30
+ def maybe_grad_ckpt(f):
31
+ return f
32
 
33
+
34
+ class SimpleGLABlock(nn.Module):
35
  def __init__(
36
  self,
37
  dim: int,
38
  num_heads: int,
39
+ layer_idx: int,
40
+ expand_k: float,
41
+ expand_v: float,
42
+ use_short_conv: bool,
43
+ ffn_expansion_factor: int,
44
  ):
45
  super().__init__()
46
+ self.tmix = SimpleGatedLinearAttention(
47
+ hidden_size=dim,
48
+ num_heads=num_heads,
49
+ layer_idx=layer_idx,
50
+ )
51
+ self.cmix = SwiGLU(dim, ffn_expansion_factor)
 
 
52
  self.norm1 = nn.LayerNorm(dim)
53
  self.norm2 = nn.LayerNorm(dim)
54
 
 
 
 
55
  def forward(
56
  self,
57
+ x,
58
+ freqs: torch.Tensor | None = None,
59
+ text_freqs: torch.Tensor | None = None,
60
+ cache: Cache | None = None,
61
+ ):
62
+ x = (
63
+ self.tmix(
64
+ self.norm1(x),
65
+ past_key_values=cache,
66
+ use_cache=cache is not None,
67
+ )[0]
68
+ + x
 
 
69
  )
 
70
  x = self.cmix(self.norm2(x)) + x
71
  return x
72
 
73
 
74
  class DecoderBlockWithOptionalCrossAttention(nn.Module):
75
+ def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None):
 
 
76
  super().__init__()
77
+
78
  self.decoder_block = decoder_block
79
  self.crossatt = crossatt
80
 
81
  def forward(
82
  self,
83
  x: torch.Tensor,
84
+ encoder_output: torch.Tensor | None = None,
85
+ freqs: torch.Tensor | None = None,
86
+ text_freqs: torch.Tensor | None = None,
87
+ cache: Cache | None = None,
88
+ selfatt_mask: torch.Tensor | None = None,
89
+ crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None,
90
  ) -> torch.Tensor:
91
+ x = self.decoder_block(
92
+ x,
93
+ freqs=freqs,
94
+ cache=cache,
95
+ )
96
+ if type(crossatt_mask) is list:
97
+ crossatt_mask = crossatt_mask[self.decoder_block.tmix.layer_idx]
98
+ if self.crossatt is not None:
99
+ x = x + self.crossatt(
100
  x,
101
+ k=encoder_output,
102
+ text_freqs=text_freqs,
103
  mask=crossatt_mask,
104
+ cache=cache,
105
  )
106
+
107
  return x
108
 
109
 
110
  @register_decoder("simple_gla")
111
  class SimpleGLADecoder(nn.Module):
112
  config = SimpleGLADecoderConfig
 
113
 
114
+ def __init__(self, cfg: SimpleGLADecoderConfig):
115
  super().__init__()
116
+
117
+ assert cfg.dim % cfg.num_heads == 0, "num_heads should divide dim"
118
+ assert cfg.blind_crossatt + (cfg.listen_read_crossatt is not None) < 2, (
119
+ "at most one specialized cross-attention"
120
+ )
121
+
122
+ self.head_dim = cfg.dim // cfg.num_heads
123
+ self.num_heads = cfg.num_heads
124
+
125
+ def simple_gla_block(i):
126
+ conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers
127
+ if i in conv_layers:
128
+ return ShortConvBlock(
129
+ dim=cfg.dim,
130
+ kernel_size=4,
131
+ ffn_expansion_factor=cfg.ffn_expansion_factor,
132
+ layer_idx=i,
133
+ use_fast_conv1d=True,
134
+ )
135
+
136
+ else:
137
+ return SimpleGLABlock(
138
+ dim=cfg.dim,
139
+ num_heads=cfg.num_heads,
140
+ layer_idx=i,
141
+ expand_k=cfg.expand_k,
142
+ expand_v=cfg.expand_v,
143
+ use_short_conv=cfg.use_short_conv,
144
+ ffn_expansion_factor=cfg.ffn_expansion_factor,
145
+ )
146
+
147
+ def crossatt_block(i):
148
+ if i in cfg.crossatt_layer_idx:
149
+ return CrossAttention(
150
+ dim=cfg.dim,
151
+ num_heads=cfg.crossatt_num_heads,
152
+ dropout=cfg.crossatt_dropout,
153
+ layer_idx=i,
154
+ )
155
+ else:
156
+ return None
157
+
158
+ self.decoder_layers = nn.ModuleList(
159
+ [
160
+ DecoderBlockWithOptionalCrossAttention(
161
+ simple_gla_block(i),
162
+ crossatt_block(i),
163
+ )
164
+ for i in range(cfg.num_layers)
165
+ ]
166
+ )
167
+
168
+ def forward(
169
+ self,
170
+ encoder_output: torch.Tensor,
171
+ decoder_input: torch.Tensor,
172
+ crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None,
173
+ text_ids: torch.Tensor | None = None,
174
+ cache: FLACache | None = None,
175
+ ):
176
+ x = decoder_input
177
+ text_freqs = None
178
+
179
+ for layer in self.decoder_layers:
180
+ x = maybe_grad_ckpt(layer)(
181
+ x,
182
+ encoder_output,
183
+ text_freqs=text_freqs,
184
+ cache=cache,
185
+ crossatt_mask=crossatt_mask,
186
  )
187
+ return x
 
 
 
 
188
 
189
+ def init_cache(self, max_seq_len, device):
190
+ return FLACache(num_states=len(self.decoder_layers) + 1)
191
+
192
+ def init_initial_state(self, batch_size=1, scale=1e-2, device="cpu"):
193
+ return tuple(
194
+ nn.Parameter(
195
+ torch.randn(
196
+ batch_size,
197
+ self.num_heads,
198
+ self.head_dim,
199
+ self.head_dim,
200
+ device=device,
201
+ )
202
+ * scale
203
+ )
204
+ for _ in range(len(self.decoder_layers))
205
+ )
206
+ def init_initial_state_lora(self, lora:int=1, batch_size: int = 1, scale: float=1e-2, device: str="cpu"):
207
+ return tuple(
208
+ (
209
+ nn.Parameter(
210
+ torch.randn(
211
+ batch_size,
212
+ self.num_heads,
213
+ self.head_dim,
214
+ lora,
215
+ device=device,
216
+ )
217
+ * scale
218
+ ),
219
+ nn.Parameter(
220
+ torch.randn(
221
+ batch_size,
222
+ self.num_heads,
223
+ lora,
224
+ self.head_dim,
225
+ device=device,
226
+ )
227
+ * scale
228
+ )
229
+ )
230
+ for _ in range(len(self.decoder_layers))
231
+ )
232
+
233
+ def _get_query(self, audio_inputs: torch.Tensor, layer_idx: int):
234
+ assert self.decoder_layers[layer_idx].crossatt is not None
235
+ x = audio_inputs
236
+ for _, layer in zip(range(layer_idx - 1), self.decoder_layers):
237
+ x = layer(x, None)
238
+ return self.decoder_layers[layer_idx].crossatt._query(x)
239
+
240
+ def forward_first_n_layers(
241
+ self,
242
+ encoder_output: torch.Tensor,
243
+ decoder_input: torch.Tensor,
244
+ n_first_layers: int,
245
+ crossatt_mask: torch.Tensor | None = None,
246
+ cache: FLACache | None = None,
247
+ ):
248
+ x = decoder_input
249
+ if self.text_freqs_embd is not None:
250
+ text_freqs = torch.arange(encoder_output.shape[1], device=x.device)[None, :]
251
+ text_freqs = self.text_freqs_embd(text_freqs)
252
+ else:
253
+ text_freqs = None
254
+
255
+ for layer in self.decoder_layers[:n_first_layers]:
256
+ x = maybe_grad_ckpt(layer)(
257
+ x,
258
+ encoder_output,
259
+ text_freqs=text_freqs,
260
+ cache=cache,
261
+ crossatt_mask=crossatt_mask,
262
+ )
263
+ return x
264
 
 
265
  def prefill(
266
  self,
267
+ encoder_output: torch.Tensor,
268
  decoder_input: torch.Tensor,
269
+ crossatt_mask: torch.Tensor | None = None,
270
+ cache: FLACache | None = None,
271
+ ):
272
+ return self(encoder_output, decoder_input, cache=cache, crossatt_mask=crossatt_mask)
 
 
 
 
 
 
 
273
 
274
+ def decode_one(
275
  self,
276
+ encoder_output: torch.Tensor,
277
  decoder_input: torch.Tensor,
278
+ cache: Cache,
279
+ text_freqs: torch.Tensor | None = None,
280
+ crossatt_mask: torch.Tensor | None = None,
281
+ ):
282
  x = decoder_input
283
+ for layer in self.decoder_layers:
 
284
  x = layer(
285
  x,
286
+ encoder_output,
287
  text_freqs=text_freqs,
288
+ cache=cache,
289
  crossatt_mask=crossatt_mask,
290
  )
291
  return x