utkarshshukla2912 commited on
Commit
8b08d3c
·
1 Parent(s): ac7c607

added distill model

Browse files
Files changed (3) hide show
  1. app.py +188 -65
  2. generation_counter.json +1 -1
  3. vertex_client.py +125 -7
app.py CHANGED
@@ -5,6 +5,7 @@ from pathlib import Path
5
  import uuid
6
  import fcntl
7
  import time
 
8
  from vertex_client import get_vertex_client
9
 
10
  # gr.NO_RELOAD = False
@@ -152,8 +153,9 @@ def synthesize_speech(text, voice_id):
152
 
153
  if success and audio_bytes:
154
  print("✅ Synthesized audio using Vertex AI")
155
- # Save binary audio to temp file
156
- audio_file = f"/tmp/ringg_{str(uuid.uuid4())}.wav"
 
157
  with open(audio_file, "wb") as f:
158
  f.write(audio_bytes)
159
 
@@ -170,7 +172,7 @@ def synthesize_speech(text, voice_id):
170
  rtf_no_vocoder
171
  ) = ""
172
 
173
- status_msg = "✅ Audio generated successfully!"
174
 
175
  return (
176
  audio_file,
@@ -220,7 +222,7 @@ with gr.Blocks(
220
 
221
  # Best Practices Section
222
  gr.Markdown("""
223
- ### 📝 Best Practices for Best Results
224
  - **Supported Languages:** Hindi and English only
225
  - **Check spelling carefully:** Misspelled words may be mispronounced
226
  - **Punctuation matters:** Use proper punctuation for natural pauses and intonation
@@ -228,41 +230,62 @@ with gr.Blocks(
228
  - **Numbers & dates:** Write numbers as words for better pronunciation (e.g., "twenty-five" instead of "25")
229
  """)
230
 
231
- # Text Input
232
- text_input = gr.Textbox(
233
- label="Text (max 300 characters)",
234
- placeholder="Type or paste your text here (max 300 characters)...",
235
- lines=6,
236
- max_lines=10,
237
- max_length=300,
238
- )
239
-
240
- # Character count display
241
- char_count = gr.Markdown("**Character count:** 0 / 300")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
 
 
243
  with gr.Row():
244
  with gr.Column(scale=1):
245
- # Voice Selection
246
- voices = get_voices()
247
- voice_choices = {display: vid for display, vid in voices}
248
-
249
- voice_dropdown = gr.Dropdown(
250
- choices=list(voice_choices.keys()),
251
- label="Choose a voice style",
252
- info=f"{len(voices)} voices available",
253
- value=list(voice_choices.keys())[0] if voices else None,
254
  )
255
 
256
  with gr.Column(scale=1):
257
- audio_output = gr.Audio(label="Listen to your audio", type="filepath")
258
- metrics_header = gr.Markdown("### 📊 Generation Metrics", visible=False)
259
- metrics_output = gr.Code(
260
- label="Metrics", language="json", interactive=False, visible=False
 
 
 
 
 
 
 
261
  )
262
 
263
  generate_btn = gr.Button("🎬 Generate Speech", variant="primary", size="lg")
264
 
265
- gr.Markdown("#### 🎯 Try these examples:")
266
  with gr.Row():
267
  example_btn1 = gr.Button("English Example", size="sm")
268
  example_btn2 = gr.Button("Hindi Example", size="sm")
@@ -280,52 +303,148 @@ with gr.Blocks(
280
  def update_char_count(text):
281
  """Update character count as user types"""
282
  count = len(text) if text else 0
283
- return f"**Character count:** {count} / 300"
284
 
285
  def load_example_text(example_text):
286
  """Load example text and update character count"""
287
  count = len(example_text)
288
- return example_text, f"**Character count:** {count} / 300"
289
 
290
  def clear_text():
291
  """Clear text input"""
292
- return "", "**Character count:** 0 / 300"
293
 
294
  def on_generate(text, voice_display):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  voice_id = voice_choices.get(voice_display)
296
- audio_file, _status, t_time, rtf, wav_dur, voc_time, no_voc_time, rtf_no_voc = (
297
- synthesize_speech(text, voice_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  )
299
 
300
- # Get fresh counter from file
301
- new_count = load_counter()
302
- if audio_file:
303
- # Atomically increment the UNIVERSAL counter
304
- new_count = increment_counter()
305
-
306
- # Format metrics as JSON string (only if available)
307
- has_metrics = any([t_time, rtf, wav_dur, voc_time, no_voc_time, rtf_no_voc])
308
- metrics_json = ""
309
- if has_metrics:
310
- metrics_json = json.dumps(
311
- {
312
- "total_time": t_time,
313
- "rtf": rtf,
314
- "audio_duration": wav_dur,
315
- "vocoder_time": voc_time,
316
- "no_vocoder_time": no_voc_time,
317
- "rtf_no_vocoder": rtf_no_voc,
318
- },
319
- indent=2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  )
321
 
322
- return (
323
- audio_file,
324
- gr.update(visible=has_metrics),
325
- gr.update(value=metrics_json, visible=has_metrics),
326
- f"**🌍 Generations:** {new_count}",
327
- )
328
-
329
  def refresh_counter_on_load():
330
  """Refresh the universal generation counter when the UI loads/reloads"""
331
  return f"**🌍 Generations since last reload:** {load_counter()}"
@@ -356,10 +475,14 @@ with gr.Blocks(
356
  fn=on_generate,
357
  inputs=[text_input, voice_dropdown],
358
  outputs=[
359
- audio_output,
360
- # status_output,
361
- metrics_header,
362
- metrics_output,
 
 
 
 
363
  generation_counter,
364
  ],
365
  concurrency_limit=2,
 
5
  import uuid
6
  import fcntl
7
  import time
8
+ import tempfile
9
  from vertex_client import get_vertex_client
10
 
11
  # gr.NO_RELOAD = False
 
153
 
154
  if success and audio_bytes:
155
  print("✅ Synthesized audio using Vertex AI")
156
+ # Save binary audio to temp file in system temp directory
157
+ temp_dir = tempfile.gettempdir()
158
+ audio_file = os.path.join(temp_dir, f"ringg_{str(uuid.uuid4())}.wav")
159
  with open(audio_file, "wb") as f:
160
  f.write(audio_bytes)
161
 
 
172
  rtf_no_vocoder
173
  ) = ""
174
 
175
+ status_msg = ""
176
 
177
  return (
178
  audio_file,
 
222
 
223
  # Best Practices Section
224
  gr.Markdown("""
225
+ ## 📝 Best Practices for Best Results
226
  - **Supported Languages:** Hindi and English only
227
  - **Check spelling carefully:** Misspelled words may be mispronounced
228
  - **Punctuation matters:** Use proper punctuation for natural pauses and intonation
 
230
  - **Numbers & dates:** Write numbers as words for better pronunciation (e.g., "twenty-five" instead of "25")
231
  """)
232
 
233
+ # Input Section - Text, Voice, and Character Count grouped together
234
+ with gr.Group():
235
+ # Text Input
236
+ text_input = gr.Textbox(
237
+ label="Text (max 500 characters)",
238
+ placeholder="Type or paste your text here (max 500 characters)...",
239
+ lines=6,
240
+ max_lines=10,
241
+ max_length=500,
242
+ )
243
+ # Voice Selection
244
+ voices = get_voices()
245
+ voice_choices = {display: vid for display, vid in voices}
246
+
247
+ voice_dropdown = gr.Dropdown(
248
+ choices=list(voice_choices.keys()),
249
+ label="Choose a voice style",
250
+ info=f"{len(voices)} voices available",
251
+ value=list(voice_choices.keys())[0] if voices else None,
252
+ show_label=False,
253
+ )
254
+ # Character count display
255
+ char_count = gr.Code(
256
+ "Character count: 0 / 500",
257
+ show_line_numbers=False,
258
+ show_label=False,
259
+ )
260
 
261
+ # Side-by-side comparison of Base and Distill models
262
+ gr.Markdown("### 🎧 Audio Results Comparison")
263
  with gr.Row():
264
  with gr.Column(scale=1):
265
+ # gr.Markdown("#### Base Model")
266
+ audio_output_base = gr.Audio(label="Base Model Audio", type="filepath")
267
+ status_base = gr.Markdown("", visible=True)
268
+ metrics_header_base = gr.Markdown("**📊 Metrics**", visible=False)
269
+ metrics_output_base = gr.Code(
270
+ label="Base Metrics", language="json", interactive=False, visible=False
 
 
 
271
  )
272
 
273
  with gr.Column(scale=1):
274
+ # gr.Markdown("#### Distill Model")
275
+ audio_output_distill = gr.Audio(
276
+ label="Distill Model Audio", type="filepath"
277
+ )
278
+ status_distill = gr.Markdown("", visible=True)
279
+ metrics_header_distill = gr.Markdown("**📊 Metrics**", visible=False)
280
+ metrics_output_distill = gr.Code(
281
+ label="Distill Metrics",
282
+ language="json",
283
+ interactive=False,
284
+ visible=False,
285
  )
286
 
287
  generate_btn = gr.Button("🎬 Generate Speech", variant="primary", size="lg")
288
 
 
289
  with gr.Row():
290
  example_btn1 = gr.Button("English Example", size="sm")
291
  example_btn2 = gr.Button("Hindi Example", size="sm")
 
303
  def update_char_count(text):
304
  """Update character count as user types"""
305
  count = len(text) if text else 0
306
+ return f"Character count: {count} / 500"
307
 
308
  def load_example_text(example_text):
309
  """Load example text and update character count"""
310
  count = len(example_text)
311
+ return example_text, f"Character count: {count} / 500"
312
 
313
  def clear_text():
314
  """Clear text input"""
315
+ return "", "Character count: 0 / 500"
316
 
317
  def on_generate(text, voice_display):
318
+ """Generate speech using both base and distill models in parallel."""
319
+ # Validate inputs
320
+ if not text or not text.strip():
321
+ error_msg = "⚠️ Please enter some text"
322
+ yield (
323
+ None,
324
+ error_msg,
325
+ gr.update(visible=False),
326
+ gr.update(visible=False),
327
+ None,
328
+ error_msg,
329
+ gr.update(visible=False),
330
+ gr.update(visible=False),
331
+ f"**🌍 Generations:** {load_counter()}",
332
+ )
333
+ return
334
+
335
  voice_id = voice_choices.get(voice_display)
336
+ if not voice_id:
337
+ error_msg = "⚠️ Please select a voice"
338
+ yield (
339
+ None,
340
+ error_msg,
341
+ gr.update(visible=False),
342
+ gr.update(visible=False),
343
+ None,
344
+ error_msg,
345
+ gr.update(visible=False),
346
+ gr.update(visible=False),
347
+ f"**🌍 Generations:** {load_counter()}",
348
+ )
349
+ return
350
+
351
+ # Initialize state for both models
352
+ results = {
353
+ "base": {"audio": None, "status": "⏳ Loading...", "metrics": None},
354
+ "distill": {"audio": None, "status": "⏳ Loading...", "metrics": None},
355
+ }
356
+
357
+ # Show loading state initially
358
+ yield (
359
+ None,
360
+ results["base"]["status"],
361
+ gr.update(visible=False),
362
+ gr.update(visible=False),
363
+ None,
364
+ results["distill"]["status"],
365
+ gr.update(visible=False),
366
+ gr.update(visible=False),
367
+ f"**🌍 Generations:** {load_counter()}",
368
  )
369
 
370
+ # Use parallel synthesis
371
+ vertex_client = get_vertex_client()
372
+ counter_incremented = False
373
+
374
+ for (
375
+ model_type,
376
+ success,
377
+ audio_bytes,
378
+ metrics,
379
+ ) in vertex_client.synthesize_parallel(text, voice_id):
380
+ if success and audio_bytes:
381
+ # Save audio file in system temp directory
382
+ temp_dir = tempfile.gettempdir()
383
+ audio_file = os.path.join(
384
+ temp_dir, f"ringg_{model_type}_{str(uuid.uuid4())}.wav"
385
+ )
386
+ with open(audio_file, "wb") as f:
387
+ f.write(audio_bytes)
388
+
389
+ # Increment counter only once (for the first successful result)
390
+ if not counter_incremented:
391
+ new_count = increment_counter()
392
+ counter_incremented = True
393
+ else:
394
+ new_count = load_counter()
395
+
396
+ # Format metrics
397
+ metrics_json = ""
398
+ has_metrics = False
399
+ if metrics:
400
+ has_metrics = True
401
+ metrics_json = json.dumps(
402
+ {
403
+ "total_time": f"{metrics.get('t', 0):.3f}s",
404
+ "rtf": f"{metrics.get('rtf', 0):.4f}",
405
+ "audio_duration": f"{metrics.get('wav_seconds', 0):.2f}s",
406
+ "vocoder_time": f"{metrics.get('t_vocoder', 0):.3f}s",
407
+ "no_vocoder_time": f"{metrics.get('t_no_vocoder', 0):.3f}s",
408
+ "rtf_no_vocoder": f"{metrics.get('rtf_no_vocoder', 0):.4f}",
409
+ },
410
+ indent=2,
411
+ )
412
+
413
+ # Update the corresponding model result
414
+ results[model_type] = {
415
+ "audio": audio_file,
416
+ "status": "",
417
+ "metrics": metrics_json,
418
+ "has_metrics": has_metrics,
419
+ }
420
+ else:
421
+ # Update failed model
422
+ results[model_type] = {
423
+ "audio": None,
424
+ "status": "❌ Failed to generate",
425
+ "metrics": "",
426
+ "has_metrics": False,
427
+ }
428
+
429
+ # Yield updated state for both models
430
+ yield (
431
+ results["base"]["audio"],
432
+ results["base"]["status"],
433
+ gr.update(visible=results["base"].get("has_metrics", False)),
434
+ gr.update(
435
+ value=results["base"]["metrics"],
436
+ visible=results["base"].get("has_metrics", False),
437
+ ),
438
+ results["distill"]["audio"],
439
+ results["distill"]["status"],
440
+ gr.update(visible=results["distill"].get("has_metrics", False)),
441
+ gr.update(
442
+ value=results["distill"]["metrics"],
443
+ visible=results["distill"].get("has_metrics", False),
444
+ ),
445
+ f"**🌍 Generations:** {new_count if counter_incremented else load_counter()}",
446
  )
447
 
 
 
 
 
 
 
 
448
  def refresh_counter_on_load():
449
  """Refresh the universal generation counter when the UI loads/reloads"""
450
  return f"**🌍 Generations since last reload:** {load_counter()}"
 
475
  fn=on_generate,
476
  inputs=[text_input, voice_dropdown],
477
  outputs=[
478
+ audio_output_base,
479
+ status_base,
480
+ metrics_header_base,
481
+ metrics_output_base,
482
+ audio_output_distill,
483
+ status_distill,
484
+ metrics_header_distill,
485
+ metrics_output_distill,
486
  generation_counter,
487
  ],
488
  concurrency_limit=2,
generation_counter.json CHANGED
@@ -1 +1 @@
1
- {"count": 3, "last_updated": 1762495500.191227}
 
1
+ {"count": 10, "last_updated": 1762780862.430711}
vertex_client.py CHANGED
@@ -5,7 +5,8 @@ import os
5
  import json
6
  import logging
7
  import requests
8
- from typing import Optional, Dict, Any, Tuple
 
9
  from google.cloud import aiplatform
10
  from google.oauth2 import service_account
11
  from dotenv import load_dotenv
@@ -24,6 +25,7 @@ class VertexAIClient:
24
  def __init__(self):
25
  """Initialize the Vertex AI client."""
26
  self.endpoint = None
 
27
  self.credentials = None
28
  self.initialized = False
29
 
@@ -57,7 +59,7 @@ class VertexAIClient:
57
 
58
  def initialize(self) -> bool:
59
  """
60
- Initialize Vertex AI and find the zipvoice endpoint.
61
 
62
  Returns:
63
  True if initialization successful, False otherwise
@@ -80,16 +82,26 @@ class VertexAIClient:
80
  )
81
  logger.info("Vertex AI initialized for project desivocalprod01")
82
 
83
- # Find the zipvoice endpoint
84
  for endpoint in aiplatform.Endpoint.list():
85
  if endpoint.display_name == "zipvoice":
86
  self.endpoint = endpoint
87
- self.initialized = True
88
  logger.info(f"Found zipvoice endpoint: {endpoint.resource_name}")
89
- return True
 
 
90
 
91
- logger.error("zipvoice endpoint not found in Vertex AI")
92
- return False
 
 
 
 
 
 
 
 
 
93
 
94
  except Exception as e:
95
  logger.error(f"Failed to initialize Vertex AI: {e}")
@@ -185,6 +197,112 @@ class VertexAIClient:
185
  logger.error(f"Failed to synthesize speech with Vertex AI: {e}")
186
  return False, None, None
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  # Global instance
190
  _vertex_client = None
 
5
  import json
6
  import logging
7
  import requests
8
+ from typing import Optional, Dict, Any, Tuple, Generator
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
  from google.cloud import aiplatform
11
  from google.oauth2 import service_account
12
  from dotenv import load_dotenv
 
25
  def __init__(self):
26
  """Initialize the Vertex AI client."""
27
  self.endpoint = None
28
+ self.endpoint_distill = None
29
  self.credentials = None
30
  self.initialized = False
31
 
 
59
 
60
  def initialize(self) -> bool:
61
  """
62
+ Initialize Vertex AI and find the zipvoice and zipvoice_base_distill endpoints.
63
 
64
  Returns:
65
  True if initialization successful, False otherwise
 
82
  )
83
  logger.info("Vertex AI initialized for project desivocalprod01")
84
 
85
+ # Find both endpoints
86
  for endpoint in aiplatform.Endpoint.list():
87
  if endpoint.display_name == "zipvoice":
88
  self.endpoint = endpoint
 
89
  logger.info(f"Found zipvoice endpoint: {endpoint.resource_name}")
90
+ elif endpoint.display_name == "zipvoice_base_distill":
91
+ self.endpoint_distill = endpoint
92
+ logger.info(f"Found zipvoice_base_distill endpoint: {endpoint.resource_name}")
93
 
94
+ # Check if at least the base endpoint is found
95
+ if not self.endpoint:
96
+ logger.error("zipvoice endpoint not found in Vertex AI")
97
+ return False
98
+
99
+ # Warn if distill endpoint is not found but continue
100
+ if not self.endpoint_distill:
101
+ logger.warning("zipvoice_base_distill endpoint not found - distill model will not be available")
102
+
103
+ self.initialized = True
104
+ return True
105
 
106
  except Exception as e:
107
  logger.error(f"Failed to initialize Vertex AI: {e}")
 
197
  logger.error(f"Failed to synthesize speech with Vertex AI: {e}")
198
  return False, None, None
199
 
200
+ def synthesize_distill(self, text: str, voice_id: str, timeout: int = 60) -> Tuple[bool, Optional[bytes], Optional[Dict[str, Any]]]:
201
+ """
202
+ Synthesize speech from text using Vertex AI distill endpoint.
203
+
204
+ Args:
205
+ text: Text to synthesize
206
+ voice_id: Voice ID to use
207
+ timeout: Request timeout in seconds
208
+
209
+ Returns:
210
+ Tuple of (success, audio_bytes, metrics)
211
+ """
212
+ if not self.initialized:
213
+ if not self.initialize():
214
+ return False, None, None
215
+
216
+ if not self.endpoint_distill:
217
+ logger.error("Distill endpoint not available")
218
+ return False, None, None
219
+
220
+ try:
221
+ logger.info(f"Synthesizing text (length: {len(text)}) with voice {voice_id} using distill model")
222
+ response = self.endpoint_distill.raw_predict(
223
+ body=json.dumps({
224
+ "text": text,
225
+ "voice_id": voice_id,
226
+ "model_type": "distill",
227
+ }),
228
+ headers={"Content-Type": "application/json"},
229
+ )
230
+
231
+ # Parse JSON response
232
+ result = json.loads(response.text) if hasattr(response, 'text') else response
233
+ logger.info(f"Vertex AI distill response: {result}")
234
+
235
+ # Check if synthesis was successful
236
+ if result.get("success"):
237
+ audio_url = result.get("audio_url")
238
+ metrics = result.get("metrics")
239
+
240
+ if not audio_url:
241
+ logger.error("No audio_url in successful response")
242
+ return False, None, None
243
+
244
+ # Download audio from URL
245
+ logger.info(f"Downloading audio from: {audio_url}")
246
+ audio_response = requests.get(audio_url, timeout=timeout)
247
+
248
+ if audio_response.status_code == 200:
249
+ audio_data = audio_response.content
250
+ logger.info(f"Successfully downloaded audio ({len(audio_data)} bytes)")
251
+ return True, audio_data, metrics
252
+ else:
253
+ logger.error(f"Failed to download audio: HTTP {audio_response.status_code}")
254
+ return False, None, None
255
+ else:
256
+ error_msg = result.get("message", "Unknown error")
257
+ logger.error(f"Synthesis failed: {error_msg}")
258
+ return False, None, None
259
+
260
+ except Exception as e:
261
+ logger.error(f"Failed to synthesize speech with Vertex AI distill: {e}")
262
+ return False, None, None
263
+
264
+ def synthesize_parallel(self, text: str, voice_id: str, timeout: int = 60) -> Generator[Tuple[str, bool, Optional[bytes], Optional[Dict[str, Any]]], None, None]:
265
+ """
266
+ Synthesize speech from text using both base and distill endpoints in parallel.
267
+
268
+ Yields results as they arrive (doesn't wait for both to complete).
269
+
270
+ Args:
271
+ text: Text to synthesize
272
+ voice_id: Voice ID to use
273
+ timeout: Request timeout in seconds
274
+
275
+ Yields:
276
+ Tuple of (model_type, success, audio_bytes, metrics)
277
+ model_type is either "base" or "distill"
278
+ """
279
+ if not self.initialized:
280
+ if not self.initialize():
281
+ logger.error("Failed to initialize client for parallel synthesis")
282
+ return
283
+
284
+ # Create executor for parallel execution
285
+ with ThreadPoolExecutor(max_workers=2) as executor:
286
+ # Submit both tasks
287
+ futures = {}
288
+
289
+ # Always submit base model
290
+ futures[executor.submit(self.synthesize, text, voice_id, timeout)] = "base"
291
+
292
+ # Submit distill model if available
293
+ if self.endpoint_distill:
294
+ futures[executor.submit(self.synthesize_distill, text, voice_id, timeout)] = "distill"
295
+
296
+ # Yield results as they complete
297
+ for future in as_completed(futures):
298
+ model_type = futures[future]
299
+ try:
300
+ success, audio_bytes, metrics = future.result()
301
+ yield model_type, success, audio_bytes, metrics
302
+ except Exception as e:
303
+ logger.error(f"Error in parallel synthesis for {model_type}: {e}")
304
+ yield model_type, False, None, None
305
+
306
 
307
  # Global instance
308
  _vertex_client = None