File size: 2,275 Bytes
3facca5
0d3ff24
3facca5
 
 
 
 
 
 
 
0d3ff24
3facca5
 
 
 
 
 
549018e
3facca5
549018e
 
3facca5
549018e
3facca5
 
 
0d3ff24
c33e25c
0d3ff24
 
 
 
 
 
 
 
 
 
 
3facca5
0d3ff24
3facca5
549018e
3facca5
 
0d3ff24
549018e
3facca5
 
549018e
ace3238
 
0d3ff24
 
549018e
1d7aaae
549018e
 
 
c3dbd3d
3facca5
 
0d3ff24
 
 
 
 
 
 
 
 
 
549018e
3facca5
768e553
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os

if os.environ.get("SPACES_ZERO_GPU") is not None:
    import spaces
else:
    class spaces:
        @staticmethod
        def GPU(func):
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)

            return wrapper

import torch
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
import gradio as gr
import config as cfg

# Load the pre-trained model
pipe = MochiPipeline.from_pretrained(cfg.MODEL_PRE_TRAINED_ID, variant="bf16", torch_dtype=torch.bfloat16)

# Enable memory-saving optimizations
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()


@spaces.GPU
def generate_video(prompt, num_frames=84, fps=30, high_quality=False):
    if high_quality:
        print("High quality option selected. Requires 42GB VRAM.")
        # Check if running on ZeroGPU
        if os.environ.get("SPACES_ZERO_GPU") is not None:
            raise RuntimeError("High quality option may fail on ZeroGPU environments.")
        with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
            frames = pipe(prompt, num_frames=num_frames).frames[0]
    else:
        print("Standard quality option selected.")
        frames = pipe(prompt, num_frames=num_frames).frames[0]

        # Export frames as video
    video_path = "mochi.mp4"
    export_to_video(frames, video_path, fps=fps)
    return video_path


# Create the Gradio interface
interface = gr.Interface(
    fn=generate_video,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your text prompt here... 💡"),
        gr.Slider(minimum=1, maximum=240, value=84, label="Number of frames 🎞️"),
        gr.Slider(minimum=1, maximum=60, value=30, label="FPS (Frames per second) ⏱️"),
        gr.Checkbox(label="High Quality Output (requires 42GB VRAM, may fail on ZeroGPU)")
    ],
    outputs=gr.Video(),
    title=cfg.TITLE,
    description=cfg.DESCRIPTION,
    examples=cfg.EXAMPLES,
    article=cfg.BUY_ME_A_COFFE
)

# Center the title and description using custom CSS
interface.css = """  
    .interface-title {  
        text-align: center;  
    }  
    .interface-description {  
        text-align: center;  
    }  
"""

# Launch the application
if __name__ == "__main__":
    interface.launch(ssr_mode=False)