Spaces:
Running
on
L40S
Running
on
L40S
File size: 5,120 Bytes
464d284 3a757b0 464d284 f261a00 464d284 77b7165 464d284 77b7165 464d284 a9e7aa9 464d284 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import gc
import time
import random
import imageio
import torch
import gradio as gr
from diffusers.utils import load_image
from skyreels_v2_infer.modules import download_model
from skyreels_v2_infer.pipelines import Image2VideoPipeline, Text2VideoPipeline, PromptEnhancer, resizecrop
MODEL_ID_CONFIG = {
"text2video": [
"Skywork/SkyReels-V2-T2V-14B-540P",
"Skywork/SkyReels-V2-T2V-14B-720P",
],
"image2video": [
"Skywork/SkyReels-V2-I2V-1.3B-540P",
"Skywork/SkyReels-V2-I2V-14B-540P",
"Skywork/SkyReels-V2-I2V-14B-720P",
],
}
def generate_video(
prompt,
model_id,
resolution,
num_frames,
image=None,
guidance_scale=6.0,
shift=8.0,
inference_steps=30,
use_usp=False,
offload=False,
fps=24,
seed=None,
prompt_enhancer=False,
teacache=False,
teacache_thresh=0.2,
use_ret_steps=False
):
model_id = download_model(model_id)
if resolution == "540P":
height, width = 544, 960
elif resolution == "720P":
height, width = 720, 1280
else:
raise ValueError(f"Invalid resolution: {resolution}")
if seed is None:
random.seed(time.time())
seed = int(random.randrange(4294967294))
if image is not None:
image = load_image(image).convert("RGB")
image_width, image_height = image.size
if image_height > image_width:
height, width = width, height
image = resizecrop(image, height, width)
negative_prompt = (
"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, "
"overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, "
"poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, "
"three legs, many people in the background, walking backwards"
)
prompt_input = prompt
if prompt_enhancer and image is None:
enhancer = PromptEnhancer()
prompt_input = enhancer(prompt_input)
del enhancer
gc.collect()
torch.cuda.empty_cache()
if image is None:
pipe = Text2VideoPipeline(model_path=model_id, dit_path=model_id, use_usp=use_usp, offload=offload)
else:
pipe = Image2VideoPipeline(model_path=model_id, dit_path=model_id, use_usp=use_usp, offload=offload)
if teacache:
pipe.transformer.initialize_teacache(
enable_teacache=True,
num_steps=inference_steps,
teacache_thresh=teacache_thresh,
use_ret_steps=use_ret_steps,
ckpt_dir=model_id,
)
kwargs = {
"prompt": prompt_input,
"negative_prompt": negative_prompt,
"num_frames": num_frames,
"num_inference_steps": inference_steps,
"guidance_scale": guidance_scale,
"shift": shift,
"generator": torch.Generator(device="cuda").manual_seed(seed),
"height": height,
"width": width,
}
if image is not None:
kwargs["image"] = image.convert("RGB")
with torch.amp.autocast("cuda", dtype=pipe.transformer.dtype), torch.no_grad():
video_frames = pipe(**kwargs)[0]
os.makedirs("gradio_videos", exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_path = f"gradio_videos/{prompt[:50].replace('/', '')}_{seed}_{timestamp}.mp4"
imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])
return output_path
# Gradio UI
resolution_options = ["540P", "720P"]
model_options = MODEL_ID_CONFIG["text2video"] + MODEL_ID_CONFIG["image2video"]
app = gr.Interface(
fn=generate_video,
inputs=[
gr.Textbox(label="Prompt"),
gr.Dropdown(choices=model_options, value="Skywork/SkyReels-V2-I2V-1.3B-540P", label="Model ID", interactive=False),
gr.Radio(choices=resolution_options, value="540P", label="Resolution", interactive=False),
gr.Slider(minimum=16, maximum=200, value=97, step=1, label="Number of Frames"),
gr.Image(type="filepath", label="Input Image (optional)"),
gr.Slider(minimum=1.0, maximum=20.0, value=5.0, step=0.1, label="Guidance Scale"),
gr.Slider(minimum=0.0, maximum=20.0, value=3.0, step=0.1, label="Shift"),
gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Inference Steps"),
gr.Checkbox(label="Use USP"),
gr.Checkbox(label="Offload", value=True, interactive=False),
gr.Slider(minimum=1, maximum=60, value=24, step=1, label="FPS"),
gr.Number(label="Seed (optional, random if empty)", precision=0),
gr.Checkbox(label="Prompt Enhancer"),
gr.Checkbox(label="Use TeaCache"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.01, label="TeaCache Threshold"),
gr.Checkbox(label="Use Retention Steps"),
],
outputs=gr.Video(label="Generated Video"),
title="SkyReels V2 Video Generator",
)
app.launch(show_api=False, show_error=True, ssr_mode=False)
|