File size: 4,100 Bytes
b89c182
 
 
 
 
 
ab0a826
b89c182
9634dc8
 
 
 
b89c182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9634dc8
b89c182
 
 
 
 
55445b1
b89c182
ab0a826
b89c182
 
9634dc8
b89c182
9634dc8
b89c182
9634dc8
 
b89c182
 
 
 
 
 
 
 
ab0a826
b89c182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

import os
import spaces
from diffusers.models import AutoencoderKL
from models import FLAV
from huggingface_hub import hf_hub_download
import torch
    
from diffusion.rectified_flow import RectifiedFlow

from diffusers.training_utils import EMAModel
from converter import Generator
from utils import *

import tempfile
import gradio as gr
from huggingface_hub import hf_hub_download
AUDIO_T_PER_FRAME = 1600 // 160 

#################################################################################
#                               Global Model Setup                              #
#################################################################################

# These variables will be initialized in setup_models() and used in main()
vae = None
model = None
vocoder = None
audio_scale = 3.5009668382765917


def setup_models():
    global vae, model, vocoder
    
    device = "cuda"
    vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
    vae.eval()
    

    model_ckpt = "MaverickAlex/R-FLAV-B-1-AIST" # MaverickAlex/R-FLAV-B-1-LS

    model = FLAV.from_pretrained(model_ckpt)

    hf_hub_download(repo_id=model_ckpt, filename="vocoder/config.json")
    vocoder_path = hf_hub_download(repo_id=model_ckpt, filename="vocoder/vocoder.pt")
    vocoder_path = vocoder_path.replace("vocoder.pt", "")
    vocoder = Generator.from_pretrained(vocoder_path)

    vae.to(device)
    model.to(device)
    vocoder.to(device)
    

@spaces.GPU
def generate_video(num_frames=10, steps=2, seed=42):
    global vae, model, vocoder
    # Setup device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(seed)
    
    # Set up generation parameters
    video_latent_size = (1, 10, 4, 256//8, 256//8)
    audio_latent_size = (1, 10, 1, 256, AUDIO_T_PER_FRAME)

    rectified_flow = RectifiedFlow(num_timesteps=steps, 
                               warmup_timesteps=10, 
                               window_size=10)
    
    # Generate sample
    video, audio = generate_sample(
        vae=vae,  # These globals are set by setup_models
        rectified_flow=rectified_flow,
        forward_fn=model.forward,
        video_length=num_frames,
        video_latent_size=video_latent_size,
        audio_latent_size=audio_latent_size,
        y=None,
        cfg_scale=None,
        device=device
    )
    
    # Convert to wav
    wavs = get_wavs(audio, vocoder, audio_scale, device)
    
    # Save to temporary files
    temp_dir = tempfile.mkdtemp()
    video_path = os.path.join(temp_dir, "video", "generated_video.mp4")
    
    # Use the first video and wav
    vid, wav = video[0], wavs[0]
    save_multimodal(vid, wav, temp_dir, "generated")
    
    return video_path

def ui_generate_video(num_frames, steps,  seed):
    try:
        return generate_video(int(num_frames), int(steps), int(seed))
    except Exception as e:
        return None

# Create Gradio interface
with gr.Blocks(title="FLAV Video Generator") as demo:
    gr.Markdown("# FLAV Video Generator")
    gr.Markdown("Generate videos using the FLAV model")

    num_frames = None
    steps = None
    seed = None

    video_output = None
    with gr.Row():
        with gr.Column():
            num_frames = gr.Slider(minimum=5, maximum=30, step=1, value=10, label="Number of Frames")
            steps = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of Steps (multiplied by a factor of 10)")
            seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed")
            generate_btn = gr.Button("Generate Video")
        
        with gr.Column():
            video_output = gr.PlayableVideo(label="Generated Video", width=256, height=256)
    generate_btn.click(
        fn=ui_generate_video,
        inputs=[num_frames, steps, seed],
        outputs=[video_output]
    )

if __name__ == "__main__":
    setup_models()
    demo.launch()