Steveeeeeeen HF staff commited on
Commit
b8a3553
·
verified ·
1 Parent(s): 22bde2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -18
app.py CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
4
  import spaces
5
 
6
  from zonos.model import Zonos
7
- from zonos.conditioning import make_cond_dict, supported_language_codes
8
 
9
  # We'll keep a global dictionary of loaded models to avoid reloading
10
  MODELS_CACHE = {}
@@ -13,6 +13,15 @@ device = "cuda"
13
  banner_url = "https://huggingface.co/datasets/Steveeeeeeen/random_images/resolve/main/ZonosHeader.png"
14
  BANNER = f'<div style="display: flex; justify-content: space-around;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 150px; max-width: 300px;"> </div>'
15
 
 
 
 
 
 
 
 
 
 
16
  def load_model(model_name: str):
17
  """
18
  Loads or retrieves a cached Zonos model, sets it to eval and bfloat16.
@@ -28,15 +37,20 @@ def load_model(model_name: str):
28
  return MODELS_CACHE[model_name]
29
 
30
  @spaces.GPU(duration=90)
31
- def tts(text, speaker_audio, selected_language, model_choice):
32
  """
33
  text: str (Text prompt to synthesize)
34
  speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
35
- selected_language: str (language code)
36
  model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid")
37
 
38
  Returns (sr_out, wav_out_numpy).
39
  """
 
 
 
 
 
40
  model = load_model(model_choice)
41
 
42
  if not text:
@@ -52,12 +66,11 @@ def tts(text, speaker_audio, selected_language, model_choice):
52
  # Convert to Torch tensor
53
  wav_tensor = torch.from_numpy(wav_np).float()
54
 
55
- # If stereo (shape [channels, samples]) or multi-channel, downmix to mono
56
- # e.g. shape (2, samples) -> shape (samples,) by averaging
57
  if wav_tensor.ndim == 2 and wav_tensor.shape[0] > 1:
58
- wav_tensor = wav_tensor.mean(dim=0) # shape => (samples,)
59
 
60
- # Now add a batch dimension => shape (1, samples)
61
  wav_tensor = wav_tensor.unsqueeze(0)
62
 
63
  # Get speaker embedding
@@ -66,12 +79,12 @@ def tts(text, speaker_audio, selected_language, model_choice):
66
  spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)
67
 
68
  # Prepare conditioning dictionary
69
- cond_dict = make_cond_dict(
70
- text=text,
71
- speaker=spk_embedding,
72
- language=selected_language,
73
- device=device,
74
- )
75
  conditioning = model.prepare_conditioning(cond_dict)
76
 
77
  # Generate codes
@@ -106,8 +119,6 @@ def build_demo():
106
  ref_audio_input = gr.Audio(
107
  label="Reference Audio (Speaker Cloning)",
108
  type="numpy"
109
- # Optionally add mono=True if you want Gradio to always downmix automatically:
110
- # mono=True
111
  )
112
 
113
  model_dropdown = gr.Dropdown(
@@ -116,10 +127,12 @@ def build_demo():
116
  value="Zyphra/Zonos-v0.1-hybrid",
117
  interactive=True,
118
  )
 
 
119
  language_dropdown = gr.Dropdown(
120
- label="Language Code",
121
- choices=supported_language_codes,
122
- value="en-us",
123
  interactive=True,
124
  )
125
 
 
4
  import spaces
5
 
6
  from zonos.model import Zonos
7
+ from zonos.conditioning import make_cond_dict # Keep this; remove supported_language_codes
8
 
9
  # We'll keep a global dictionary of loaded models to avoid reloading
10
  MODELS_CACHE = {}
 
13
  banner_url = "https://huggingface.co/datasets/Steveeeeeeen/random_images/resolve/main/ZonosHeader.png"
14
  BANNER = f'<div style="display: flex; justify-content: space-around;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 150px; max-width: 300px;"> </div>'
15
 
16
+ # Define a list of tuples: (Display Label, Language Code)
17
+ LANGUAGES = [
18
+ ("English", "en-us"),
19
+ ("Japanese", "ja"),
20
+ ("Chinese", "cmn"),
21
+ ("French", "fr-fr"),
22
+ ("German", "de"),
23
+ ]
24
+
25
  def load_model(model_name: str):
26
  """
27
  Loads or retrieves a cached Zonos model, sets it to eval and bfloat16.
 
37
  return MODELS_CACHE[model_name]
38
 
39
  @spaces.GPU(duration=90)
40
+ def tts(text, speaker_audio, selected_language_label, model_choice):
41
  """
42
  text: str (Text prompt to synthesize)
43
  speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
44
+ selected_language_label: str (the display name from the dropdown, e.g. "Chinese")
45
  model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid")
46
 
47
  Returns (sr_out, wav_out_numpy).
48
  """
49
+ # Map from label -> actual language code
50
+ label_to_code = dict(LANGUAGES)
51
+ # Convert the human-readable label back to the code
52
+ selected_language = label_to_code[selected_language_label]
53
+
54
  model = load_model(model_choice)
55
 
56
  if not text:
 
66
  # Convert to Torch tensor
67
  wav_tensor = torch.from_numpy(wav_np).float()
68
 
69
+ # If stereo or multi-channel, downmix to mono
 
70
  if wav_tensor.ndim == 2 and wav_tensor.shape[0] > 1:
71
+ wav_tensor = wav_tensor.mean(dim=0) # => (samples,)
72
 
73
+ # Add batch dimension => (1, samples)
74
  wav_tensor = wav_tensor.unsqueeze(0)
75
 
76
  # Get speaker embedding
 
79
  spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)
80
 
81
  # Prepare conditioning dictionary
82
+ cond_dict = {
83
+ "text": text,
84
+ "speaker": spk_embedding,
85
+ "language": selected_language, # Use the code here
86
+ "device": device,
87
+ }
88
  conditioning = model.prepare_conditioning(cond_dict)
89
 
90
  # Generate codes
 
119
  ref_audio_input = gr.Audio(
120
  label="Reference Audio (Speaker Cloning)",
121
  type="numpy"
 
 
122
  )
123
 
124
  model_dropdown = gr.Dropdown(
 
127
  value="Zyphra/Zonos-v0.1-hybrid",
128
  interactive=True,
129
  )
130
+
131
+ # For the language dropdown, we display only the friendly label
132
  language_dropdown = gr.Dropdown(
133
+ label="Language",
134
+ choices=[label for (label, code) in LANGUAGES],
135
+ value="English", # default display
136
  interactive=True,
137
  )
138