File size: 12,709 Bytes
1d73f5a
 
 
 
 
 
 
 
 
 
 
 
 
45fa39d
c9d89c3
45fa39d
9afd27b
 
 
 
1d73f5a
 
 
f7e7660
7255ed6
 
f7e7660
1d73f5a
 
 
 
f7e7660
 
1d73f5a
 
 
 
d090378
1d73f5a
 
 
f7e7660
1d73f5a
f7e7660
1d73f5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45fa39d
1d73f5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc282b
 
 
 
 
 
 
064d5cb
 
bb24670
064d5cb
 
 
d090378
 
45fa39d
 
d090378
45fa39d
 
 
 
 
 
 
 
 
 
 
 
 
d090378
 
 
 
 
54d3be3
 
064d5cb
 
bb24670
064d5cb
 
54d3be3
 
d090378
 
 
54d3be3
dfc282b
d090378
 
 
 
 
 
 
 
 
e381ffb
dfc282b
d090378
 
e381ffb
c9d89c3
 
 
d090378
 
 
 
 
 
 
eaadc4b
 
7255ed6
817bfe2
7255ed6
817bfe2
eaadc4b
7255ed6
 
 
 
 
eaadc4b
 
54d3be3
 
 
45fa39d
 
4cf9612
c9d89c3
4cf9612
c9d89c3
 
 
54d3be3
 
 
c9d89c3
 
 
54d3be3
 
45fa39d
 
3a7c34b
f7e7660
3a7c34b
54d3be3
 
45fa39d
 
3a7c34b
f398a44
3a7c34b
54d3be3
 
45fa39d
 
4d4c22b
 
 
317c75b
 
 
54d3be3
 
45fa39d
 
4d4c22b
 
 
317c75b
 
 
54d3be3
c9d89c3
54d3be3
 
 
 
 
c9d89c3
54d3be3
 
 
d090378
 
 
 
7255ed6
 
d090378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403283f
d090378
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import os
import gc
import time
import random
import torch
import imageio
import gradio as gr
from diffusers.utils import load_image

from skyreels_v2_infer import DiffusionForcingPipeline
from skyreels_v2_infer.modules import download_model
from skyreels_v2_infer.pipelines import PromptEnhancer, resizecrop

is_shared_ui = True if "fffiloni/SkyReels-V2" in os.environ['SPACE_ID'] else False
#is_shared_ui = False

model_id = None
if is_shared_ui:
    model_id = download_model("Skywork/SkyReels-V2-DF-1.3B-540P")

def generate_diffusion_forced_video(
    prompt,
    image=None,
    target_length="10",
    model_id="Skywork/SkyReels-V2-DF-1.3B-540P",
    resolution="540P",
    num_frames=257,
    ar_step=0,
    causal_attention=False,
    causal_block_size=1,
    base_num_frames=97,
    overlap_history=17,
    addnoise_condition=20,
    guidance_scale=6.0,
    shift=8.0,
    inference_steps=30,
    use_usp=False,
    offload=True,
    fps=24,
    seed=None,
    prompt_enhancer=False,
    teacache=True,
    teacache_thresh=0.2,
    use_ret_steps=True,
):
    model_id = download_model(model_id)

    if resolution == "540P":
        height, width = 544, 960
    elif resolution == "720P":
        height, width = 720, 1280
    else:
        raise ValueError(f"Invalid resolution: {resolution}")

    if seed is None:
        random.seed(time.time())
        seed = int(random.randrange(4294967294))

    if num_frames > base_num_frames and overlap_history is None:
        raise ValueError("Specify `overlap_history` for long video generation. Try 17 or 37.")
    if addnoise_condition > 60:
        print("Warning: Large `addnoise_condition` may reduce consistency. Recommended: 20.")

    if image is not None:
        image = load_image(image).convert("RGB")
        image_width, image_height = image.size
        if image_height > image_width:
            height, width = width, height
        image = resizecrop(image, height, width)

    negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

    prompt_input = prompt
    if prompt_enhancer and image is None:
        enhancer = PromptEnhancer()
        prompt_input = enhancer(prompt_input)
        del enhancer
        gc.collect()
        torch.cuda.empty_cache()

    pipe = DiffusionForcingPipeline(
        model_id,
        dit_path=model_id,
        device=torch.device("cuda"),
        weight_dtype=torch.bfloat16,
        use_usp=use_usp,
        offload=offload,
    )

    if causal_attention:
        pipe.transformer.set_ar_attention(causal_block_size)

    if teacache:
        if ar_step > 0:
            num_steps = (
                inference_steps + (((base_num_frames - 1) // 4 + 1) // causal_block_size - 1) * ar_step
            )
        else:
            num_steps = inference_steps
        pipe.transformer.initialize_teacache(
            enable_teacache=True,
            num_steps=num_steps,
            teacache_thresh=teacache_thresh,
            use_ret_steps=use_ret_steps,
            ckpt_dir=model_id,
        )

    with torch.amp.autocast("cuda", dtype=pipe.transformer.dtype), torch.no_grad():
        video_frames = pipe(
            prompt=prompt_input,
            negative_prompt=negative_prompt,
            image=image,
            height=height,
            width=width,
            num_frames=num_frames,
            num_inference_steps=inference_steps,
            shift=shift,
            guidance_scale=guidance_scale,
            generator=torch.Generator(device="cuda").manual_seed(seed),
            overlap_history=overlap_history,
            addnoise_condition=addnoise_condition,
            base_num_frames=base_num_frames,
            ar_step=ar_step,
            causal_block_size=causal_block_size,
            fps=fps,
        )[0]

    os.makedirs("gradio_df_videos", exist_ok=True)
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    output_path = f"gradio_df_videos/{prompt[:50].replace('/', '')}_{seed}_{timestamp}.mp4"
    imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])
    return output_path


# Gradio UI
resolution_options = ["540P", "720P"]
model_options = ["Skywork/SkyReels-V2-DF-1.3B-540P"]  # Update if there are more

if is_shared_ui is False:
    model_options = [
        "Skywork/SkyReels-V2-DF-1.3B-540P",
        "Skywork/SkyReels-V2-DF-14B-540P",
        "Skywork/SkyReels-V2-DF-14B-720P"
    ]

length_options = []
if is_shared_ui is True:
    length_options = ["4", "10"]
else:
    length_options = ["4", "10", "15", "30", "60"]

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# SkyReels V2: Infinite-Length Film Generation")
        gr.Markdown("The first open-source video generative model employing AutoRegressive Diffusion-Forcing architecture that achieves the SOTA performance among publicly available models.")

        gr.HTML("""
            <div style="display:flex;column-gap:4px;">
                <a href="https://github.com/SkyworkAI/SkyReels-V2">
                    <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
                </a> 
    			<a href="https://arxiv.org/pdf/2504.13074">
                    <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
                </a>
                <a href="https://huggingface.co/spaces/fffiloni/SkyReels-V2?duplicate=true">
    				<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
    			</a>	
            </div>
        """)
        with gr.Row():

            with gr.Column():

                prompt = gr.Textbox(label="Prompt")

                with gr.Row():
                    if is_shared_ui:
                        target_length = gr.Radio(label="Video length target", choices=length_options, value="4")
                        forbidden_length = gr.Radio(label="Available target on duplicated instance", choices=["15","30","60"], value=None, interactive=False)
                    else:
                        target_length = gr.Radio(label="Video length target", choices=length_options, value="4")
                
                num_frames = gr.Slider(minimum=17, maximum=257, value=97, step=20, label="Number of Frames", interactive=False)
                image = gr.Image(type="filepath", label="Input Image (optional)")
                
                with gr.Accordion("Advanced Settings", open=False):
                    model_id = gr.Dropdown(choices=model_options, value=model_options[0], label="Model ID")
                    resolution = gr.Radio(choices=resolution_options, value="540P", label="Resolution", interactive=False if is_shared_ui else True)
                    ar_step = gr.Number(label="AR Step", value=0)
                    causal_attention = gr.Checkbox(label="Causal Attention")
                    causal_block_size = gr.Number(label="Causal Block Size", value=1)
                    base_num_frames = gr.Number(label="Base Num Frames", value=97)
                    overlap_history = gr.Number(label="Overlap History (set for long videos)", value=None)
                    addnoise_condition = gr.Number(label="AddNoise Condition", value=0)
                    guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=6.0, step=0.1, label="Guidance Scale")
                    shift = gr.Slider(minimum=0.0, maximum=20.0, value=8.0, step=0.1, label="Shift")
                    inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Inference Steps")
                    use_usp = gr.Checkbox(label="Use USP", visible=False if is_shared_ui else True)
                    offload = gr.Checkbox(label="Offload", value=True, interactive=False if is_shared_ui else True)
                    fps = gr.Slider(minimum=1, maximum=60, value=24, step=1, label="FPS")
                    seed = gr.Number(label="Seed (optional)", precision=0)
                    prompt_enhancer = gr.Checkbox(label="Prompt Enhancer", visible=False if is_shared_ui else True)
                    use_teacache = gr.Checkbox(label="Use TeaCache", value=True)
                    teacache_thresh = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.01, label="TeaCache Threshold")
                    use_ret_steps = gr.Checkbox(label="Use Retention Steps", value=True)

                submit_btn = gr.Button("Generate")

            with gr.Column():

                output_video = gr.Video(label="Generated Video")

                gr.Examples(
                    examples = [
                        ["A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed.", "./examples/swan.jpeg", "10"],
                       # ["A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed.", None],
                        ["A sea turtle swimming near a shipwreck", "./examples/turtle.jpeg", "10"],
                       # ["A sea turtle swimming near a shipwreck", None],
                    ],
                    fn = generate_diffusion_forced_video,
                    inputs = [prompt, image, target_length],
                    outputs = [output_video],
                    cache_examples = True,
                    cache_mode = "lazy"
                )

    def set_num_frames(target_l):

        n_frames = 0
        overlap_history = 0
        addnoise_condition = 0
        ar_step = 0
        causal_attention = False
        causal_block_size = 1
        use_teacache = True
        teacache_thresh = 0.2
        use_ret_steps = True

        if target_l == "4":
            n_frames = 97
            use_teacache = True
            teacache_thresh = 0.2
            use_ret_steps = True
        elif target_l == "10":
            n_frames = 257
            overlap_history = 17
            addnoise_condition = 20
            use_teacache = True
            teacache_thresh = 0.2
            use_ret_steps = True
        elif target_l == "15":
            n_frames = 377
            overlap_history = 17
            addnoise_condition = 20
            use_teacache = True
            teacache_thresh = 0.3
            use_ret_steps = True
        elif target_l == "30":
            n_frames = 737
            overlap_history = 17
            addnoise_condition = 20
            use_teacache = True
            teacache_thresh = 0.3
            use_ret_steps = True
            causal_attention = False
            ar_step = 0
            causal_block_size = 1
        elif target_l == "60":
            n_frames = 1457
            overlap_history = 17
            addnoise_condition = 20
            use_teacache = True
            teacache_thresh = 0.3
            use_ret_steps = True
            causal_attention = False
            ar_step = 0
            causal_block_size = 0
        
        return n_frames, overlap_history, addnoise_condition, ar_step, causal_attention, causal_block_size, use_teacache, teacache_thresh, use_ret_steps
        

    target_length.change(
        fn = set_num_frames,
        inputs = [target_length],
        outputs = [num_frames, overlap_history, addnoise_condition, ar_step, causal_attention, causal_block_size, use_teacache, teacache_thresh, use_ret_steps],
        queue = False
    )

    submit_btn.click(
        fn = generate_diffusion_forced_video,
        inputs = [
            prompt,
            image,
            target_length,
            model_id,
            resolution,
            num_frames,
            ar_step,
            causal_attention,
            causal_block_size,
            base_num_frames,
            overlap_history,
            addnoise_condition,
            guidance_scale,
            shift,
            inference_steps,
            use_usp,
            offload,
            fps,
            seed,
            prompt_enhancer,
            use_teacache,
            teacache_thresh,
            use_ret_steps
        ],
        outputs = [
            output_video
        ]
    )

demo.launch(show_error=True, show_api=False, share=False)