File size: 5,683 Bytes
3fd57dd 174364e 3fd57dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import streamlit as st
import torch
import os
import uuid
from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler
from diffusers.utils import export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import time
# Constants
bases = {
"Cartoon": "frankjoshua/toonyou_beta6",
"Realistic": "emilianJR/epiCRealism",
"3d": "Lykon/DreamShaper",
"Anime": "Yntec/mistoonAnime2"
}
step_loaded = None
base_loaded = "Realistic"
motion_loaded = None
# Thiết lập thiết bị CPU và kiểu dữ liệu
device = "cpu"
dtype = torch.float32 # Sử dụng float32 thay vì float16 cho CPU
# Đặt cấu hình trang ngay đầu script
st.set_page_config(page_title="Instant⚡ Text to Video", layout="centered")
# Khởi tạo pipeline
@st.cache_resource
def init_pipeline():
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear"
)
# Tắt safety checker để tăng tốc
pipe.safety_checker = None
return pipe
pipe = init_pipeline()
# Hàm tạo video
def generate_image(prompt, base="Realistic", motion="", step=1):
global step_loaded, base_loaded, motion_loaded
step = int(step)
st.write(f"Generating video with prompt: {prompt}, base: {base}, steps: {step}")
# Tải checkpoint AnimateDiff-Lightning
if step_loaded != step:
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
step_loaded = step
# Tải mô hình cơ sở
if base_loaded != base:
pipe.unet.load_state_dict(
torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
strict=False
)
base_loaded = base
# Tải motion LoRA (tùy chọn)
if motion_loaded != motion:
pipe.unload_lora_weights()
if motion != "":
pipe.load_lora_weights(motion, adapter_name="motion")
pipe.set_adapters(["motion"], [0.7])
motion_loaded = motion
# Tạo progress bar
progress_bar = st.progress(0)
def progress_callback(i, t, z):
progress_bar.progress((i + 1) / step)
# Tối ưu hóa suy luận
with torch.no_grad(): # Tắt gradient để tiết kiệm bộ nhớ
output = pipe(
prompt=prompt,
guidance_scale=1.2,
num_inference_steps=step,
callback=progress_callback,
callback_steps=1
)
# Xuất video
name = str(uuid.uuid4()).replace("-", "")
path = f"/tmp/{name}.mp4"
export_to_video(output.frames[0], path, fps=10)
return path
# Giao diện Streamlit
st.title("Instant⚡ Text to Video")
# CSS tùy chỉnh
st.markdown("""
<style>
body {font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #f4f4f9; color: #333;}
.stApp {max-width: 800px; margin: auto; padding: 20px; background: #fff; box-shadow: 0px 0px 20px rgba(0,0,0,0.1); border-radius: 10px;}
.stButton>button {width: 100%; background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 5px; cursor: pointer;}
.stButton>button:hover {background-color: #45a049;}
.stVideo {margin-top: 20px;}
</style>
""", unsafe_allow_html=True)
# Inputs
prompt = st.text_input("Prompt", placeholder="Enter text to generate video...")
base = st.selectbox("Base model", ["Cartoon", "Realistic", "3d", "Anime"], index=1)
motion = st.selectbox(
"Motion",
[
("Default", ""),
("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"),
("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"),
("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"),
("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"),
("Pan left", "guoyww/animatediff-motion-lora-pan-left"),
("Pan right", "guoyww/animatediff-motion-lora-pan-right"),
("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"),
("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"),
],
format_func=lambda x: x[0],
index=1
)[1]
step = st.selectbox("Inference steps", [1, 2, 4, 8], index=0)
# Nút Generate
if st.button("Generate Video"):
if prompt:
with st.spinner("Generating video..."):
start_time = time.time()
video_path = generate_image(prompt, base, motion, step)
end_time = time.time()
st.success(f"Video generated in {end_time - start_time:.2f} seconds!")
st.video(video_path)
else:
st.error("Please enter a prompt!")
# Ví dụ
st.subheader("Examples")
examples = [
"Focus: Eiffel Tower (Animate: Clouds moving)",
"Focus: Trees In forest (Animate: Lion running)",
"Focus: Astronaut in Space",
"Focus: Group of Birds in sky (Animate: Birds Moving) (Shot From distance)",
"Focus: Statue of liberty (Shot from Drone) (Animate: Drone coming toward statue)",
"Focus: Panda in Forest (Animate: Drinking Tea)",
"Focus: Kids Playing (Season: Winter)",
"Focus: Cars in Street (Season: Rain, Daytime) (Shot from Distance) (Movement: Cars running)"
]
for example in examples:
if st.button(example, key=example):
with st.spinner("Generating video..."):
video_path = generate_image(example, base, motion, step)
st.video(video_path) |