hoangkha1810's picture
Update app.py
174364e verified
import streamlit as st
import torch
import os
import uuid
from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler
from diffusers.utils import export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import time
# Constants
bases = {
"Cartoon": "frankjoshua/toonyou_beta6",
"Realistic": "emilianJR/epiCRealism",
"3d": "Lykon/DreamShaper",
"Anime": "Yntec/mistoonAnime2"
}
step_loaded = None
base_loaded = "Realistic"
motion_loaded = None
# Thiết lập thiết bị CPU và kiểu dữ liệu
device = "cpu"
dtype = torch.float32 # Sử dụng float32 thay vì float16 cho CPU
# Đặt cấu hình trang ngay đầu script
st.set_page_config(page_title="Instant⚡ Text to Video", layout="centered")
# Khởi tạo pipeline
@st.cache_resource
def init_pipeline():
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear"
)
# Tắt safety checker để tăng tốc
pipe.safety_checker = None
return pipe
pipe = init_pipeline()
# Hàm tạo video
def generate_image(prompt, base="Realistic", motion="", step=1):
global step_loaded, base_loaded, motion_loaded
step = int(step)
st.write(f"Generating video with prompt: {prompt}, base: {base}, steps: {step}")
# Tải checkpoint AnimateDiff-Lightning
if step_loaded != step:
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
step_loaded = step
# Tải mô hình cơ sở
if base_loaded != base:
pipe.unet.load_state_dict(
torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
strict=False
)
base_loaded = base
# Tải motion LoRA (tùy chọn)
if motion_loaded != motion:
pipe.unload_lora_weights()
if motion != "":
pipe.load_lora_weights(motion, adapter_name="motion")
pipe.set_adapters(["motion"], [0.7])
motion_loaded = motion
# Tạo progress bar
progress_bar = st.progress(0)
def progress_callback(i, t, z):
progress_bar.progress((i + 1) / step)
# Tối ưu hóa suy luận
with torch.no_grad(): # Tắt gradient để tiết kiệm bộ nhớ
output = pipe(
prompt=prompt,
guidance_scale=1.2,
num_inference_steps=step,
callback=progress_callback,
callback_steps=1
)
# Xuất video
name = str(uuid.uuid4()).replace("-", "")
path = f"/tmp/{name}.mp4"
export_to_video(output.frames[0], path, fps=10)
return path
# Giao diện Streamlit
st.title("Instant⚡ Text to Video")
# CSS tùy chỉnh
st.markdown("""
<style>
body {font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #f4f4f9; color: #333;}
.stApp {max-width: 800px; margin: auto; padding: 20px; background: #fff; box-shadow: 0px 0px 20px rgba(0,0,0,0.1); border-radius: 10px;}
.stButton>button {width: 100%; background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 5px; cursor: pointer;}
.stButton>button:hover {background-color: #45a049;}
.stVideo {margin-top: 20px;}
</style>
""", unsafe_allow_html=True)
# Inputs
prompt = st.text_input("Prompt", placeholder="Enter text to generate video...")
base = st.selectbox("Base model", ["Cartoon", "Realistic", "3d", "Anime"], index=1)
motion = st.selectbox(
"Motion",
[
("Default", ""),
("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"),
("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"),
("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"),
("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"),
("Pan left", "guoyww/animatediff-motion-lora-pan-left"),
("Pan right", "guoyww/animatediff-motion-lora-pan-right"),
("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"),
("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"),
],
format_func=lambda x: x[0],
index=1
)[1]
step = st.selectbox("Inference steps", [1, 2, 4, 8], index=0)
# Nút Generate
if st.button("Generate Video"):
if prompt:
with st.spinner("Generating video..."):
start_time = time.time()
video_path = generate_image(prompt, base, motion, step)
end_time = time.time()
st.success(f"Video generated in {end_time - start_time:.2f} seconds!")
st.video(video_path)
else:
st.error("Please enter a prompt!")
# Ví dụ
st.subheader("Examples")
examples = [
"Focus: Eiffel Tower (Animate: Clouds moving)",
"Focus: Trees In forest (Animate: Lion running)",
"Focus: Astronaut in Space",
"Focus: Group of Birds in sky (Animate: Birds Moving) (Shot From distance)",
"Focus: Statue of liberty (Shot from Drone) (Animate: Drone coming toward statue)",
"Focus: Panda in Forest (Animate: Drinking Tea)",
"Focus: Kids Playing (Season: Winter)",
"Focus: Cars in Street (Season: Rain, Daytime) (Shot from Distance) (Movement: Cars running)"
]
for example in examples:
if st.button(example, key=example):
with st.spinner("Generating video..."):
video_path = generate_image(example, base, motion, step)
st.video(video_path)