Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,29 +3,48 @@ import numpy as np
|
|
| 3 |
import torch
|
| 4 |
from chatterbox.src.chatterbox.tts import ChatterboxTTS
|
| 5 |
import gradio as gr
|
|
|
|
| 6 |
|
| 7 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def set_seed(seed: int):
|
| 11 |
torch.manual_seed(seed)
|
| 12 |
-
|
| 13 |
-
|
|
|
|
| 14 |
random.seed(seed)
|
| 15 |
np.random.seed(seed)
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
def generate(model, text, audio_prompt_path, exaggeration, pace, temperature, seed_num, cfgw):
|
| 22 |
-
if model is None:
|
| 23 |
-
model = ChatterboxTTS.from_pretrained(DEVICE)
|
| 24 |
|
| 25 |
if seed_num != 0:
|
| 26 |
set_seed(int(seed_num))
|
| 27 |
|
| 28 |
-
|
|
|
|
| 29 |
text,
|
| 30 |
audio_prompt_path=audio_prompt_path,
|
| 31 |
exaggeration=exaggeration,
|
|
@@ -33,13 +52,29 @@ def generate(model, text, audio_prompt_path, exaggeration, pace, temperature, se
|
|
| 33 |
temperature=temperature,
|
| 34 |
cfg_weight=cfgw,
|
| 35 |
)
|
| 36 |
-
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
with gr.Blocks() as demo:
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
with gr.Row():
|
|
|
|
| 43 |
with gr.Column():
|
| 44 |
text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
|
| 45 |
ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/wav7604828.wav")
|
|
@@ -72,8 +107,9 @@ with gr.Blocks() as demo:
|
|
| 72 |
outputs=[model_state, audio_output],
|
| 73 |
)
|
| 74 |
|
|
|
|
|
|
|
| 75 |
demo.queue(
|
| 76 |
max_size=50,
|
| 77 |
-
default_concurrency_limit=1,
|
| 78 |
-
).launch(share=True
|
| 79 |
-
|
|
|
|
| 3 |
import torch
|
| 4 |
from chatterbox.src.chatterbox.tts import ChatterboxTTS
|
| 5 |
import gradio as gr
|
| 6 |
+
import spaces # <<< IMPORT THIS
|
| 7 |
|
| 8 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 9 |
+
print(f"🚀 Running on device: {DEVICE}") # Good to log this
|
| 10 |
|
| 11 |
+
# Global model variable to load only once if not using gr.State for model object
|
| 12 |
+
# global_model = None
|
| 13 |
|
| 14 |
def set_seed(seed: int):
|
| 15 |
torch.manual_seed(seed)
|
| 16 |
+
if DEVICE == "cuda": # Only seed cuda if available
|
| 17 |
+
torch.cuda.manual_seed(seed)
|
| 18 |
+
torch.cuda.manual_seed_all(seed)
|
| 19 |
random.seed(seed)
|
| 20 |
np.random.seed(seed)
|
| 21 |
|
| 22 |
+
# Optional: Decorate model loading if it's done on first use within a GPU function
|
| 23 |
+
# However, it's often better to load the model once globally or manage with gr.State
|
| 24 |
+
# and ensure the function CALLING the model is decorated.
|
| 25 |
|
| 26 |
+
@spaces.GPU # <<< ADD THIS DECORATOR
|
| 27 |
+
def generate(model_obj, text, audio_prompt_path, exaggeration, pace, temperature, seed_num, cfgw):
|
| 28 |
+
# It's better to load the model once, perhaps when the gr.State is initialized
|
| 29 |
+
# or globally, rather than checking `model_obj is None` on every call.
|
| 30 |
+
# For ZeroGPU, the decorated function handles the GPU context.
|
| 31 |
+
# Let's assume model_obj is passed correctly and is already on DEVICE
|
| 32 |
+
# or will be moved to DEVICE by ChatterboxTTS internally.
|
| 33 |
+
|
| 34 |
+
if model_obj is None:
|
| 35 |
+
print("Model is None, attempting to load...")
|
| 36 |
+
# This load should ideally happen on DEVICE and be efficient.
|
| 37 |
+
# If ChatterboxTTS.from_pretrained(DEVICE) is slow,
|
| 38 |
+
# this will happen inside the GPU-allocated time.
|
| 39 |
+
model_obj = ChatterboxTTS.from_pretrained(DEVICE)
|
| 40 |
+
print(f"Model loaded on device: {model_obj.device if hasattr(model_obj, 'device') else 'unknown'}")
|
| 41 |
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
if seed_num != 0:
|
| 44 |
set_seed(int(seed_num))
|
| 45 |
|
| 46 |
+
print(f"Generating audio for text: '{text}' on device: {DEVICE}")
|
| 47 |
+
wav = model_obj.generate(
|
| 48 |
text,
|
| 49 |
audio_prompt_path=audio_prompt_path,
|
| 50 |
exaggeration=exaggeration,
|
|
|
|
| 52 |
temperature=temperature,
|
| 53 |
cfg_weight=cfgw,
|
| 54 |
)
|
| 55 |
+
print("Audio generation complete.")
|
| 56 |
+
# The model state is passed back out, which is correct for gr.State
|
| 57 |
+
return (model_obj, (model_obj.sr, wav.squeeze(0).numpy()))
|
| 58 |
|
| 59 |
|
| 60 |
with gr.Blocks() as demo:
|
| 61 |
+
# To ensure model loads on app start and uses DEVICE correctly:
|
| 62 |
+
# Pre-load the model here if you want it loaded once globally for the Space instance.
|
| 63 |
+
# However, with gr.State(None) and loading in `generate`,
|
| 64 |
+
# the first user hitting "Generate" will trigger the load.
|
| 65 |
+
# This is fine if `ChatterboxTTS.from_pretrained(DEVICE)` correctly uses the GPU
|
| 66 |
+
# within the @spaces.GPU decorated `generate` function.
|
| 67 |
+
|
| 68 |
+
# For better clarity on model loading with ZeroGPU:
|
| 69 |
+
# Consider a dedicated function for loading the model that's called to initialize gr.State,
|
| 70 |
+
# or ensure the first call to `generate` handles it robustly within the GPU context.
|
| 71 |
+
# The current approach of loading if model_state is None within `generate` is okay
|
| 72 |
+
# as long as `generate` itself is decorated.
|
| 73 |
+
|
| 74 |
+
model_state = gr.State(None)
|
| 75 |
|
| 76 |
with gr.Row():
|
| 77 |
+
# ... (rest of your UI code is fine) ...
|
| 78 |
with gr.Column():
|
| 79 |
text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
|
| 80 |
ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/wav7604828.wav")
|
|
|
|
| 107 |
outputs=[model_state, audio_output],
|
| 108 |
)
|
| 109 |
|
| 110 |
+
# The share=True in launch() will give a UserWarning on Spaces, it's not needed.
|
| 111 |
+
# Hugging Face Spaces provides the public link automatically.
|
| 112 |
demo.queue(
|
| 113 |
max_size=50,
|
| 114 |
+
default_concurrency_limit=1, # Good for single model instance on GPU
|
| 115 |
+
).launch() # Removed share=True
|
|
|