Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,100 Bytes
b89c182 ab0a826 b89c182 9634dc8 b89c182 9634dc8 b89c182 55445b1 b89c182 ab0a826 b89c182 9634dc8 b89c182 9634dc8 b89c182 9634dc8 b89c182 ab0a826 b89c182 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import os
import spaces
from diffusers.models import AutoencoderKL
from models import FLAV
from huggingface_hub import hf_hub_download
import torch
from diffusion.rectified_flow import RectifiedFlow
from diffusers.training_utils import EMAModel
from converter import Generator
from utils import *
import tempfile
import gradio as gr
from huggingface_hub import hf_hub_download
AUDIO_T_PER_FRAME = 1600 // 160
#################################################################################
# Global Model Setup #
#################################################################################
# These variables will be initialized in setup_models() and used in main()
vae = None
model = None
vocoder = None
audio_scale = 3.5009668382765917
def setup_models():
global vae, model, vocoder
device = "cuda"
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
vae.eval()
model_ckpt = "MaverickAlex/R-FLAV-B-1-AIST" # MaverickAlex/R-FLAV-B-1-LS
model = FLAV.from_pretrained(model_ckpt)
hf_hub_download(repo_id=model_ckpt, filename="vocoder/config.json")
vocoder_path = hf_hub_download(repo_id=model_ckpt, filename="vocoder/vocoder.pt")
vocoder_path = vocoder_path.replace("vocoder.pt", "")
vocoder = Generator.from_pretrained(vocoder_path)
vae.to(device)
model.to(device)
vocoder.to(device)
@spaces.GPU
def generate_video(num_frames=10, steps=2, seed=42):
global vae, model, vocoder
# Setup device
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(seed)
# Set up generation parameters
video_latent_size = (1, 10, 4, 256//8, 256//8)
audio_latent_size = (1, 10, 1, 256, AUDIO_T_PER_FRAME)
rectified_flow = RectifiedFlow(num_timesteps=steps,
warmup_timesteps=10,
window_size=10)
# Generate sample
video, audio = generate_sample(
vae=vae, # These globals are set by setup_models
rectified_flow=rectified_flow,
forward_fn=model.forward,
video_length=num_frames,
video_latent_size=video_latent_size,
audio_latent_size=audio_latent_size,
y=None,
cfg_scale=None,
device=device
)
# Convert to wav
wavs = get_wavs(audio, vocoder, audio_scale, device)
# Save to temporary files
temp_dir = tempfile.mkdtemp()
video_path = os.path.join(temp_dir, "video", "generated_video.mp4")
# Use the first video and wav
vid, wav = video[0], wavs[0]
save_multimodal(vid, wav, temp_dir, "generated")
return video_path
def ui_generate_video(num_frames, steps, seed):
try:
return generate_video(int(num_frames), int(steps), int(seed))
except Exception as e:
return None
# Create Gradio interface
with gr.Blocks(title="FLAV Video Generator") as demo:
gr.Markdown("# FLAV Video Generator")
gr.Markdown("Generate videos using the FLAV model")
num_frames = None
steps = None
seed = None
video_output = None
with gr.Row():
with gr.Column():
num_frames = gr.Slider(minimum=5, maximum=30, step=1, value=10, label="Number of Frames")
steps = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of Steps (multiplied by a factor of 10)")
seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed")
generate_btn = gr.Button("Generate Video")
with gr.Column():
video_output = gr.PlayableVideo(label="Generated Video", width=256, height=256)
generate_btn.click(
fn=ui_generate_video,
inputs=[num_frames, steps, seed],
outputs=[video_output]
)
if __name__ == "__main__":
setup_models()
demo.launch()
|