Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import torch | |
import diffusers | |
import transformers | |
import argparse | |
import peft | |
import copy | |
import cv2 | |
import spaces | |
import gc | |
import gradio as gr | |
import numpy as np | |
from flash_attn import flash_attn_func | |
from peft import LoraConfig | |
from omegaconf import OmegaConf | |
from safetensors.torch import safe_open | |
from PIL import Image, ImageDraw, ImageFilter | |
from huggingface_hub import hf_hub_download | |
from transformers import pipeline | |
from models import HunyuanVideoTransformer3DModel | |
from pipelines import HunyuanVideoImageToVideoPipeline | |
header = """ | |
# DRA-Ctrl Gradio App | |
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
<a href="https://arxiv.org/pdf/2505.23325"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> | |
<a href="https://arxiv.org/abs/2505.23325"><img src="https://img.shields.io/badge/ariXv-Page-A42C25.svg" alt="arXiv"></a> | |
<a href="https://huggingface.co/Kunbyte/DRA-Ctrl"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a> | |
<a href="https://huggingface.co/spaces/Kunbyte/DRA-Ctrl"><img src="https://img.shields.io/badge/🤗-Space-ffbd45.svg" alt="HuggingFace"></a> | |
<a href="https://github.com/Kunbyte-AI/DRA-Ctrl"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> | |
<a href="https://dra-ctrl-2025.github.io/DRA-Ctrl/"><img src="https://img.shields.io/badge/Project-Page-blue" alt="Project"></a> | |
</div> | |
""" | |
notice = """ | |
For easier testing, in spatially-aligned image generation tasks, when passing the condition image to `gradio_app`, | |
there's no need to manually input edge maps, depth maps, or other condition images - only the original image is required. | |
The corresponding condition images will be automatically extracted. | |
""" | |
def init_basemodel(): | |
global transformer, scheduler, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, image_processor, pipe, current_task | |
pipe = None | |
current_task = None | |
# init models | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
weight_dtype = torch.bfloat16 | |
transformer = HunyuanVideoTransformer3DModel.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
subfolder="transformer", | |
inference_subject_driven=False, | |
low_cpu_mem_usage=True).requires_grad_(False).to(device, dtype=weight_dtype) | |
torch.cuda.empty_cache() | |
gc.collect() | |
scheduler = diffusers.FlowMatchEulerDiscreteScheduler() | |
vae = diffusers.AutoencoderKLHunyuanVideo.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
subfolder="vae", | |
low_cpu_mem_usage=True).requires_grad_(False).to(device, dtype=weight_dtype) | |
torch.cuda.empty_cache() | |
gc.collect() | |
text_encoder = transformers.LlavaForConditionalGeneration.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
subfolder="text_encoder", | |
low_cpu_mem_usage=True).requires_grad_(False).to(device, dtype=weight_dtype) | |
torch.cuda.empty_cache() | |
gc.collect() | |
text_encoder_2 = transformers.CLIPTextModel.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
subfolder="text_encoder_2", | |
low_cpu_mem_usage=True).requires_grad_(False).to(device, dtype=weight_dtype) | |
torch.cuda.empty_cache() | |
gc.collect() | |
tokenizer = transformers.AutoTokenizer.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
subfolder="tokenizer") | |
tokenizer_2 = transformers.CLIPTokenizer.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
subfolder="tokenizer_2") | |
image_processor = transformers.CLIPImageProcessor.from_pretrained('hunyuanvideo-community/HunyuanVideo-I2V', | |
subfolder="image_processor") | |
vae.enable_tiling() | |
vae.enable_slicing() | |
pipe = HunyuanVideoImageToVideoPipeline( | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
transformer=transformer, | |
vae=vae, | |
scheduler=copy.deepcopy(scheduler), | |
text_encoder_2=text_encoder_2, | |
tokenizer_2=tokenizer_2, | |
image_processor=image_processor, | |
) | |
def process_image_and_text(condition_image, target_prompt, condition_image_prompt, task, random_seed, num_steps, inpainting, fill_x1, fill_x2, fill_y1, fill_y2): | |
# set up the model | |
global pipe, current_task, transformer | |
if current_task != task: | |
if current_task is None: | |
# insert LoRA | |
lora_config = LoraConfig( | |
r=16, | |
lora_alpha=16, | |
init_lora_weights="gaussian", | |
target_modules=[ | |
'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0', | |
'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out', | |
'ff.net.0.proj', 'ff.net.2', | |
'ff_context.net.0.proj', 'ff_context.net.2', | |
'norm1_context.linear', 'norm1.linear', | |
'norm.linear', 'proj_mlp', 'proj_out', | |
] | |
) | |
transformer.add_adapter(lora_config) | |
else: | |
def restore_forward(module): | |
def restored_forward(self, x, *args, **kwargs): | |
return module.original_forward(x, *args, **kwargs) | |
return restored_forward.__get__(module, type(module)) | |
for n, m in transformer.named_modules(): | |
if isinstance(m, peft.tuners.lora.layer.Linear): | |
m.forward = restore_forward(m) | |
current_task = task | |
# hack LoRA forward | |
def create_hacked_forward(module): | |
if not hasattr(module, 'original_forward'): | |
module.original_forward = module.forward | |
lora_forward = module.forward | |
non_lora_forward = module.base_layer.forward | |
img_sequence_length = int((512 / 8 / 2) ** 2) | |
encoder_sequence_length = 144 + 252 # encoder sequence: 144 img 252 txt | |
num_imgs = 4 | |
num_generated_imgs = 3 | |
num_encoder_sequences = 2 if task in ['subject_driven', 'style_transfer'] else 1 | |
def hacked_lora_forward(self, x, *args, **kwargs): | |
if x.shape[1] == img_sequence_length * num_imgs and len(x.shape) > 2: | |
return torch.cat(( | |
lora_forward(x[:, :-img_sequence_length*num_generated_imgs], *args, **kwargs), | |
non_lora_forward(x[:, -img_sequence_length*num_generated_imgs:], *args, **kwargs) | |
), dim=1) | |
elif x.shape[1] == encoder_sequence_length * num_encoder_sequences or x.shape[1] == encoder_sequence_length: | |
return lora_forward(x, *args, **kwargs) | |
elif x.shape[1] == img_sequence_length * num_imgs + encoder_sequence_length * num_encoder_sequences: | |
return torch.cat(( | |
lora_forward(x[:, :(num_imgs - num_generated_imgs)*img_sequence_length], *args, **kwargs), | |
non_lora_forward(x[:, (num_imgs - num_generated_imgs)*img_sequence_length:-num_encoder_sequences*encoder_sequence_length], *args, **kwargs), | |
lora_forward(x[:, -num_encoder_sequences*encoder_sequence_length:], *args, **kwargs) | |
), dim=1) | |
elif x.shape[1] == 3072: | |
return non_lora_forward(x, *args, **kwargs) | |
else: | |
raise ValueError( | |
f"hacked_lora_forward receives unexpected sequence length: {x.shape[1]}, input shape: {x.shape}!" | |
) | |
return hacked_lora_forward.__get__(module, type(module)) | |
for n, m in transformer.named_modules(): | |
if isinstance(m, peft.tuners.lora.layer.Linear): | |
m.forward = create_hacked_forward(m) | |
# load LoRA weights | |
model_root = hf_hub_download( | |
repo_id="Kunbyte/DRA-Ctrl", | |
filename=f"{task}.safetensors", | |
resume_download=True) | |
try: | |
with safe_open(model_root, framework="pt") as f: | |
lora_weights = {} | |
for k in f.keys(): | |
param = f.get_tensor(k) | |
if k.endswith(".weight"): | |
k = k.replace('.weight', '.default.weight') | |
lora_weights[k] = param | |
transformer.load_state_dict(lora_weights, strict=False) | |
except Exception as e: | |
raise ValueError(f'{e}') | |
transformer.requires_grad_(False) | |
# start generation | |
c_txt = None if condition_image_prompt == "" else condition_image_prompt | |
c_img = condition_image.resize((512, 512)) | |
t_txt = target_prompt | |
if task not in ['subject_driven', 'style_transfer']: | |
if task == "canny": | |
def get_canny_edge(img): | |
img_np = np.array(img) | |
img_gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) | |
edges = cv2.Canny(img_gray, 100, 200) | |
edges_tmp = Image.fromarray(edges).convert("RGB") | |
edges[edges == 0] = 128 | |
return Image.fromarray(edges).convert("RGB") | |
c_img = get_canny_edge(c_img) | |
elif task == "coloring": | |
c_img = ( | |
c_img.resize((512, 512)) | |
.convert("L") | |
.convert("RGB") | |
) | |
elif task == "deblurring": | |
blur_radius = 10 | |
c_img = ( | |
c_img.convert("RGB") | |
.filter(ImageFilter.GaussianBlur(blur_radius)) | |
.resize((512, 512)) | |
.convert("RGB") | |
) | |
elif task == "depth": | |
def get_depth_map(img): | |
from transformers import pipeline | |
depth_pipe = pipeline( | |
task="depth-estimation", | |
model="LiheYoung/depth-anything-small-hf", | |
device="cpu", | |
) | |
return depth_pipe(img)["depth"].convert("RGB").resize((512, 512)) | |
c_img = get_depth_map(c_img) | |
k = (255 - 128) / 255 | |
b = 128 | |
c_img = c_img.point(lambda x: k * x + b) | |
elif task == "depth_pred": | |
c_img = c_img | |
elif task == "fill": | |
c_img = c_img.resize((512, 512)).convert("RGB") | |
x1, x2 = fill_x1, fill_x2 | |
y1, y2 = fill_y1, fill_y2 | |
mask = Image.new("L", (512, 512), 0) | |
draw = ImageDraw.Draw(mask) | |
draw.rectangle((x1, y1, x2, y2), fill=255) | |
if inpainting: | |
mask = Image.eval(mask, lambda a: 255 - a) | |
c_img = Image.composite( | |
c_img, | |
Image.new("RGB", (512, 512), (255, 255, 255)), | |
mask | |
) | |
c_img = Image.composite( | |
c_img, | |
Image.new("RGB", (512, 512), (128, 128, 128)), | |
mask | |
) | |
elif task == "sr": | |
c_img = c_img.resize((int(512 / 4), int(512 / 4))).convert("RGB") | |
c_img = c_img.resize((512, 512)) | |
gen_img = pipe( | |
image=c_img, | |
prompt=[t_txt.strip()], | |
prompt_condition=[c_txt.strip()] if c_txt is not None else None, | |
prompt_2=[t_txt], | |
height=512, | |
width=512, | |
num_frames=5, | |
num_inference_steps=num_steps, | |
guidance_scale=6.0, | |
num_videos_per_prompt=1, | |
generator=torch.Generator(device=pipe.transformer.device).manual_seed(random_seed), | |
output_type='pt', | |
image_embed_interleave=4, | |
frame_gap=48, | |
mixup=True, | |
mixup_num_imgs=2, | |
enhance_tp=task in ['subject_driven'], | |
).frames | |
output_images = [] | |
for i in range(10): | |
out = gen_img[:, i:i+1, :, :, :] | |
out = out.squeeze(0).squeeze(0).cpu().to(torch.float32).numpy() | |
out = np.transpose(out, (1, 2, 0)) | |
out = (out * 255).astype(np.uint8) | |
out = Image.fromarray(out) | |
output_images.append(out) | |
return output_images[1:] + [output_images[0]] | |
def create_app(): | |
with gr.Blocks() as app: | |
gr.Markdown(header, elem_id="header") | |
with gr.Row(equal_height=False): | |
with gr.Column(variant="panel", elem_classes="inputPanel"): | |
condition_image = gr.Image( | |
type="pil", label="Condition Image", width=300, elem_id="input" | |
) | |
task = gr.Radio( | |
[ | |
("Subject-driven Image Generation", "subject_driven"), | |
("Canny-to-Image", "canny"), | |
("Colorization", "coloring"), | |
("Deblurring", "deblurring"), | |
("Depth-to-Image", "depth"), | |
("Depth Prediction", "depth_pred"), | |
("In/Out-Painting", "fill"), | |
("Super-Resolution", "sr"), | |
("Style Transfer", "style_transfer") | |
], | |
label="Task Selection", | |
value="subject_driven", | |
interactive=True, | |
elem_id="task_selection" | |
) | |
gr.Markdown(notice, elem_id="notice") | |
target_prompt = gr.Textbox(lines=2, label="Target Prompt", elem_id="tp") | |
condition_image_prompt = gr.Textbox(lines=2, label="Condition Image Prompt (Only required by Subject-driven Image Generation and Style Transfer tasks)", elem_id="cp") | |
random_seed = gr.Number(label="Random Seed", precision=0, value=0, elem_id="seed") | |
num_steps = gr.Number(label="Diffusion Inference Steps", precision=0, value=50, elem_id="steps") | |
inpainting = gr.Checkbox(label="Inpainting", value=False, elem_id="inpainting") | |
fill_x1 = gr.Number(label="In/Out-painting Box Left Boundary", precision=0, value=128, elem_id="fill_x1") | |
fill_x2 = gr.Number(label="In/Out-painting Box Right Boundary", precision=0, value=384, elem_id="fill_x2") | |
fill_y1 = gr.Number(label="In/Out-painting Box Top Boundary", precision=0, value=128, elem_id="fill_y1") | |
fill_y2 = gr.Number(label="In/Out-painting Box Bottom Boundary", precision=0, value=384, elem_id="fill_y2") | |
submit_btn = gr.Button("Run", elem_id="submit_btn") | |
with gr.Column(variant="panel", elem_classes="outputPanel"): | |
# output_image = gr.Image(type="pil", elem_id="output") | |
output_images = gr.Gallery( | |
label="Output Images", | |
show_label=True, | |
elem_id="output_gallery", | |
columns=1, | |
rows=10, | |
object_fit="contain", | |
height="auto", | |
) | |
submit_btn.click( | |
fn=process_image_and_text, | |
inputs=[condition_image, target_prompt, condition_image_prompt, task, random_seed, num_steps, inpainting, fill_x1, fill_x2, fill_y1, fill_y2], | |
outputs=output_images, | |
) | |
return app | |
if __name__ == "__main__": | |
init_basemodel() | |
create_app().launch(debug=True, ssr_mode=False, max_threads=1) | |