Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import time | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| import imageio | |
| from PIL import Image | |
| # Add project root to path | |
| # current_file_path = os.path.abspath(__file__) | |
| # project_root = os.path.dirname(os.path.dirname(current_file_path)) | |
| # if project_root not in sys.path: | |
| # sys.path.insert(0, project_root) | |
| from videox_fun.ui.wan_ui import Wan_Controller, css | |
| from videox_fun.ui.ui import ( | |
| create_model_type, create_model_checkpoints, create_finetune_models_checkpoints, | |
| create_teacache_params, create_cfg_skip_params, create_cfg_riflex_k, | |
| create_prompts, create_samplers, create_height_width, | |
| create_generation_methods_and_video_length, create_generation_method, | |
| create_cfg_and_seedbox, create_ui_outputs | |
| ) | |
| from videox_fun.data.dataset_image_video import derive_ground_object_from_instruction | |
| from videox_fun.utils.lora_utils import merge_lora, unmerge_lora | |
| from videox_fun.utils.utils import save_videos_grid, timer | |
| # Redefine create_height_width to remove Chinese and specific defaults if needed, | |
| # although we will mostly ignore sliders if we use input resolution. | |
| # We will create a custom version here to avoid modifying the library file if possible, | |
| # or we just rely on `create_height_width` and update labels. | |
| # But `create_height_width` is imported. Let's override it or create a new one. | |
| def create_height_width_english(default_height, default_width, maximum_height, maximum_width): | |
| resize_method = gr.Radio( | |
| ["Generate by", "Resize according to Reference"], | |
| value="Generate by", | |
| show_label=False, | |
| visible=False # Hide since we force input resolution | |
| ) | |
| # We keep sliders visible but maybe we can update them dynamically or just ignore them? | |
| # User requested "input is whatever resolution, inference is whatever resolution". | |
| # So we can hide these or just label them as "Default / Override if no video". | |
| # But better to hide them if we always use video resolution. | |
| # However, if no video is provided (which shouldn't happen for VideoCoF), we might need them. | |
| # Let's keep them but make them less prominent or explain. | |
| # Actually user said "no default 480x832", implying don't force it. | |
| width_slider = gr.Slider(label="Width", value=default_width, minimum=128, maximum=maximum_width, step=16, visible=False) | |
| height_slider = gr.Slider(label="Height", value=default_height, minimum=128, maximum=maximum_height, step=16, visible=False) | |
| base_resolution = gr.Radio(label="Base Resolution", value=512, choices=[512, 640, 768, 896, 960, 1024], visible=False) | |
| return resize_method, width_slider, height_slider, base_resolution | |
| def load_video_frames(video_path: str, source_frames: int): | |
| assert source_frames is not None, "source_frames is required" | |
| reader = imageio.get_reader(video_path) | |
| try: | |
| total_frames = reader.count_frames() | |
| except Exception: | |
| total_frames = sum(1 for _ in reader) | |
| reader = imageio.get_reader(video_path) | |
| stride = max(1, total_frames // source_frames) | |
| # Using random start frame as in inference.py | |
| start_frame = torch.randint(0, max(1, total_frames - stride * source_frames), (1,))[0].item() | |
| frames = [] | |
| original_height, original_width = None, None | |
| for i in range(source_frames): | |
| idx = start_frame + i * stride | |
| if idx >= total_frames: | |
| break | |
| try: | |
| frame = reader.get_data(idx) | |
| pil_frame = Image.fromarray(frame) | |
| if original_height is None: | |
| original_width, original_height = pil_frame.size | |
| frames.append(pil_frame) | |
| except IndexError: | |
| break | |
| reader.close() | |
| while len(frames) < source_frames: | |
| if frames: | |
| frames.append(frames[-1].copy()) | |
| else: | |
| w, h = (original_width, original_height) if original_width else (832, 480) | |
| frames.append(Image.new('RGB', (w, h), (0, 0, 0))) | |
| input_video = torch.from_numpy(np.array(frames)) | |
| input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float() | |
| input_video = input_video * (2.0 / 255.0) - 1.0 | |
| return input_video, original_height, original_width | |
| class VideoCoF_Controller(Wan_Controller): | |
| def generate( | |
| self, | |
| diffusion_transformer_dropdown, | |
| base_model_dropdown, | |
| lora_model_dropdown, | |
| lora_alpha_slider, | |
| prompt_textbox, | |
| negative_prompt_textbox, | |
| sampler_dropdown, | |
| sample_step_slider, | |
| resize_method, | |
| width_slider, | |
| height_slider, | |
| base_resolution, | |
| generation_method, | |
| length_slider, | |
| overlap_video_length, | |
| partial_video_length, | |
| cfg_scale_slider, | |
| start_image, | |
| end_image, | |
| validation_video, | |
| validation_video_mask, | |
| control_video, | |
| denoise_strength, | |
| seed_textbox, | |
| ref_image=None, | |
| enable_teacache=None, | |
| teacache_threshold=None, | |
| num_skip_start_steps=None, | |
| teacache_offload=None, | |
| cfg_skip_ratio=None, | |
| enable_riflex=None, | |
| riflex_k=None, | |
| # Custom args | |
| source_frames_slider=33, | |
| reasoning_frames_slider=4, | |
| repeat_rope_checkbox=True, | |
| fps=10, | |
| is_api=False, | |
| ): | |
| self.clear_cache() | |
| print(f"VideoCoF Generation started.") | |
| if self.diffusion_transformer_dropdown != diffusion_transformer_dropdown: | |
| self.update_diffusion_transformer(diffusion_transformer_dropdown) | |
| if self.base_model_path != base_model_dropdown: | |
| self.update_base_model(base_model_dropdown) | |
| if self.lora_model_path != lora_model_dropdown: | |
| self.update_lora_model(lora_model_dropdown) | |
| # Scheduler setup | |
| scheduler_config = self.pipeline.scheduler.config | |
| if sampler_dropdown in ["Flow_Unipc", "Flow_DPM++"]: | |
| scheduler_config['shift'] = 1 | |
| self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config) | |
| # LoRA merging | |
| if self.lora_model_path != "none": | |
| print(f"Merge Lora.") | |
| self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) | |
| # Seed | |
| if int(seed_textbox) != -1 and seed_textbox != "": | |
| torch.manual_seed(int(seed_textbox)) | |
| else: | |
| seed_textbox = np.random.randint(0, 1e10) | |
| generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox)) | |
| try: | |
| # VideoCoF logic | |
| # Use validation_video as source if provided (UI standard for Video-to-Video) | |
| input_video_path = validation_video | |
| if input_video_path is None: | |
| # Fallback to control_video if set, but standard UI uses validation_video | |
| input_video_path = control_video | |
| if input_video_path is None: | |
| raise ValueError("Please upload a video for VideoCoF generation.") | |
| # CoT Prompt Construction | |
| edit_text = prompt_textbox | |
| ground_instr = derive_ground_object_from_instruction(edit_text) | |
| prompt = ( | |
| "A video sequence showing three parts: first the original scene, " | |
| f"then grounded {ground_instr}, and finally the same scene but {edit_text}" | |
| ) | |
| print(f"Constructed prompt: {prompt}") | |
| # Load video frames | |
| input_video_tensor, video_height, video_width = load_video_frames( | |
| input_video_path, | |
| source_frames=source_frames_slider | |
| ) | |
| # Using loaded video dimensions | |
| h, w = video_height, video_width | |
| print(f"Input video dimensions: {w}x{h}") | |
| print(f"Running pipeline with frames={length_slider}, source={source_frames_slider}, reasoning={reasoning_frames_slider}") | |
| sample = self.pipeline( | |
| video=input_video_tensor, | |
| prompt=prompt, | |
| num_frames=length_slider, | |
| source_frames=source_frames_slider, | |
| reasoning_frames=reasoning_frames_slider, | |
| negative_prompt=negative_prompt_textbox, | |
| height=h, | |
| width=w, | |
| generator=generator, | |
| guidance_scale=cfg_scale_slider, | |
| num_inference_steps=sample_step_slider, | |
| repeat_rope=repeat_rope_checkbox, | |
| cot=True, | |
| ).videos | |
| final_video = sample | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| if self.lora_model_path != "none": | |
| self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) | |
| return gr.update(), gr.update(), f"Error: {str(e)}" | |
| # Unmerge LoRA | |
| if self.lora_model_path != "none": | |
| self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) | |
| # Save output | |
| save_sample_path = self.save_outputs( | |
| False, length_slider, final_video, fps=fps | |
| ) | |
| # Return input video to display it alongside output if needed? | |
| # But generate returns [result_image, result_video, infer_progress]. | |
| # The user said "load original video didn't display". | |
| # That usually refers to the input component not showing the video after upload or example selection. | |
| # Grado handles that automatically if `value` is set or user uploads. | |
| # Maybe they mean the `validation_video` component didn't show the example? | |
| # Or do they mean they want to see the processed input frames? | |
| # "load 原视频没有display 出来" -> "Loaded original video didn't display". | |
| # Likely referring to the input UI component. | |
| # If they mean they want to see it in the output area, we can't easily change the return signature without changing UI structure. | |
| # But let's ensure the input component works. | |
| return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" | |
| def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype): | |
| controller = VideoCoF_Controller( | |
| GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", | |
| config_path=config_path, compile_dit=compile_dit, | |
| weight_dtype=weight_dtype | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# VideoCoF Demo") | |
| with gr.Column(variant="panel"): | |
| # Hide model selection | |
| diffusion_transformer_dropdown, _ = create_model_checkpoints(controller, visible=False, default_model="Wan-AI/Wan2.1-T2V-14B") | |
| base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints(controller, visible=False, default_lora="XiangpengYang/VideoCoF") | |
| # Set default LoRA alpha to 1.0 (matching inference.py) | |
| lora_alpha_slider.value = 1.0 | |
| with gr.Row(): | |
| # Disable teacache by default | |
| enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = create_teacache_params(False, 0.10, 5, False) | |
| cfg_skip_ratio = create_cfg_skip_params(0) | |
| enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) | |
| with gr.Column(variant="panel"): | |
| prompt_textbox, negative_prompt_textbox = create_prompts(prompt="Remove the young man with short black hair wearing black shirt on the left.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| sampler_dropdown, sample_step_slider = create_samplers(controller) | |
| # Custom VideoCoF Params | |
| with gr.Group(): | |
| gr.Markdown("### VideoCoF Parameters") | |
| source_frames_slider = gr.Slider(label="Source Frames", minimum=1, maximum=100, value=33, step=1) | |
| reasoning_frames_slider = gr.Slider(label="Reasoning Frames", minimum=1, maximum=20, value=4, step=1) | |
| repeat_rope_checkbox = gr.Checkbox(label="Repeat RoPE", value=True) | |
| # Use custom height/width creation to hide/customize | |
| resize_method, width_slider, height_slider, base_resolution = create_height_width_english( | |
| default_height=480, default_width=832, maximum_height=1344, maximum_width=1344 | |
| ) | |
| # Default video length 65 | |
| generation_method, length_slider, overlap_video_length, partial_video_length = \ | |
| create_generation_methods_and_video_length( | |
| ["Video Generation"], | |
| default_video_length=65, | |
| maximum_video_length=161 | |
| ) | |
| # Simplified input for VideoCoF - mainly Video to Video. | |
| image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( | |
| ["Video to Video"], prompt_textbox, support_end_image=False, default_video="assets/two_man.mp4", | |
| video_examples=[ | |
| ["assets/two_man.mp4", "Remove the young man with short black hair wearing black shirt on the left."], | |
| ["assets/sign.mp4", "Replace the yellow \"SCHOOL\" sign with a red hospital sign, featuring a white hospital emblem on the top and the word \"HOSPITAL\" below."] | |
| ] | |
| ) | |
| # Ensure validation_video is visible and interactive | |
| validation_video.visible = True | |
| validation_video.interactive = True | |
| # Set default seed to 0 | |
| cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(True) | |
| seed_textbox.value = "0" | |
| generate_button = gr.Button(value="Generate", variant='primary') | |
| result_image, result_video, infer_progress = create_ui_outputs() | |
| # Event handlers | |
| generate_button.click( | |
| fn=controller.generate, | |
| inputs=[ | |
| diffusion_transformer_dropdown, | |
| base_model_dropdown, | |
| lora_model_dropdown, | |
| lora_alpha_slider, | |
| prompt_textbox, | |
| negative_prompt_textbox, | |
| sampler_dropdown, | |
| sample_step_slider, | |
| resize_method, | |
| width_slider, | |
| height_slider, | |
| base_resolution, | |
| generation_method, | |
| length_slider, | |
| overlap_video_length, | |
| partial_video_length, | |
| cfg_scale_slider, | |
| start_image, | |
| end_image, | |
| validation_video, | |
| validation_video_mask, | |
| control_video, | |
| denoise_strength, | |
| seed_textbox, | |
| ref_image, | |
| enable_teacache, | |
| teacache_threshold, | |
| num_skip_start_steps, | |
| teacache_offload, | |
| cfg_skip_ratio, | |
| enable_riflex, | |
| riflex_k, | |
| # New inputs | |
| source_frames_slider, | |
| reasoning_frames_slider, | |
| repeat_rope_checkbox | |
| ], | |
| outputs=[result_image, result_video, infer_progress] | |
| ) | |
| return demo, controller | |
| if __name__ == "__main__": | |
| from videox_fun.ui.controller import flow_scheduler_dict | |
| GPU_memory_mode = "sequential_cpu_offload" | |
| compile_dit = False | |
| weight_dtype = torch.bfloat16 | |
| server_name = "0.0.0.0" | |
| server_port = 7860 | |
| config_path = "config/wan2.1/wan_civitai.yaml" | |
| demo, controller = ui(GPU_memory_mode, flow_scheduler_dict, config_path, compile_dit, weight_dtype) | |
| demo.queue(status_update_rate=1).launch( | |
| server_name=server_name, | |
| server_port=server_port, | |
| prevent_thread_lock=True, | |
| share=False | |
| ) | |
| while True: | |
| time.sleep(5) | |