FreeNoise / app.py
Anonymous
add 512
186fae0
raw
history blame
No virus
10.3 kB
import gradio as gr
import os
import sys
import argparse
import random
from omegaconf import OmegaConf
import torch
import torchvision
from pytorch_lightning import seed_everything
from huggingface_hub import hf_hub_download
sys.path.insert(0, "scripts/evaluation")
from funcs import (
batch_ddim_sampling_freenoise,
load_model_checkpoint,
)
from utils.utils import instantiate_from_config
ckpt_path_512 = "checkpoints/base_512_v1/model_512.ckpt"
ckpt_dir_512 = "checkpoints/base_512_v1"
os.makedirs(ckpt_dir_512, exist_ok=True)
hf_hub_download(repo_id="MoonQiu/LongerCrafter", filename="model_512.ckpt", local_dir=ckpt_dir_512)
# ckpt_path_1024 = "checkpoints/base_1024_v1/model.ckpt"
# ckpt_dir_1024 = "checkpoints/base_1024_v1"
# os.makedirs(ckpt_dir_1024, exist_ok=True)
# hf_hub_download(repo_id="VideoCrafter/Text2Video-1024", filename="model.ckpt", local_dir=ckpt_dir_1024)
# ckpt_path_256 = "checkpoints/base_256_v1/model_256.pth"
# ckpt_dir_256 = "checkpoints/base_256_v1"
# os.makedirs(ckpt_dir_256, exist_ok=True)
# hf_hub_download(repo_id="MoonQiu/LongerCrafter", filename="model_256.pth", local_dir=ckpt_dir_256)
def infer(prompt, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps):
window_size = 16
window_stride = 4
if output_size == "320x512":
width = 512
height = 320
config_512 = "configs/inference_t2v_tconv512_v1.0_freenoise.yaml"
config_512 = OmegaConf.load(config_512)
model_config_512 = config_512.pop("model", OmegaConf.create())
model_512 = instantiate_from_config(model_config_512)
model_512 = model_512.cuda()
model_512 = load_model_checkpoint(model_512, ckpt_path_512)
model_512.eval()
model = model_512
fps = 8
elif output_size == "576x1024":
width = 1024
height = 576
config_1024 = "configs/inference_t2v_1024_v1.0_freenoise.yaml"
config_1024 = OmegaConf.load(config_1024)
model_config_1024 = config_1024.pop("model", OmegaConf.create())
model_1024 = instantiate_from_config(model_config_1024)
model_1024 = model_1024.cuda()
model_1024 = load_model_checkpoint(model_1024, ckpt_path_1024)
model_1024.eval()
model = model_1024
fps = 28
# elif output_size == "256x256":
# width = 256
# height = 256
# config_256 = "configs/inference_t2v_tconv256_v1.0_freenoise.yaml"
# config_256 = OmegaConf.load(config_256)
# model_config_256 = config_256.pop("model", OmegaConf.create())
# model_256 = instantiate_from_config(model_config_256)
# model_256 = model_256.cuda()
# model_256 = load_model_checkpoint(model_256, ckpt_path_256)
# model_256.eval()
# model = model_256
# fps = 8
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
seed_everything(seed)
args = argparse.Namespace(
mode="base",
savefps=save_fps,
n_samples=1,
ddim_steps=ddim_steps,
ddim_eta=0.0,
bs=1,
height=height,
width=width,
frames=num_frames,
fps=fps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_guidance_scale_temporal=None,
cond_input=None,
window_size=window_size,
window_stride=window_stride,
)
## latent noise shape
h, w = args.height // 8, args.width // 8
frames = model.temporal_length if args.frames < 0 else args.frames
channels = model.channels
x_T_total = torch.randn(
[args.n_samples, 1, channels, frames, h, w], device=model.device
).repeat(1, args.bs, 1, 1, 1, 1)
for frame_index in range(args.window_size, args.frames, args.window_stride):
list_index = list(
range(
frame_index - args.window_size,
frame_index + args.window_stride - args.window_size,
)
)
random.shuffle(list_index)
x_T_total[
:, :, :, frame_index : frame_index + args.window_stride
] = x_T_total[:, :, :, list_index]
batch_size = 1
noise_shape = [batch_size, channels, frames, h, w]
fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
prompts = [prompt]
text_emb = model.get_learned_conditioning(prompts)
cond = {"c_crossattn": [text_emb], "fps": fps}
## inference
batch_samples = batch_ddim_sampling_freenoise(
model,
cond,
noise_shape,
args.n_samples,
args.ddim_steps,
args.ddim_eta,
args.unconditional_guidance_scale,
args=args,
x_T_total=x_T_total,
)
video_path = "output.mp4"
vid_tensor = batch_samples[0]
video = vid_tensor.detach().cpu()
video = torch.clamp(video.float(), -1.0, 1.0)
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
for framesheet in video
] # [3, 1*h, n*w]
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(
video_path,
grid,
fps=args.savefps,
video_codec="h264",
options={"crf": "10"},
)
print(video_path)
return video_path
examples = [
["A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect",],
["A corgi is swimming quickly",],
["A bigfoot walking in the snowstorm",],
["Campfire at night in a snowy forest with starry sky in the background",],
["A panda is surfing in the universe",],
]
css = """
#col-container {max-width: 640px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
.animate-spin {
animation: spin 1s linear infinite;
}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex;
padding-left: 0.5rem !important;
padding-right: 0.5rem !important;
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
max-width: 15rem;
height: 36px;
}
div#share-btn-container > div {
flex-direction: row;
background: black;
align-items: center;
}
#share-btn-container:hover {
background-color: #060606;
}
#share-btn {
all: initial;
color: #ffffff;
font-weight: 600;
cursor:pointer;
font-family: 'IBM Plex Sans', sans-serif;
margin-left: 0.5rem !important;
padding-top: 0.5rem !important;
padding-bottom: 0.5rem !important;
right:0;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
#share-btn-container.hidden {
display: none!important;
}
img[src*='#center'] {
display: inline-block;
margin: unset;
}
.footer {
margin-bottom: 45px;
margin-top: 10px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
<h1 style="text-align: center;">LongerCrafter(FreeNoise) Text-to-Video</h1>
<p style="text-align: center;">
Tuning-Free Longer Video Diffusion via Noise Rescheduling <br />
</p>
"""
)
prompt_in = gr.Textbox(label="Prompt", placeholder="A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect")
with gr.Row():
with gr.Accordion('FreeNoise Parameters (feel free to adjust these parameters based on your prompt): ', open=False):
with gr.Row():
# output_size = gr.Dropdown(["576x1024"], value="576x1024", label="Output Size (around 900s for 576x1024)")
output_size = gr.Dropdown(["320x512", "576x1024"], value="320x512", label="Output Size", info="576x1024 will cost around 900s")
with gr.Row():
num_frames = gr.Slider(label='Frames (a multiple of 4)',
minimum=16,
maximum=36,
step=4,
value=32)
ddim_steps = gr.Slider(label='DDIM Steps',
minimum=5,
maximum=200,
step=1,
value=50)
with gr.Row():
unconditional_guidance_scale = gr.Slider(label='Unconditional Guidance Scale',
minimum=1.0,
maximum=20.0,
step=0.1,
value=12.0)
save_fps = gr.Slider(label='Save FPS',
minimum=1,
maximum=30,
step=1,
value=10)
with gr.Row():
seed = gr.Slider(label='Random Seed',
minimum=0,
maximum=10000,
step=1,
value=123)
submit_btn = gr.Button("Generate")
video_result = gr.Video(label="Video Output")
gr.Examples(examples=examples, inputs=[prompt_in, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps])
submit_btn.click(fn=infer,
inputs=[prompt_in, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps],
outputs=[video_result],
api_name="zrscp")
demo.queue(max_size=12).launch(show_api=True)