FreeNoise / app.py
Anonymous
fix app
b135a58
raw
history blame
No virus
8 kB
import gradio as gr
import os
import sys
import argparse
import random
from omegaconf import OmegaConf
import torch
import torchvision
from pytorch_lightning import seed_everything
from huggingface_hub import hf_hub_download
sys.path.insert(0, "scripts/evaluation")
from funcs import (
batch_ddim_sampling_freenoise,
load_model_checkpoint,
)
from utils.utils import instantiate_from_config
ckpt_path_1024 = "checkpoints/base_1024_v1/model.ckpt"
ckpt_dir_1024 = "checkpoints/base_1024_v1"
os.makedirs(ckpt_dir_1024, exist_ok=True)
# hf_hub_download(repo_id="VideoCrafter/Text2Video-1024", filename="model.ckpt", local_dir=ckpt_dir_1024)
ckpt_path_256 = "checkpoints/base_256_v1/model.pth"
ckpt_dir_256 = "checkpoints/base_256_v1"
os.makedirs(ckpt_dir_256, exist_ok=True)
hf_hub_download(repo_id="MoonQiu/LongerCrafter", filename="model.pth", local_dir=ckpt_dir_256)
def infer(prompt):
output_size = "256x256"
num_frames = 32
ddim_steps = 50
unconditional_guidance_scale = 12.0
seed = 123
save_fps = 10
window_size = 16
window_stride = 4
if output_size == "576x1024":
width = 1024
height = 576
config_1024 = "configs/inference_t2v_1024_v1.0_freenoise.yaml"
config_1024 = OmegaConf.load(config_1024)
model_config_1024 = config_1024.pop("model", OmegaConf.create())
model_1024 = instantiate_from_config(model_config_1024)
model_1024 = model_1024.cuda()
model_1024 = load_model_checkpoint(model_1024, ckpt_path_1024)
model_1024.eval()
model = model_1024
fps = 24
elif output_size == "256x256":
width = 256
height = 256
config_256 = "configs/inference_t2v_tconv256_v1.0_freenoise.yaml"
config_256 = OmegaConf.load(config_256)
model_config_256 = config_256.pop("model", OmegaConf.create())
model_256 = instantiate_from_config(model_config_256)
model_256 = model_256.cuda()
model_256 = load_model_checkpoint(model_256, ckpt_path_256)
model_256.eval()
model = model_256
fps = 8
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
seed_everything(seed)
args = argparse.Namespace(
mode="base",
savefps=save_fps,
n_samples=1,
ddim_steps=ddim_steps,
ddim_eta=0.0,
bs=1,
height=height,
width=width,
frames=num_frames,
fps=fps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_guidance_scale_temporal=None,
cond_input=None,
window_size=window_size,
window_stride=window_stride,
)
## latent noise shape
h, w = args.height // 8, args.width // 8
frames = model.temporal_length if args.frames < 0 else args.frames
channels = model.channels
x_T_total = torch.randn(
[args.n_samples, 1, channels, frames, h, w], device=model.device
).repeat(1, args.bs, 1, 1, 1, 1)
for frame_index in range(args.window_size, args.frames, args.window_stride):
list_index = list(
range(
frame_index - args.window_size,
frame_index + args.window_stride - args.window_size,
)
)
random.shuffle(list_index)
x_T_total[
:, :, :, frame_index : frame_index + args.window_stride
] = x_T_total[:, :, :, list_index]
batch_size = 1
noise_shape = [batch_size, channels, frames, h, w]
fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
prompts = [prompt]
text_emb = model.get_learned_conditioning(prompts)
cond = {"c_crossattn": [text_emb], "fps": fps}
## inference
batch_samples = batch_ddim_sampling_freenoise(
model,
cond,
noise_shape,
args.n_samples,
args.ddim_steps,
args.ddim_eta,
args.unconditional_guidance_scale,
args=args,
x_T_total=x_T_total,
)
video_path = "output.mp4"
vid_tensor = batch_samples[0]
video = vid_tensor.detach().cpu()
video = torch.clamp(video.float(), -1.0, 1.0)
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
for framesheet in video
] # [3, 1*h, n*w]
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(
video_path,
grid,
fps=args.savefps,
video_codec="h264",
options={"crf": "10"},
)
print(video_path)
return video_path
examples = [
"A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect",
"Campfire at night in a snowy forest with starry sky in the background",
"A dark knight riding a black horse on the grassland, in sunset",
"A corgi is swimming quickly",
"A panda is surfing in the universe",
]
css = """
#col-container {max-width: 510px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
.animate-spin {
animation: spin 1s linear infinite;
}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex;
padding-left: 0.5rem !important;
padding-right: 0.5rem !important;
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
max-width: 15rem;
height: 36px;
}
div#share-btn-container > div {
flex-direction: row;
background: black;
align-items: center;
}
#share-btn-container:hover {
background-color: #060606;
}
#share-btn {
all: initial;
color: #ffffff;
font-weight: 600;
cursor:pointer;
font-family: 'IBM Plex Sans', sans-serif;
margin-left: 0.5rem !important;
padding-top: 0.5rem !important;
padding-bottom: 0.5rem !important;
right:0;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
#share-btn-container.hidden {
display: none!important;
}
img[src*='#center'] {
display: inline-block;
margin: unset;
}
.footer {
margin-bottom: 45px;
margin-top: 10px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
<h1 style="text-align: center;">LongerCrafter(FreeNoise) Text-to-Video</h1>
<p style="text-align: center;">
Tuning-Free Longer Video Diffusion via Noise Rescheduling <br />
</p>
"""
)
prompt_in = gr.Textbox(label="Prompt", placeholder="A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect", elem_id="prompt-in")
#neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in")
#inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False)
submit_btn = gr.Button("Generate")
video_result = gr.Video(label="Video Output", elem_id="video-output")
gr.Examples(examples=examples, inputs=[prompt_in])
submit_btn.click(fn=infer,
inputs=[prompt_in],
outputs=[video_result],
api_name="zrscp")
demo.queue(max_size=12).launch(show_api=True)