File size: 5,534 Bytes
b7658fb
6b64262
3f97053
4324db0
138e306
4324db0
6b64262
 
 
3f97053
6b64262
 
b338394
6b64262
 
 
 
 
 
 
 
d082ce1
6b64262
 
4324db0
8a1f431
6b64262
 
 
b338394
6b64262
 
 
 
 
 
c5f9e51
4324db0
b338394
21a831f
4324db0
2877f43
c5f9e51
4324db0
c5f9e51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b338394
 
 
 
 
 
138e306
b338394
 
 
 
c5f9e51
b338394
 
 
 
 
2877f43
c5f9e51
 
 
 
 
 
 
b338394
21a831f
2877f43
 
c5f9e51
2877f43
b338394
4324db0
 
 
b338394
 
4324db0
21a831f
4324db0
6d6254a
d082ce1
6b64262
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import subprocess
import sys
from pathlib import Path

# --- 1. Clone the VibeVoice Repository ---
repo_dir = "VibeVoice"
if not os.path.exists(repo_dir):
    print("Cloning the VibeVoice repository...")
    try:
        subprocess.run(
            ["git", "clone", "https://github.com/microsoft/VibeVoice.git"],
            check=True, capture_output=True, text=True
        )
        print("Repository cloned successfully.")
    except subprocess.CalledProcessError as e:
        print(f"Error cloning repository: {e.stderr}")
        sys.exit(1)
else:
    print("Repository already exists. Skipping clone.")

# --- 2. Install Dependencies ---
os.chdir(repo_dir)
print(f"Changed directory to: {os.getcwd()}")

print("Installing the VibeVoice package in editable mode...")
try:
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "-e", "."],
        check=True, capture_output=True, text=True
    )
    print("Package installed successfully.")
except subprocess.CalledProcessError as e:
    print(f"Error installing package: {e.stderr}")
    sys.exit(1)

# --- 3. Refactor the demo script using a direct replacement strategy ---
demo_script_path = Path("demo/gradio_demo.py")
print(f"Refactoring {demo_script_path} for ZeroGPU lazy loading...")

try:
    with open(demo_script_path, 'r') as f:
        modified_content = f.read()

    # --- Add 'import spaces' at the top ---
    if "import spaces" not in modified_content:
        modified_content = "import spaces\n" + modified_content

    # --- Patch 1: Defer model loading in __init__ ---
    original_init_call = "        self.load_model()"
    replacement_init_block = (
        "        # self.load_model() # Patched: Defer model loading\n"
        "        self.model = None\n"
        "        self.processor = None"
    )
    if original_init_call in modified_content:
        modified_content = modified_content.replace(original_init_call, replacement_init_block, 1)
        print("Successfully patched __init__ to prevent startup model load.")
    else:
        print(f"\033[91mError: Could not find '{original_init_call}' to patch. Startup patch failed.\033[0m")
        sys.exit(1)
        
    # --- Patch 2: Add decorator and lazy-loading logic to the generation method ---
    # Define the exact block to find, spanning the full method signature down to the 'try:'.
    # This is sensitive to whitespace but is the most direct way to replace.
    original_method_header = """    def generate_podcast_streaming(self, 
                                 num_speakers: int,
                                 script: str,
                                 speaker_1: str = None,
                                 speaker_2: str = None,
                                 speaker_3: str = None,
                                 speaker_4: str = None,
                                 cfg_scale: float = 1.3) -> Iterator[tuple]:
        try:"""

    # Define the full replacement block with correct indentation.
    replacement_method_header = """    @spaces.GPU(duration=120)
    def generate_podcast_streaming(self,
                                 num_speakers: int,
                                 script: str,
                                 speaker_1: str = None,
                                 speaker_2: str = None,
                                 speaker_3: str = None,
                                 speaker_4: str = None,
                                 cfg_scale: float = 1.3) -> Iterator[tuple]:
        # Patched: Lazy-load model and processor on the GPU worker
        if self.model is None or self.processor is None:
            print("Loading processor & model for the first time on GPU worker...")
            self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
            self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
                self.model_path,
                torch_dtype=torch.bfloat16, # Use 16-bit precision for quality
                device_map="auto",
            )
            self.model.eval()
            self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
                self.model.model.noise_scheduler.config,
                algorithm_type='sde-dpmsolver++',
                beta_schedule='squaredcos_cap_v2'
            )
            self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
            print("Model and processor loaded successfully on GPU worker.")

        try:"""

    if original_method_header in modified_content:
        modified_content = modified_content.replace(original_method_header, replacement_method_header, 1)
        print("Successfully patched generation method for lazy loading.")
    else:
        print(f"\033[91mError: Could not find the method definition for 'generate_podcast_streaming' to patch. This is likely due to a whitespace mismatch. Please check the demo script.\033[0m")
        sys.exit(1)

    # --- Write the modified content back to the file ---
    with open(demo_script_path, 'w') as f:
        f.write(modified_content)
    
    print("Script patching complete.")

except Exception as e:
    print(f"An error occurred while modifying the script: {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)

# --- 4. Launch the Gradio Demo ---
model_id = "microsoft/VibeVoice-1.5B"
command = ["python", str(demo_script_path), "--model_path", model_id, "--share"]
print(f"Launching Gradio demo with command: {' '.join(command)}")
subprocess.run(command)