Spaces:
Running
Running
import torch | |
from diffusers import StableVideoDiffusionPipeline | |
from diffusers.utils import load_image | |
from PIL import Image | |
from .tdd_svd_scheduler import TDDSVDStochasticIterativeScheduler | |
from .utils import load_lora_weights, save_video | |
svd_path = 'stabilityai/stable-video-diffusion-img2vid-xt-1-1' | |
lora_repo_path = 'RED-AIGC/TDD' | |
lora_weight_name = 'svd-xt-1-1_tdd_lora_weights.safetensors' | |
if torch.cuda.is_available(): | |
noise_scheduler = TDDSVDStochasticIterativeScheduler(num_train_timesteps = 250, sigma_min = 0.002, sigma_max = 700.0, sigma_data = 1.0, | |
s_noise = 1.0, rho = 7, clip_denoised = False) | |
pipeline = StableVideoDiffusionPipeline.from_pretrained(svd_path, scheduler = noise_scheduler, torch_dtype = torch.float16, variant = "fp16").to('cuda') | |
load_lora_weights(pipeline.unet, lora_repo_path, weight_name = lora_weight_name) | |
def Video( | |
image: Image, | |
seed: Optional[int] = 1, | |
randomize_seed: bool = False, | |
num_inference_steps: int = 4, | |
eta: float = 0.3, | |
min_guidance_scale: float = 1.0, | |
max_guidance_scale: float = 1.0, | |
fps: int = 7, | |
width: int = 512, | |
height: int = 512, | |
num_frames: int = 25, | |
motion_bucket_id: int = 127, | |
output_folder: str = "outputs_gradio", | |
): | |
pipeline.scheduler.set_eta(eta) | |
if randomize_seed: | |
seed = random.randint(0, max_64_bit_int) | |
generator = torch.manual_seed(seed) | |
os.makedirs(output_folder, exist_ok=True) | |
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) | |
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") | |
with torch.autocast("cuda"): | |
frames = pipeline( | |
image, height = height, width = width, | |
num_inference_steps = num_inference_steps, | |
min_guidance_scale = min_guidance_scale, | |
max_guidance_scale = max_guidance_scale, | |
num_frames = num_frames, fps = fps, motion_bucket_id = motion_bucket_id, | |
decode_chunk_size = 8, | |
noise_aug_strength = 0.02, | |
generator = generator, | |
).frames[0] | |
save_video(frames, video_path, fps = fps, quality = 5.0) | |
torch.manual_seed(seed) | |
return video_path, seed |