|
|
|
import gradio as gr |
|
|
|
import torch |
|
import os |
|
from glob import glob |
|
from pathlib import Path |
|
from typing import Optional |
|
import uuid |
|
import base64 |
|
from io import BytesIO |
|
import tempfile |
|
import numpy as np |
|
import cv2 |
|
import subprocess |
|
|
|
from DeepCache import DeepCacheSDHelper |
|
|
|
from PIL import Image |
|
from diffusers.utils import load_image, export_to_video |
|
from pipeline import StableVideoDiffusionPipeline |
|
|
|
import random |
|
from safetensors import safe_open |
|
from lcm_scheduler import AnimateLCMSVDStochasticIterativeScheduler |
|
|
|
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret') |
|
|
|
hardcoded_fps = 8 |
|
hardcoded_duration_sec = 3 |
|
|
|
def get_safetensors_files(): |
|
models_dir = "./safetensors" |
|
safetensors_files = [ |
|
f for f in os.listdir(models_dir) if f.endswith(".safetensors") |
|
] |
|
return safetensors_files |
|
|
|
|
|
def model_select(selected_file): |
|
print("load model weights", selected_file) |
|
pipe.unet.cpu() |
|
file_path = os.path.join("./safetensors", selected_file) |
|
state_dict = {} |
|
with safe_open(file_path, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
state_dict[key] = f.get_tensor(key) |
|
missing, unexpected = pipe.unet.load_state_dict(state_dict, strict=True) |
|
pipe.unet.cuda() |
|
del state_dict |
|
return |
|
|
|
def decode_data_uri_to_image(data_uri): |
|
|
|
header, encoded = data_uri.split(",", 1) |
|
data = base64.b64decode(encoded) |
|
img = Image.open(BytesIO(data)) |
|
return img |
|
|
|
|
|
|
|
|
|
|
|
|
|
def interpolate_video_frames( |
|
input_file_path, |
|
output_file_path, |
|
output_fps=hardcoded_fps, |
|
desired_duration=hardcoded_duration_sec, |
|
original_duration=hardcoded_duration_sec, |
|
output_width=None, |
|
output_height=None, |
|
use_cuda=False, |
|
verbose=False): |
|
|
|
scale_factor = desired_duration / original_duration |
|
|
|
filters = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if output_width and output_height: |
|
filters.append(f'scale={output_width}:{output_height}') |
|
|
|
|
|
|
|
|
|
interpolation_filter = f'minterpolate=mi_mode=mci:mc_mode=obmc:me=hexbs:vsbmc=1:mb_size=4:fps={output_fps}:scd=none,setpts={scale_factor}*PTS' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
filters.append(interpolation_filter) |
|
|
|
|
|
filter_complex = ','.join(filters) |
|
|
|
|
|
cmd = [ |
|
'ffmpeg', |
|
'-i', input_file_path, |
|
] |
|
|
|
|
|
if use_cuda: |
|
cmd.extend(['-hwaccel', 'cuda', '-hwaccel_output_format', 'cuda']) |
|
|
|
cmd.extend([ |
|
'-filter:v', filter_complex, |
|
'-r', str(output_fps), |
|
output_file_path |
|
]) |
|
|
|
|
|
if not verbose: |
|
cmd.insert(1, '-loglevel') |
|
cmd.insert(2, 'error') |
|
|
|
|
|
if verbose: |
|
print("output_fps:", output_fps) |
|
print("desired_duration:", desired_duration) |
|
print("original_duration:", original_duration) |
|
print("cmd:", cmd) |
|
|
|
try: |
|
subprocess.run(cmd, check=True) |
|
return output_file_path |
|
except subprocess.CalledProcessError as e: |
|
print("Failed to interpolate video. Error:", e) |
|
return input_file_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
def export_to_video_file(video_frames, output_video_path=None, fps=hardcoded_fps): |
|
if output_video_path is None: |
|
output_video_path = tempfile.NamedTemporaryFile(suffix=".webm").name |
|
|
|
if isinstance(video_frames[0], np.ndarray): |
|
video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] |
|
elif isinstance(video_frames[0], Image.Image): |
|
video_frames = [np.array(frame) for frame in video_frames] |
|
|
|
|
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'VP90') |
|
h, w, c = video_frames[0].shape |
|
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps, (w, h), True) |
|
|
|
for frame in video_frames: |
|
|
|
img = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
video_writer.write(img) |
|
video_writer.release() |
|
|
|
return output_video_path |
|
|
|
noise_scheduler = AnimateLCMSVDStochasticIterativeScheduler( |
|
num_train_timesteps=40, |
|
sigma_min=0.002, |
|
sigma_max=700.0, |
|
sigma_data=1.0, |
|
s_noise=1.0, |
|
rho=7, |
|
clip_denoised=False, |
|
) |
|
pipe = StableVideoDiffusionPipeline.from_pretrained( |
|
"stabilityai/stable-video-diffusion-img2vid-xt", |
|
scheduler=noise_scheduler, |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
) |
|
pipe.to("cuda") |
|
pipe.enable_model_cpu_offload() |
|
model_select("AnimateLCM-SVD-xt-1.1.safetensors") |
|
|
|
|
|
|
|
max_64_bit_int = 2**63 - 1 |
|
|
|
def sample( |
|
secret_token: str, |
|
input_image_base64: str, |
|
seed: Optional[int] = 42, |
|
randomize_seed: bool = True, |
|
motion_bucket_id: int = 33, |
|
desired_duration: int = hardcoded_duration_sec, |
|
desired_fps: int = hardcoded_fps, |
|
max_guidance_scale: float = 1.2, |
|
min_guidance_scale: float = 1, |
|
width: int = 832, |
|
height: int = 448, |
|
num_inference_steps: int = 4, |
|
|
|
|
|
|
|
|
|
|
|
decoding_t: int = 8, |
|
output_folder: str = "outputs_gradio", |
|
): |
|
if secret_token != SECRET_TOKEN: |
|
raise gr.Error( |
|
f'Invalid secret token. Please fork the original space if you want to use it for yourself.') |
|
|
|
image = decode_data_uri_to_image(input_image_base64) |
|
|
|
print(f"seed={seed}\nrandomize_seed={randomize_seed}\nmotion_bucket_id={motion_bucket_id}\ndesired_duration={desired_duration}\ndesired_fps={desired_fps}\nmax_guidance_scale={max_guidance_scale}\nmin_guidance_scale={min_guidance_scale}\nwidth={width}\nheight={height}\nnum_inference_steps={num_inference_steps}\ndecoding_t={decoding_t}") |
|
|
|
if image.mode == "RGBA": |
|
image = image.convert("RGB") |
|
|
|
if randomize_seed: |
|
seed = random.randint(0, max_64_bit_int) |
|
generator = torch.manual_seed(seed) |
|
|
|
os.makedirs(output_folder, exist_ok=True) |
|
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) |
|
|
|
video_uuid = str(uuid.uuid4()).replace("-", "") |
|
raw_video_path = f"/tmp/{video_uuid}_raw.mp4" |
|
enhanced_video_path = f"/tmp/{video_uuid}_enhanced.mp4" |
|
|
|
with torch.autocast("cuda"): |
|
frames = pipe( |
|
image, |
|
decode_chunk_size=decoding_t, |
|
generator=generator, |
|
motion_bucket_id=motion_bucket_id, |
|
height=height, |
|
width=width, |
|
num_inference_steps=num_inference_steps, |
|
min_guidance_scale=min_guidance_scale, |
|
max_guidance_scale=max_guidance_scale, |
|
).frames[0] |
|
|
|
|
|
|
|
export_to_video(frames, raw_video_path, fps=hardcoded_fps) |
|
|
|
torch.manual_seed(seed) |
|
|
|
final_video_path = interpolate_video_frames(raw_video_path, enhanced_video_path, output_fps=desired_fps, desired_duration=desired_duration) |
|
|
|
|
|
|
|
with open(final_video_path, "rb") as video_file: |
|
video_base64 = base64.b64encode(video_file.read()).decode('utf-8') |
|
|
|
|
|
return 'data:video/mp4;base64,' + video_base64 |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(""" |
|
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;"> |
|
<div style="text-align: center; color: black;"> |
|
<p style="color: black;">This space is a headless component of the cloud rendering engine used by AiTube.</p> |
|
<p style="color: black;">It is not available for public use, but you can use the <a href="https://huggingface.co/spaces/doevent/AnimateLCM-SVD" target="_blank">original space</a>.</p> |
|
</div> |
|
</div>""") |
|
with gr.Row(): |
|
secret_token = gr.Textbox() |
|
image_input_base64 = gr.Textbox() |
|
generate_btn = gr.Button("Generate") |
|
video_output_base64 = gr.Textbox() |
|
|
|
seed = gr.Slider( |
|
label="Seed", |
|
value=42, |
|
randomize=False, |
|
minimum=0, |
|
maximum=max_64_bit_int, |
|
step=1, |
|
) |
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=False) |
|
motion_bucket_id = gr.Slider( |
|
label="Motion bucket id", |
|
info="Controls how much motion to add/remove from the image", |
|
value=80, |
|
minimum=1, |
|
maximum=255, |
|
) |
|
duration_slider = gr.Slider(label="Desired Duration (seconds)", min_value=1, max_value=120, value=hardcoded_duration_sec, step=0.1) |
|
fps_slider = gr.Slider(label="Desired Frames Per Second", min_value=5, max_value=60, value=hardcoded_fps, step=1) |
|
|
|
|
|
|
|
|
|
width = gr.Slider( |
|
label="Width of input image", |
|
info="It should be divisible by 64", |
|
value=832, |
|
minimum=256, |
|
maximum=2048, |
|
step=64, |
|
) |
|
height = gr.Slider( |
|
label="Height of input image", |
|
info="It should be divisible by 64", |
|
value=448, |
|
minimum=256, |
|
maximum=1152, |
|
) |
|
max_guidance_scale = gr.Slider( |
|
label="Max guidance scale", |
|
info="classifier-free guidance strength", |
|
value=1.2, |
|
minimum=1, |
|
maximum=2, |
|
) |
|
min_guidance_scale = gr.Slider( |
|
label="Min guidance scale", |
|
info="classifier-free guidance strength", |
|
value=1, |
|
minimum=1, |
|
maximum=1.5, |
|
) |
|
num_inference_steps = gr.Slider( |
|
label="Num inference steps", |
|
info="steps for inference", |
|
value=4, |
|
minimum=1, |
|
maximum=20, |
|
step=1, |
|
) |
|
|
|
generate_btn.click( |
|
fn=sample, |
|
inputs=[ |
|
secret_token, |
|
image_input_base64, |
|
seed, |
|
randomize_seed, |
|
motion_bucket_id, |
|
duration_slider, |
|
fps_slider, |
|
max_guidance_scale, |
|
min_guidance_scale, |
|
width, |
|
height, |
|
num_inference_steps, |
|
], |
|
outputs=video_output_base64, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue() |
|
demo.launch(show_error=True) |
|
|