File size: 3,840 Bytes
d8bee93
4641482
 
fc2df73
 
 
8265486
 
7c0acdb
976c1b3
fc2df73
16df797
 
 
 
 
 
9fe0c69
fc2df73
 
 
8265486
 
 
fc2df73
8265486
97422f6
 
 
 
fc2df73
97422f6
fc2df73
8265486
 
fc2df73
8265486
 
 
fc2df73
df00973
8265486
 
fc2df73
cc52ef6
fc2df73
52b1dc0
fc2df73
cc52ef6
4808c1f
 
 
 
 
 
fc2df73
4b233a9
59bbf3c
8265486
59bbf3c
 
 
8265486
59bbf3c
 
fc2df73
86a50c3
8265486
133dd35
fc2df73
133dd35
59bbf3c
133dd35
fc2df73
4641482
fc2df73
cc52ef6
 
4641482
8265486
86a50c3
cc52ef6
 
 
 
86a50c3
cc52ef6
 
 
8265486
 
 
 
fc2df73
 
 
 
 
cc52ef6
8265486
fc2df73
8265486
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import spaces
import gradio as gr
import torch
import numpy as np
import os
import tempfile
from diffusers import DiffusionPipeline
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils.export_utils import export_to_video

# Constants
LANDSCAPE_WIDTH = 832
LANDSCAPE_HEIGHT = 480
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81
T2V_FIXED_FPS = 16
MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)

# Checkpoint ID
ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"

# Quantization config
quant_config = PipelineQuantizationConfig(
    quant_backend="bitsandbytes_4bit",
    quant_kwargs={
        "load_in_4bit": True,
        "bnb_4bit_quant_type": "nf4",
        "bnb_4bit_compute_dtype": torch.bfloat16,
    },
    components_to_quantize=["transformer", "text_encoder"],
)

# Load pipeline
pipe = DiffusionPipeline.from_pretrained(
    ckpt_id,
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()

# Duration estimator
def get_duration(prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress):
    return steps * 18 if duration_seconds <= 2.5 else steps * 25

# Inference function
@spaces.GPU(duration=get_duration)
def generate_video(prompt, height, width, negative_prompt, duration_seconds,
                   guidance_scale, steps, seed, randomize_seed,
                   progress=gr.Progress(track_tqdm=True)):
    
    num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)),
                         MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
    current_seed = np.random.randint(0, MAX_SEED) if randomize_seed else int(seed)

    output_frames_list = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=int(height),
        width=int(width),
        num_frames=num_frames,
        guidance_scale=float(guidance_scale),
        num_inference_steps=int(steps),
        generator=torch.manual_seed(current_seed),
    ).frames[0]

    temp_dir = tempfile.mkdtemp()
    video_path = os.path.join(temp_dir, "t2v_output.mp4")
    export_to_video(output_frames_list, video_path, fps=T2V_FIXED_FPS)

    print(f"✅ Video saved to: {video_path}")
    return video_path  # Only return video

# Gradio UI
with gr.Blocks(css="body { max-width: 100vw; overflow-x: hidden; }") as demo:
    gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (Quantized, Smart Duration)")
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.")
            negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=3, value="")
            height_input = gr.Slider(256, 1024, step=8, value=512, label="Height")
            width_input = gr.Slider(256, 1024, step=8, value=512, label="Width")
            duration_input = gr.Slider(1, 10, value=2, step=0.1, label="Duration (seconds)")
            steps_input = gr.Slider(1, 50, value=20, step=1, label="Inference Steps")
            guidance_scale_input = gr.Slider(0.0, 20.0, step=0.5, value=7.5, label="Guidance Scale")
            seed_input = gr.Number(value=42, label="Seed (optional)")
            randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
            run_btn = gr.Button("Generate Video")
        with gr.Column():
            output_video = gr.Video(label="Generated Video")

    ui_inputs = [
        prompt_input, height_input, width_input, negative_prompt_input,
        duration_input, guidance_scale_input, steps_input, seed_input,
        randomize_seed_checkbox
    ]
    run_btn.click(fn=generate_video, inputs=ui_inputs, outputs=output_video)

# Launch
demo.launch()