Kandinsky / app.py
rahul7star's picture
Update app.py
d09a395 verified
raw
history blame
5.86 kB
import spaces
import os
import gradio as gr
import torch
import subprocess
import importlib, site
import warnings
import logging
from huggingface_hub import hf_hub_download
from kandinsky import get_T2V_pipeline
from PIL import Image
# ============================================================
# 1️⃣ FlashAttention setup
# ============================================================
try:
print("Attempting to download and install FlashAttention wheel...")
flash_attention_wheel = hf_hub_download(
repo_id="rahul7star/flash-attn-3",
repo_type="model",
filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
)
subprocess.run(["pip", "install", flash_attention_wheel], check=True)
site.addsitedir(site.getsitepackages()[0])
importlib.invalidate_caches()
print("βœ… FlashAttention installed successfully.")
except Exception as e:
print(f"⚠️ Could not install FlashAttention: {e}")
print("Continuing without FlashAttention...")
# ============================================================
# 2️⃣ Torch + logging config
# ============================================================
warnings.filterwarnings("ignore")
logging.getLogger("torch").setLevel(logging.ERROR)
torch._logging.set_logs(
dynamo=logging.ERROR,
dynamic=logging.ERROR,
aot=logging.ERROR,
inductor=logging.ERROR,
guards=False,
recompiles=False,
)
# ============================================================
# 3️⃣ Ensure models are downloaded
# ============================================================
if not os.path.exists("./models_downloaded.marker"):
print("πŸ“¦ Models not found. Running download_models.py...")
subprocess.run(["python", "download_models.py"], check=True)
with open("./models_downloaded.marker", "w") as f:
f.write("done")
print("βœ… Models downloaded successfully.")
else:
print("βœ… Models already downloaded (marker found).")
# ============================================================
# 4️⃣ Load pipeline to CUDA (like Wan example)
# ============================================================
print("πŸ”§ Loading Kandinsky 5.0 T2V pipeline to CUDA...")
try:
pipe = get_T2V_pipeline(
device_map={
"dit": "cuda:0",
"vae": "cuda:0",
"text_embedder": "cuda:0",
},
conf_path="./configs/config_5s_sft.yaml",
)
# Explicitly move all components to CUDA
if hasattr(pipe, "to"):
pipe.to("cuda")
print("βœ… Pipeline successfully loaded and moved to CUDA.")
except Exception as e:
print(f"❌ Pipeline load failed: {e}")
pipe = None
# ============================================================
# 5️⃣ Generation function
# ============================================================
@spaces.GPU(duration = 40)
def generate_output(prompt, mode, duration, width, height, steps, guidance, scheduler):
print(f"❌ Pipeline load failed: {prompt}")
if pipe is None:
return None, "❌ Pipeline not initialized."
try:
output_path = f"/tmp/{prompt.replace(' ', '_')}.{'mp4' if mode == 'video' else 'png'}"
if mode == "image":
print(f"πŸ–ΌοΈ Generating image: {prompt}")
pipe(
prompt,
time_length=0,
width=width,
height=height,
save_path=output_path,
)
return output_path, f"βœ… Image saved: {output_path}"
elif mode == "video":
print(f"🎬 Generating {duration}s video: {prompt}")
pipe(
prompt,
time_length=duration,
width=width,
height=height,
num_steps=steps,
guidance_weight=guidance,
scheduler_scale=scheduler,
save_path=output_path,
)
return output_path, f"βœ… Video saved: {output_path}"
except torch.cuda.OutOfMemoryError:
return None, "⚠️ CUDA OOM β€” try smaller size or shorter duration."
except Exception as e:
return None, f"❌ Error during generation: {e}"
# ============================================================
# 6️⃣ Gradio UI
# ============================================================
with gr.Blocks(theme=gr.themes.Soft(), title="Kandinsky 5.0 T2V Lite (CUDA)") as demo:
gr.Markdown("## 🎞️ Kandinsky 5.0 β€” Text & Image to Video Generator")
with gr.Row():
with gr.Column(scale=2):
mode = gr.Radio(["video", "image"], value="video", label="Mode")
prompt = gr.Textbox(label="Prompt", value="A dog in red boots")
duration = gr.Slider(1, 10, step=1, value=5, label="Video Duration (seconds)")
width = gr.Radio([512, 768], value=768, label="Width (px)")
height = gr.Radio([512, 768], value=512, label="Height (px)")
steps = gr.Slider(10, 50, step=5, value=25, label="Sampling Steps")
guidance = gr.Slider(0.0, 10.0, step=0.5, value=1.0, label="Guidance Weight")
scheduler = gr.Slider(1.0, 10.0, step=0.5, value=5.0, label="Scheduler Scale")
btn = gr.Button("πŸš€ Generate", variant="primary")
with gr.Column(scale=3):
output_display = gr.Video(label="Generated Output (Video/Image)")
status = gr.Markdown()
btn.click(
fn=generate_output,
inputs=[prompt, mode, duration, width, height, steps, guidance, scheduler],
outputs=[output_display, status],
)
# ============================================================
# 7️⃣ Launch app normally (no GPU decorator)
# ============================================================
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)