VideoCoF / app.py
XiangpengYang's picture
fps 8
d10046c
raw
history blame
20 kB
import os
import sys
import time
import torch
import gradio as gr
import numpy as np
import imageio
import spaces
from PIL import Image
# Add project root to path
# current_file_path = os.path.abspath(__file__)
# project_root = os.path.dirname(os.path.dirname(current_file_path))
# if project_root not in sys.path:
# sys.path.insert(0, project_root)
from videox_fun.ui.wan_ui import Wan_Controller, css
from videox_fun.ui.ui import (
create_model_type, create_model_checkpoints, create_finetune_models_checkpoints,
create_teacache_params, create_cfg_skip_params, create_cfg_riflex_k,
create_prompts, create_samplers, create_height_width,
create_generation_methods_and_video_length, create_generation_method,
create_cfg_and_seedbox, create_ui_outputs
)
from videox_fun.data.dataset_image_video import derive_ground_object_from_instruction
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
from videox_fun.utils.utils import save_videos_grid, timer
# Redefine create_height_width to remove Chinese and specific defaults if needed,
# although we will mostly ignore sliders if we use input resolution.
# We will create a custom version here to avoid modifying the library file if possible,
# or we just rely on `create_height_width` and update labels.
# But `create_height_width` is imported. Let's override it or create a new one.
def create_height_width_english(default_height, default_width, maximum_height, maximum_width):
resize_method = gr.Radio(
["Generate by", "Resize according to Reference"],
value="Generate by",
show_label=False,
visible=False # Hide since we force input resolution
)
# We keep sliders visible but maybe we can update them dynamically or just ignore them?
# User requested "input is whatever resolution, inference is whatever resolution".
# So we can hide these or just label them as "Default / Override if no video".
# But better to hide them if we always use video resolution.
# However, if no video is provided (which shouldn't happen for VideoCoF), we might need them.
# Let's keep them but make them less prominent or explain.
# Actually user said "no default 480x832", implying don't force it.
width_slider = gr.Slider(label="Width", value=default_width, minimum=128, maximum=maximum_width, step=16, visible=False)
height_slider = gr.Slider(label="Height", value=default_height, minimum=128, maximum=maximum_height, step=16, visible=False)
base_resolution = gr.Radio(label="Base Resolution", value=512, choices=[512, 640, 768, 896, 960, 1024], visible=False)
return resize_method, width_slider, height_slider, base_resolution
def load_video_frames(video_path: str, source_frames: int):
assert source_frames is not None, "source_frames is required"
reader = imageio.get_reader(video_path)
try:
total_frames = reader.count_frames()
except Exception:
total_frames = sum(1 for _ in reader)
reader = imageio.get_reader(video_path)
stride = max(1, total_frames // source_frames)
# Using random start frame as in inference.py
start_frame = torch.randint(0, max(1, total_frames - stride * source_frames), (1,))[0].item()
frames = []
original_height, original_width = None, None
for i in range(source_frames):
idx = start_frame + i * stride
if idx >= total_frames:
break
try:
frame = reader.get_data(idx)
pil_frame = Image.fromarray(frame)
if original_height is None:
original_width, original_height = pil_frame.size
frames.append(pil_frame)
except IndexError:
break
reader.close()
while len(frames) < source_frames:
if frames:
frames.append(frames[-1].copy())
else:
w, h = (original_width, original_height) if original_width else (832, 480)
frames.append(Image.new('RGB', (w, h), (0, 0, 0)))
input_video = torch.from_numpy(np.array(frames))
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float()
input_video = input_video * (2.0 / 255.0) - 1.0
return input_video, original_height, original_width
class VideoCoF_Controller(Wan_Controller):
@spaces.GPU(duration=300)
@timer
def generate(
self,
diffusion_transformer_dropdown,
base_model_dropdown,
lora_model_dropdown,
lora_alpha_slider,
prompt_textbox,
negative_prompt_textbox,
sampler_dropdown,
sample_step_slider,
resize_method,
width_slider,
height_slider,
base_resolution,
generation_method,
length_slider,
overlap_video_length,
partial_video_length,
cfg_scale_slider,
start_image,
end_image,
validation_video,
validation_video_mask,
control_video,
denoise_strength,
seed_textbox,
ref_image=None,
enable_teacache=None,
teacache_threshold=None,
num_skip_start_steps=None,
teacache_offload=None,
cfg_skip_ratio=None,
enable_riflex=None,
riflex_k=None,
# Custom args
source_frames_slider=33,
reasoning_frames_slider=4,
repeat_rope_checkbox=True,
# New arg for acceleration
enable_acceleration=False,
fps=8,
is_api=False,
):
self.clear_cache()
print(f"VideoCoF Generation started.")
# Ensure model is on CUDA inside the zero-gpu decorated function
if torch.cuda.is_available():
self.device = torch.device("cuda")
# If pipeline is not on cuda, move it (if possible, but usually accelerate handles this or it's handled by parts)
# However, Wan_Controller logic might rely on `self.device`.
# We explicitly set `self.device` to cuda here.
if self.diffusion_transformer_dropdown != diffusion_transformer_dropdown:
self.update_diffusion_transformer(diffusion_transformer_dropdown)
if self.base_model_path != base_model_dropdown:
self.update_base_model(base_model_dropdown)
if self.lora_model_path != lora_model_dropdown:
self.update_lora_model(lora_model_dropdown)
# Scheduler setup
scheduler_config = self.pipeline.scheduler.config
if sampler_dropdown in ["Flow_Unipc", "Flow_DPM++"]:
scheduler_config['shift'] = 1
self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config)
# LoRA merging
# 1. Merge VideoCoF LoRA
if self.lora_model_path != "none":
print(f"Merge VideoCoF Lora: {self.lora_model_path}")
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
# 2. Merge Acceleration LoRA (FusionX) if enabled
acc_lora_path = os.path.join(self.personalized_model_dir, "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors")
if enable_acceleration:
if os.path.exists(acc_lora_path):
print(f"Merge Acceleration LoRA: {acc_lora_path}")
# FusionX LoRA generally uses multiplier 1.0
self.pipeline = merge_lora(self.pipeline, acc_lora_path, multiplier=1.0)
else:
print(f"Warning: Acceleration LoRA not found at {acc_lora_path}")
# Seed
if int(seed_textbox) != -1 and seed_textbox != "":
torch.manual_seed(int(seed_textbox))
else:
seed_textbox = np.random.randint(0, 1e10)
generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox))
try:
# VideoCoF logic
# Use validation_video as source if provided (UI standard for Video-to-Video)
input_video_path = validation_video
if input_video_path is None:
# Fallback to control_video if set, but standard UI uses validation_video
input_video_path = control_video
if input_video_path is None:
raise ValueError("Please upload a video for VideoCoF generation.")
# CoT Prompt Construction
edit_text = prompt_textbox
ground_instr = derive_ground_object_from_instruction(edit_text)
prompt = (
"A video sequence showing three parts: first the original scene, "
f"then grounded {ground_instr}, and finally the same scene but {edit_text}"
)
print(f"Constructed prompt: {prompt}")
# Load video frames
input_video_tensor, video_height, video_width = load_video_frames(
input_video_path,
source_frames=source_frames_slider
)
# Using loaded video dimensions
h, w = video_height, video_width
print(f"Input video dimensions: {w}x{h}")
print(f"Running pipeline with frames={length_slider}, source={source_frames_slider}, reasoning={reasoning_frames_slider}")
sample = self.pipeline(
video=input_video_tensor,
prompt=prompt,
num_frames=length_slider,
source_frames=source_frames_slider,
reasoning_frames=reasoning_frames_slider,
negative_prompt=negative_prompt_textbox,
height=h,
width=w,
generator=generator,
guidance_scale=cfg_scale_slider,
num_inference_steps=sample_step_slider,
repeat_rope=repeat_rope_checkbox,
cot=True,
).videos
final_video = sample
except Exception as e:
print(f"Error: {e}")
# Unmerge in case of error (LIFO order)
if enable_acceleration and os.path.exists(acc_lora_path):
print("Unmerging Acceleration LoRA (due to error)")
self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0)
if self.lora_model_path != "none":
print("Unmerging VideoCoF LoRA (due to error)")
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
return gr.update(), gr.update(), f"Error: {str(e)}"
# Unmerge LoRAs (LIFO order)
if enable_acceleration and os.path.exists(acc_lora_path):
print("Unmerging Acceleration LoRA")
self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0)
if self.lora_model_path != "none":
print("Unmerging VideoCoF LoRA")
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
# Save output
save_sample_path = self.save_outputs(
False, length_slider, final_video, fps=fps
)
# Return input video to display it alongside output if needed?
# But generate returns [result_image, result_video, infer_progress].
# The user said "load original video didn't display".
# That usually refers to the input component not showing the video after upload or example selection.
# Grado handles that automatically if `value` is set or user uploads.
# Maybe they mean the `validation_video` component didn't show the example?
# Or do they mean they want to see the processed input frames?
# "load 原视频没有display 出来" -> "Loaded original video didn't display".
# Likely referring to the input UI component.
# If they mean they want to see it in the output area, we can't easily change the return signature without changing UI structure.
# But let's ensure the input component works.
return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
controller = VideoCoF_Controller(
GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
config_path=config_path, compile_dit=compile_dit,
weight_dtype=weight_dtype
)
with gr.Blocks() as demo:
gr.Markdown("# VideoCoF Demo")
with gr.Column(variant="panel"):
# Hide model selection
diffusion_transformer_dropdown, _ = create_model_checkpoints(controller, visible=False, default_model="Wan-AI/Wan2.1-T2V-14B")
# Use snapshot download for the VideoCoF repo to get all weights (including safetensors)
try:
from huggingface_hub import snapshot_download, hf_hub_download
print("Downloading Wan2.1-T2V-14B weights...")
snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="Wan-AI/Wan2.1-T2V-14B")
os.makedirs("models/Personalized_Model", exist_ok=True)
print("Downloading VideoCoF weights...")
hf_hub_download(repo_id="XiangpengYang/VideoCoF", filename="videocof.safetensors", local_dir="models/Personalized_Model")
print("Downloading FusionX Acceleration LoRA...")
hf_hub_download(repo_id="MonsterMMORPG/Wan_GGUF", filename="Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors", local_dir="models/Personalized_Model")
except Exception as e:
print(f"Warning: Failed to pre-download weights: {e}")
base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints(controller, visible=False, default_lora="videocof.safetensors")
# Set default LoRA alpha to 1.0 (matching inference.py)
lora_alpha_slider.value = 1.0
with gr.Row():
# Disable teacache by default
enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = create_teacache_params(False, 0.10, 5, False)
cfg_skip_ratio = create_cfg_skip_params(0)
enable_riflex, riflex_k = create_cfg_riflex_k(False, 6)
with gr.Column(variant="panel"):
prompt_textbox, negative_prompt_textbox = create_prompts(prompt="Remove the young man with short black hair wearing black shirt on the left.")
with gr.Row():
with gr.Column():
sampler_dropdown, sample_step_slider = create_samplers(controller)
# Default steps lowered to 4 for acceleration
sample_step_slider.value = 4
# Custom VideoCoF Params
with gr.Group():
gr.Markdown("### VideoCoF Parameters")
source_frames_slider = gr.Slider(label="Source Frames", minimum=1, maximum=100, value=33, step=1)
reasoning_frames_slider = gr.Slider(label="Reasoning Frames", minimum=1, maximum=20, value=4, step=1)
repeat_rope_checkbox = gr.Checkbox(label="Repeat RoPE", value=True)
# Add Acceleration Checkbox
enable_acceleration = gr.Checkbox(label="Enable 4-step Acceleration (FusionX LoRA)", value=True)
# Use custom height/width creation to hide/customize
resize_method, width_slider, height_slider, base_resolution = create_height_width_english(
default_height=480, default_width=832, maximum_height=1344, maximum_width=1344
)
# Default video length 65
generation_method, length_slider, overlap_video_length, partial_video_length = \
create_generation_methods_and_video_length(
["Video Generation"],
default_video_length=65,
maximum_video_length=161
)
# Simplified input for VideoCoF - mainly Video to Video.
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
["Video to Video"], prompt_textbox, support_end_image=False, default_video="assets/two_man.mp4",
video_examples=[
["assets/two_man.mp4", "Remove the young man with short black hair wearing black shirt on the left."],
["assets/sign.mp4", "Replace the yellow \"SCHOOL\" sign with a red hospital sign, featuring a white hospital emblem on the top and the word \"HOSPITAL\" below."]
]
)
# Ensure validation_video is visible and interactive
validation_video.visible = True
validation_video.interactive = True
# Set default seed to 0
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(True)
seed_textbox.value = "0"
cfg_scale_slider.value = 1.0
generate_button = gr.Button(value="Generate", variant='primary')
result_image, result_video, infer_progress = create_ui_outputs()
# Event handlers
generate_button.click(
fn=controller.generate,
inputs=[
diffusion_transformer_dropdown,
base_model_dropdown,
lora_model_dropdown,
lora_alpha_slider,
prompt_textbox,
negative_prompt_textbox,
sampler_dropdown,
sample_step_slider,
resize_method,
width_slider,
height_slider,
base_resolution,
generation_method,
length_slider,
overlap_video_length,
partial_video_length,
cfg_scale_slider,
start_image,
end_image,
validation_video,
validation_video_mask,
control_video,
denoise_strength,
seed_textbox,
ref_image,
enable_teacache,
teacache_threshold,
num_skip_start_steps,
teacache_offload,
cfg_skip_ratio,
enable_riflex,
riflex_k,
# New inputs
source_frames_slider,
reasoning_frames_slider,
repeat_rope_checkbox,
enable_acceleration
],
outputs=[result_image, result_video, infer_progress]
)
return demo, controller
if __name__ == "__main__":
from videox_fun.ui.controller import flow_scheduler_dict
GPU_memory_mode = "sequential_cpu_offload"
compile_dit = False
weight_dtype = torch.bfloat16
server_name = "0.0.0.0"
server_port = 7860
config_path = "config/wan2.1/wan_civitai.yaml"
demo, controller = ui(GPU_memory_mode, flow_scheduler_dict, config_path, compile_dit, weight_dtype)
demo.queue(status_update_rate=1).launch(
server_name=server_name,
server_port=server_port,
prevent_thread_lock=True,
share=False
)
while True:
time.sleep(5)