File size: 5,683 Bytes
3fd57dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174364e
 
 
3fd57dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import streamlit as st
import torch
import os
import uuid
from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler
from diffusers.utils import export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import time

# Constants
bases = {
    "Cartoon": "frankjoshua/toonyou_beta6",
    "Realistic": "emilianJR/epiCRealism",
    "3d": "Lykon/DreamShaper",
    "Anime": "Yntec/mistoonAnime2"
}
step_loaded = None
base_loaded = "Realistic"
motion_loaded = None

# Thiết lập thiết bị CPU và kiểu dữ liệu
device = "cpu"
dtype = torch.float32  # Sử dụng float32 thay vì float16 cho CPU

# Đặt cấu hình trang ngay đầu script
st.set_page_config(page_title="Instant⚡ Text to Video", layout="centered")

# Khởi tạo pipeline
@st.cache_resource
def init_pipeline():
    pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
    pipe.scheduler = EulerDiscreteScheduler.from_config(
        pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear"
    )
    # Tắt safety checker để tăng tốc
    pipe.safety_checker = None
    return pipe

pipe = init_pipeline()

# Hàm tạo video
def generate_image(prompt, base="Realistic", motion="", step=1):
    global step_loaded, base_loaded, motion_loaded

    step = int(step)
    st.write(f"Generating video with prompt: {prompt}, base: {base}, steps: {step}")

    # Tải checkpoint AnimateDiff-Lightning
    if step_loaded != step:
        repo = "ByteDance/AnimateDiff-Lightning"
        ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
        pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
        step_loaded = step

    # Tải mô hình cơ sở
    if base_loaded != base:
        pipe.unet.load_state_dict(
            torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
            strict=False
        )
        base_loaded = base

    # Tải motion LoRA (tùy chọn)
    if motion_loaded != motion:
        pipe.unload_lora_weights()
        if motion != "":
            pipe.load_lora_weights(motion, adapter_name="motion")
            pipe.set_adapters(["motion"], [0.7])
        motion_loaded = motion

    # Tạo progress bar
    progress_bar = st.progress(0)
    def progress_callback(i, t, z):
        progress_bar.progress((i + 1) / step)

    # Tối ưu hóa suy luận
    with torch.no_grad():  # Tắt gradient để tiết kiệm bộ nhớ
        output = pipe(
            prompt=prompt,
            guidance_scale=1.2,
            num_inference_steps=step,
            callback=progress_callback,
            callback_steps=1
        )

    # Xuất video
    name = str(uuid.uuid4()).replace("-", "")
    path = f"/tmp/{name}.mp4"
    export_to_video(output.frames[0], path, fps=10)
    return path

# Giao diện Streamlit
st.title("Instant⚡ Text to Video")

# CSS tùy chỉnh
st.markdown("""
    <style>
        body {font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #f4f4f9; color: #333;}
        .stApp {max-width: 800px; margin: auto; padding: 20px; background: #fff; box-shadow: 0px 0px 20px rgba(0,0,0,0.1); border-radius: 10px;}
        .stButton>button {width: 100%; background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 5px; cursor: pointer;}
        .stButton>button:hover {background-color: #45a049;}
        .stVideo {margin-top: 20px;}
    </style>
""", unsafe_allow_html=True)

# Inputs
prompt = st.text_input("Prompt", placeholder="Enter text to generate video...")
base = st.selectbox("Base model", ["Cartoon", "Realistic", "3d", "Anime"], index=1)
motion = st.selectbox(
    "Motion",
    [
        ("Default", ""),
        ("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"),
        ("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"),
        ("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"),
        ("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"),
        ("Pan left", "guoyww/animatediff-motion-lora-pan-left"),
        ("Pan right", "guoyww/animatediff-motion-lora-pan-right"),
        ("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"),
        ("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"),
    ],
    format_func=lambda x: x[0],
    index=1
)[1]
step = st.selectbox("Inference steps", [1, 2, 4, 8], index=0)

# Nút Generate
if st.button("Generate Video"):
    if prompt:
        with st.spinner("Generating video..."):
            start_time = time.time()
            video_path = generate_image(prompt, base, motion, step)
            end_time = time.time()
            st.success(f"Video generated in {end_time - start_time:.2f} seconds!")
            st.video(video_path)
    else:
        st.error("Please enter a prompt!")

# Ví dụ
st.subheader("Examples")
examples = [
    "Focus: Eiffel Tower (Animate: Clouds moving)",
    "Focus: Trees In forest (Animate: Lion running)",
    "Focus: Astronaut in Space",
    "Focus: Group of Birds in sky (Animate: Birds Moving) (Shot From distance)",
    "Focus: Statue of liberty (Shot from Drone) (Animate: Drone coming toward statue)",
    "Focus: Panda in Forest (Animate: Drinking Tea)",
    "Focus: Kids Playing (Season: Winter)",
    "Focus: Cars in Street (Season: Rain, Daytime) (Shot from Distance) (Movement: Cars running)"
]
for example in examples:
    if st.button(example, key=example):
        with st.spinner("Generating video..."):
            video_path = generate_image(example, base, motion, step)
            st.video(video_path)