FreeNoise / app.py
Anonymous
add 512
2afd1df
raw
history blame
10.7 kB
import gradio as gr
import os
import sys
import argparse
import random
from omegaconf import OmegaConf
import torch
import torchvision
from pytorch_lightning import seed_everything
from huggingface_hub import hf_hub_download
sys.path.insert(0, "scripts/evaluation")
from funcs import (
batch_ddim_sampling_freenoise,
load_model_checkpoint,
)
from utils.utils import instantiate_from_config
def infer(prompt, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps):
window_size = 16
window_stride = 4
if output_size == "320x512":
width = 512
height = 320
ckpt_dir_512 = "checkpoints/base_512_v1"
ckpt_path_512 = "checkpoints/base_512_v1/model_512.ckpt"
config_512 = "configs/inference_t2v_tconv512_v1.0_freenoise.yaml"
config_512 = OmegaConf.load(config_512)
model_config_512 = config_512.pop("model", OmegaConf.create())
model_512 = instantiate_from_config(model_config_512)
model_512 = model_512.cuda()
if not os.path.exists(ckpt_path_512):
os.makedirs(ckpt_dir_512, exist_ok=True)
hf_hub_download(repo_id="MoonQiu/LongerCrafter", filename="model_512.ckpt", local_dir=ckpt_dir_512)
model_512 = load_model_checkpoint(model_512, ckpt_path_512)
model_512.eval()
model = model_512
fps = 8
elif output_size == "576x1024":
width = 1024
height = 576
ckpt_dir_1024 = "checkpoints/base_1024_v1"
ckpt_path_1024 = "checkpoints/base_1024_v1/model.ckpt"
config_1024 = "configs/inference_t2v_1024_v1.0_freenoise.yaml"
config_1024 = OmegaConf.load(config_1024)
model_config_1024 = config_1024.pop("model", OmegaConf.create())
model_1024 = instantiate_from_config(model_config_1024)
model_1024 = model_1024.cuda()
if not os.path.exists(ckpt_path_1024):
os.makedirs(ckpt_dir_1024, exist_ok=True)
hf_hub_download(repo_id="VideoCrafter/Text2Video-1024", filename="model.ckpt", local_dir=ckpt_dir_1024)
model_1024 = load_model_checkpoint(model_1024, ckpt_path_1024)
model_1024.eval()
model = model_1024
fps = 28
num_frames = min(num_frames, 36)
# elif output_size == "256x256":
# width = 256
# height = 256
# ckpt_dir_256 = "checkpoints/base_256_v1"
# ckpt_path_256 = "checkpoints/base_256_v1/model_256.pth"
# config_256 = "configs/inference_t2v_tconv256_v1.0_freenoise.yaml"
# config_256 = OmegaConf.load(config_256)
# model_config_256 = config_256.pop("model", OmegaConf.create())
# model_256 = instantiate_from_config(model_config_256)
# model_256 = model_256.cuda()
# if not os.path.exists(ckpt_path_256):
# os.makedirs(ckpt_dir_256, exist_ok=True)
# hf_hub_download(repo_id="MoonQiu/LongerCrafter", filename="model_256.pth", local_dir=ckpt_dir_256)
# model_256 = load_model_checkpoint(model_256, ckpt_path_256)
# model_256.eval()
# model = model_256
# fps = 8
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
seed_everything(seed)
args = argparse.Namespace(
mode="base",
savefps=save_fps,
n_samples=1,
ddim_steps=ddim_steps,
ddim_eta=0.0,
bs=1,
height=height,
width=width,
frames=num_frames,
fps=fps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_guidance_scale_temporal=None,
cond_input=None,
window_size=window_size,
window_stride=window_stride,
)
## latent noise shape
h, w = args.height // 8, args.width // 8
frames = model.temporal_length if args.frames < 0 else args.frames
channels = model.channels
x_T_total = torch.randn(
[args.n_samples, 1, channels, frames, h, w], device=model.device
).repeat(1, args.bs, 1, 1, 1, 1)
for frame_index in range(args.window_size, args.frames, args.window_stride):
list_index = list(
range(
frame_index - args.window_size,
frame_index + args.window_stride - args.window_size,
)
)
random.shuffle(list_index)
x_T_total[
:, :, :, frame_index : frame_index + args.window_stride
] = x_T_total[:, :, :, list_index]
batch_size = 1
noise_shape = [batch_size, channels, frames, h, w]
fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
prompts = [prompt]
text_emb = model.get_learned_conditioning(prompts)
cond = {"c_crossattn": [text_emb], "fps": fps}
## inference
batch_samples = batch_ddim_sampling_freenoise(
model,
cond,
noise_shape,
args.n_samples,
args.ddim_steps,
args.ddim_eta,
args.unconditional_guidance_scale,
args=args,
x_T_total=x_T_total,
)
video_path = "output.mp4"
vid_tensor = batch_samples[0]
video = vid_tensor.detach().cpu()
video = torch.clamp(video.float(), -1.0, 1.0)
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
for framesheet in video
] # [3, 1*h, n*w]
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(
video_path,
grid,
fps=args.savefps,
video_codec="h264",
options={"crf": "10"},
)
print(video_path)
return video_path
examples = [
["A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect",],
["A corgi is swimming quickly",],
["A bigfoot walking in the snowstorm",],
["Campfire at night in a snowy forest with starry sky in the background",],
["A panda is surfing in the universe",],
]
css = """
#col-container {max-width: 640px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
.animate-spin {
animation: spin 1s linear infinite;
}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex;
padding-left: 0.5rem !important;
padding-right: 0.5rem !important;
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
max-width: 15rem;
height: 36px;
}
div#share-btn-container > div {
flex-direction: row;
background: black;
align-items: center;
}
#share-btn-container:hover {
background-color: #060606;
}
#share-btn {
all: initial;
color: #ffffff;
font-weight: 600;
cursor:pointer;
font-family: 'IBM Plex Sans', sans-serif;
margin-left: 0.5rem !important;
padding-top: 0.5rem !important;
padding-bottom: 0.5rem !important;
right:0;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
#share-btn-container.hidden {
display: none!important;
}
img[src*='#center'] {
display: inline-block;
margin: unset;
}
.footer {
margin-bottom: 45px;
margin-top: 10px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
<h1 style="text-align: center;">LongerCrafter(FreeNoise) Text-to-Video</h1>
<p style="text-align: center;">
Tuning-Free Longer Video Diffusion via Noise Rescheduling
</p>
<p style="text-align: center;">
[Arxiv](https://arxiv.org/abs/2310.15169) | [Project Page](http://haonanqiu.com/projects/FreeNoise.html) | [Github](https://github.com/arthur-qiu/LongerCrafter) <br />
</p>
"""
)
prompt_in = gr.Textbox(label="Prompt", placeholder="A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect")
with gr.Row():
with gr.Accordion('FreeNoise Parameters (feel free to adjust these parameters based on your prompt): ', open=False):
with gr.Row():
output_size = gr.Dropdown(["320x512", "576x1024"], value="320x512", label="Output Size", info="around 350s for 320x512, around 900s for 576x1024")
with gr.Row():
num_frames = gr.Slider(label='Frames (a multiple of 4), max 36 for 1024 model',
minimum=16,
maximum=64,
step=4,
value=32)
ddim_steps = gr.Slider(label='DDIM Steps',
minimum=5,
maximum=200,
step=1,
value=50)
with gr.Row():
unconditional_guidance_scale = gr.Slider(label='Unconditional Guidance Scale',
minimum=1.0,
maximum=20.0,
step=0.1,
value=12.0)
save_fps = gr.Slider(label='Save FPS',
minimum=1,
maximum=30,
step=1,
value=10)
with gr.Row():
seed = gr.Slider(label='Random Seed',
minimum=0,
maximum=10000,
step=1,
value=123)
submit_btn = gr.Button("Generate", variant='primary')
video_result = gr.Video(label="Video Output")
gr.Examples(examples=examples, inputs=[prompt_in, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps])
submit_btn.click(fn=infer,
inputs=[prompt_in, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps],
outputs=[video_result],
api_name="zrscp")
demo.queue(max_size=12).launch(show_api=True)