Wismut commited on
Commit
7c73742
·
1 Parent(s): 34ab4db

fixed missing cuda option

Browse files
Files changed (1) hide show
  1. app.py +31 -11
app.py CHANGED
@@ -33,8 +33,20 @@ ANNOTATED_FEATURES_INFO = [
33
  "Colloquial | Formal",
34
  ]
35
 
 
36
  nltk.download("punkt_tab")
37
 
 
 
 
 
 
 
 
 
 
 
 
38
  # Load PCA model and annotated features
39
  try:
40
  pca = joblib.load(PCA_MODEL_PATH)
@@ -50,7 +62,9 @@ except FileNotFoundError:
50
  print(f"Error: Annotated features file '{ANNOTATED_FEATURES_PATH}' not found.")
51
  annotated_features = None
52
 
53
- # Utility Functions
 
 
54
 
55
 
56
  def load_voices_json():
@@ -132,8 +146,8 @@ def generate_audio_with_voice(text, voice_key, speed_val):
132
  print(f"Selected Voice: {voice_key}")
133
  print(f"Style Vector (First 6): {style_vector[0][:6]}")
134
 
135
- # Convert to torch tensor
136
- style_vec_torch = torch.from_numpy(style_vector).float()
137
 
138
  # Generate audio using the TTS model
139
  audio_np = tts_with_style_vector(
@@ -148,7 +162,7 @@ def generate_audio_with_voice(text, voice_key, speed_val):
148
 
149
  if audio_np is None:
150
  print("Audio generation failed.")
151
- return None, "Audio generation failed."
152
 
153
  # Prepare audio for Gradio
154
  sr = 24000 # Adjust based on your actual sampling rate
@@ -216,9 +230,9 @@ def generate_custom_audio(text, voice_key, randomize, speed_str, *slider_values)
216
  if random_style_vec is None:
217
  print("Failed to generate randomized style vector.")
218
  return None, None, None
219
- # Ensure the style vector is flat
220
  final_vec = (
221
- random_style_vec.numpy().flatten()
222
  if isinstance(random_style_vec, torch.Tensor)
223
  else np.array(random_style_vec).flatten()
224
  )
@@ -232,8 +246,10 @@ def generate_custom_audio(text, voice_key, randomize, speed_str, *slider_values)
232
  )
233
  return None, None, None
234
 
235
- # Convert to torch tensor
236
- style_vec_torch = torch.from_numpy(reconstructed_vec).float().unsqueeze(0)
 
 
237
 
238
  # Generate audio with the reconstructed style vector
239
  audio_np = tts_with_style_vector(
@@ -471,13 +487,17 @@ def create_combined_interface():
471
  # Save button functionality
472
  def on_save_style_studio(style_vector, style_name):
473
  if not style_name:
474
- return "Please enter a name for the new voice!"
 
 
 
 
475
  result = save_style_to_json(style_vector, style_name)
476
  new_choices = list(load_voices_json().keys())
477
  # Return multiple values to update both dropdowns and show status
478
  return (
479
- gr.Dropdown(choices=new_choices), # Update first dropdown
480
- gr.Dropdown(choices=new_choices), # Update studio dropdown
481
  result, # Status message
482
  )
483
 
 
33
  "Colloquial | Formal",
34
  ]
35
 
36
+ # Download necessary NLTK data
37
  nltk.download("punkt_tab")
38
 
39
+ ##############################################################################
40
+ # DEVICE CONFIGURATION
41
+ ##############################################################################
42
+ # Detect if CUDA is available and set the device accordingly
43
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ print(f"Using device: {device}")
45
+
46
+ ##############################################################################
47
+ # LOAD PCA MODEL AND ANNOTATED FEATURES
48
+ ##############################################################################
49
+
50
  # Load PCA model and annotated features
51
  try:
52
  pca = joblib.load(PCA_MODEL_PATH)
 
62
  print(f"Error: Annotated features file '{ANNOTATED_FEATURES_PATH}' not found.")
63
  annotated_features = None
64
 
65
+ ##############################################################################
66
+ # UTILITY FUNCTIONS
67
+ ##############################################################################
68
 
69
 
70
  def load_voices_json():
 
146
  print(f"Selected Voice: {voice_key}")
147
  print(f"Style Vector (First 6): {style_vector[0][:6]}")
148
 
149
+ # Convert to torch tensor and move to device
150
+ style_vec_torch = torch.from_numpy(style_vector).float().to(device)
151
 
152
  # Generate audio using the TTS model
153
  audio_np = tts_with_style_vector(
 
162
 
163
  if audio_np is None:
164
  print("Audio generation failed.")
165
+ return None, None, "Audio generation failed."
166
 
167
  # Prepare audio for Gradio
168
  sr = 24000 # Adjust based on your actual sampling rate
 
230
  if random_style_vec is None:
231
  print("Failed to generate randomized style vector.")
232
  return None, None, None
233
+ # Ensure the style vector is flat and on device
234
  final_vec = (
235
+ random_style_vec.cpu().numpy().flatten()
236
  if isinstance(random_style_vec, torch.Tensor)
237
  else np.array(random_style_vec).flatten()
238
  )
 
246
  )
247
  return None, None, None
248
 
249
+ # Convert to torch tensor and move to device
250
+ style_vec_torch = (
251
+ torch.from_numpy(reconstructed_vec).float().unsqueeze(0).to(device)
252
+ )
253
 
254
  # Generate audio with the reconstructed style vector
255
  audio_np = tts_with_style_vector(
 
487
  # Save button functionality
488
  def on_save_style_studio(style_vector, style_name):
489
  if not style_name:
490
+ return (
491
+ "Please enter a name for the new voice!",
492
+ gr.Dropdown.update(),
493
+ gr.Dropdown.update(),
494
+ )
495
  result = save_style_to_json(style_vector, style_name)
496
  new_choices = list(load_voices_json().keys())
497
  # Return multiple values to update both dropdowns and show status
498
  return (
499
+ gr.Dropdown.update(choices=new_choices),
500
+ gr.Dropdown.update(choices=new_choices),
501
  result, # Status message
502
  )
503