multimodalart HF Staff commited on
Commit
c73cdb0
·
verified ·
1 Parent(s): cda9eef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +310 -141
app.py CHANGED
@@ -1,12 +1,15 @@
1
  import os
2
  import shutil
 
3
  import sys
4
- import subprocess
5
- import asyncio
6
- import uuid
7
  from typing import Sequence, Mapping, Any, Union
8
- from huggingface_hub import hf_hub_download
9
  import spaces
 
 
 
 
10
 
11
  def hf_hub_download_local(repo_id, filename, local_dir, **kwargs):
12
  downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
@@ -32,156 +35,322 @@ hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/
32
  hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras")
33
  print("Downloads complete.")
34
 
35
- # --- 2. Let ComfyUI's main.py handle all initial setup ---
36
- print("Importing ComfyUI's main.py for setup...")
37
- import main
38
- print("ComfyUI main imported.")
39
 
40
- # --- 3. Now we can import the rest of the necessary modules ---
41
- import torch
42
- import gradio as gr
43
- from comfy import model_management
44
- from PIL import Image
45
- import random
46
- import nodes
47
 
48
- # --- 4. Manually trigger the node initialization ---
49
- print("Initializing ComfyUI nodes...")
50
- loop = asyncio.new_event_loop()
51
- asyncio.set_event_loop(loop)
52
- loop.run_until_complete(nodes.init_extra_nodes())
53
- print("Nodes initialized.")
54
 
55
- # --- Helper function ---
56
- def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
 
 
 
 
 
 
 
 
 
 
57
  try:
58
  return obj[index]
59
  except KeyError:
60
- return obj["result"][index]
61
-
62
-
63
- # --- ZeroGPU: Pre-load models and instantiate nodes globally ---
64
- cliploader = nodes.NODE_CLASS_MAPPINGS["CLIPLoader"]()
65
- cliptextencode = nodes.NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
66
- unetloader = nodes.NODE_CLASS_MAPPINGS["UNETLoader"]()
67
- vaeloader = nodes.NODE_CLASS_MAPPINGS["VAELoader"]()
68
- clipvisionloader = nodes.NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
69
- loadimage = nodes.NODE_CLASS_MAPPINGS["LoadImage"]()
70
- clipvisionencode = nodes.NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
71
- loraloadermodelonly = nodes.NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
72
- modelsamplingsd3 = nodes.NODE_CLASS_MAPPINGS["ModelSamplingSD3"]()
73
- pathchsageattentionkj = nodes.NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]()
74
- wanfirstlastframetovideo = nodes.NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]()
75
- ksampleradvanced = nodes.NODE_CLASS_MAPPINGS["KSamplerAdvanced"]()
76
- vaedecode = nodes.NODE_CLASS_MAPPINGS["VAEDecode"]()
77
- createvideo = nodes.NODE_CLASS_MAPPINGS["CreateVideo"]()
78
- savevideo = nodes.NODE_CLASS_MAPPINGS["SaveVideo"]()
79
-
80
- cliploader_38 = cliploader.load_clip(clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan", device="cpu")
81
- unetloader_37_low_noise = unetloader.load_unet(unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
82
- unetloader_91_high_noise = unetloader.load_unet(unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
83
- vaeloader_39 = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors")
84
- clipvisionloader_49 = clipvisionloader.load_clip(clip_name="clip_vision_h.safetensors")
85
-
86
- loraloadermodelonly_94_high = loraloadermodelonly.load_lora_model_only(lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", strength_model=0.8, model=get_value_at_index(unetloader_91_high_noise, 0))
87
- loraloadermodelonly_95_low = loraloadermodelonly.load_lora_model_only(lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", strength_model=0.8, model=get_value_at_index(unetloader_37_low_noise, 0))
88
- modelsamplingsd3_93_low = modelsamplingsd3.patch(shift=8, model=get_value_at_index(loraloadermodelonly_95_low, 0))
89
- pathchsageattentionkj_98_low = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(modelsamplingsd3_93_low, 0))
90
- modelsamplingsd3_79_high = modelsamplingsd3.patch(shift=8, model=get_value_at_index(loraloadermodelonly_94_high, 0))
91
- pathchsageattentionkj_96_high = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(modelsamplingsd3_79_high, 0))
92
-
93
- model_loaders = [cliploader_38, unetloader_37_low_noise, unetloader_91_high_noise, vaeloader_39, clipvisionloader_49, loraloadermodelonly_94_high, loraloadermodelonly_95_low]
94
- valid_models = [getattr(loader[0], 'patcher', loader[0]) for loader in model_loaders if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)]
95
- model_management.load_models_gpu(valid_models)
96
-
97
- # --- App Logic ---
98
- def calculate_dimensions(width, height):
99
- if width == height: return 480, 480
100
- if width > height: new_width, new_height = 832, int(height * (832 / width))
101
- else: new_height, new_width = 832, int(width * (832 / height))
102
- return (new_width // 16) * 16, (new_height // 16) * 16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  @spaces.GPU(duration=120)
105
- def generate_video(prompt, first_image_path, last_image_path, duration_seconds, progress=gr.Progress(track_tqdm=True)):
106
- # Create a temporary directory for resized images
107
- temp_dir = "input"
108
- os.makedirs(temp_dir, exist_ok=True)
 
 
 
 
 
 
 
109
 
110
- with torch.inference_mode():
111
-
112
- # --- Python Image Preprocessing using Pillow ---
113
- print("Preprocessing images with Pillow...")
114
- with Image.open(first_image_path) as img:
115
- orig_width, orig_height = img.size
116
-
117
- target_width, target_height = calculate_dimensions(orig_width, orig_height)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- # Resize first image
120
- with Image.open(first_image_path) as img:
121
- img_resized = img.resize((target_width, target_height), Image.Resampling.LANCZOS)
122
- resized_first_path = os.path.join(temp_dir, f"first_frame_resized_{uuid.uuid4().hex}.png")
123
- print(resized_first_path)
124
- img_resized.save(resized_first_path)
125
-
126
- # Resize second image to match the target dimensions
127
- with Image.open(last_image_path) as img:
128
- img_resized = img.resize((target_width, target_height), Image.Resampling.LANCZOS)
129
- resized_last_path = os.path.join(temp_dir, f"last_frame_resized_{uuid.uuid4().hex}.png")
130
- print(resized_last_path)
131
- img_resized.save(resized_last_path)
132
- print(f"Images resized to {target_width}x{target_height} and saved temporarily.")
133
- # --- End Preprocessing ---
134
-
135
- FPS, MAX_FRAMES = 16, 81
136
- length_in_frames = max(1, min(int(duration_seconds * FPS), MAX_FRAMES))
137
- print(f"Requested duration: {duration_seconds}s. Calculated frames: {length_in_frames}")
138
 
139
- # Load the pre-processed images into ComfyUI
140
- loaded_first_image = loadimage.load_image(image=os.path.basename(resized_first_path))
141
- loaded_last_image = loadimage.load_image(image=os.path.basename(resized_last_path))
142
 
143
- cliptextencode_6 = cliptextencode.encode(text=prompt, clip=get_value_at_index(cliploader_38, 0))
144
- cliptextencode_7_negative = cliptextencode.encode(text="low quality, worst quality, jpeg artifacts, ugly, deformed, blurry", clip=get_value_at_index(cliploader_38, 0))
145
- clipvisionencode_51 = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clipvisionloader_49, 0), image=get_value_at_index(loaded_first_image, 0))
146
- clipvisionencode_87 = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clipvisionloader_49, 0), image=get_value_at_index(loaded_last_image, 0))
147
 
148
- wanfirstlastframetovideo_83 = wanfirstlastframetovideo.EXECUTE_NORMALIZED(width=target_width, height=target_height, length=length_in_frames, batch_size=1, positive=get_value_at_index(cliptextencode_6, 0), negative=get_value_at_index(cliptextencode_7_negative, 0), vae=get_value_at_index(vaeloader_39, 0), clip_vision_start_image=get_value_at_index(clipvisionencode_51, 0), clip_vision_end_image=get_value_at_index(clipvisionencode_87, 0), start_image=get_value_at_index(loaded_first_image, 0), end_image=get_value_at_index(loaded_last_image, 0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- ksampler_positive = get_value_at_index(wanfirstlastframetovideo_83, 0)
151
- ksampler_negative = get_value_at_index(wanfirstlastframetovideo_83, 1)
152
- ksampler_latent = get_value_at_index(wanfirstlastframetovideo_83, 2)
153
-
154
- ksampleradvanced_101 = ksampleradvanced.sample(add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4, return_with_leftover_noise="enable", model=get_value_at_index(pathchsageattentionkj_96_high, 0), positive=ksampler_positive, negative=ksampler_negative, latent_image=ksampler_latent)
155
- ksampleradvanced_102 = ksampleradvanced.sample(add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000, return_with_leftover_noise="disable", model=get_value_at_index(pathchsageattentionkj_98_low, 0), positive=ksampler_positive, negative=ksampler_negative, latent_image=get_value_at_index(ksampleradvanced_101, 0))
156
-
157
- vaedecode_8 = vaedecode.decode(samples=get_value_at_index(ksampleradvanced_102, 0), vae=get_value_at_index(vaeloader_39, 0))
158
- createvideo_104 = createvideo.create_video(fps=16, images=get_value_at_index(vaedecode_8, 0))
159
- savevideo_103 = savevideo.save_video(filename_prefix="ComfyUI_Video", format="mp4", codec="h264", video=get_value_at_index(createvideo_104, 0))
160
- print("** DEBUG ** ", savevideo_103)
161
- return f"output/{savevideo_103['ui']['images'][0]['filename']}"
162
-
163
- # --- Gradio Interface ---
164
- with gr.Blocks() as app:
165
- gr.Markdown("# Wan 2.2 First/Last Frame Video Fast")
166
- gr.Markdown("Running the [Wan 2.2 First/Last Frame ComfyUI workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/) on ZeroGPU")
167
- with gr.Row():
168
- with gr.Column(scale=1):
169
- prompt_input = gr.Textbox(label="Prompt", value="a man dancing in the street, cinematic")
170
- duration_slider = gr.Slider(minimum=1.0, maximum=5.0, value=2.0, step=0.1, label="Video Duration (seconds)")
171
- with gr.Row():
172
- first_image = gr.Image(label="First Frame", type="filepath")
173
- last_image = gr.Image(label="Last Frame", type="filepath")
174
- generate_btn = gr.Button("Generate Video")
175
- with gr.Column(scale=2):
176
- output_video = gr.Video(label="Generated Video")
177
- generate_btn.click(fn=generate_video, inputs=[prompt_input, first_image, last_image, duration_slider], outputs=[output_video])
178
- gr.Examples(examples=[["a beautiful woman, cinematic", "examples/start.png", "examples/end.png", 2.5]], inputs=[prompt_input, first_image, last_image, duration_slider])
179
 
180
  if __name__ == "__main__":
181
- if not os.path.exists("examples"): os.makedirs("examples")
182
- if not os.path.exists("examples/start.png"): Image.new('RGB', (512, 512), color='red').save('examples/start.png')
183
- if not os.path.exists("examples/end.png"): Image.new('RGB', (512, 512), color='blue').save('examples/end.png')
184
- # Set the input directory for LoadImage to find the temp files
185
- import folder_paths
186
- folder_paths.add_model_folder_path("input", "temp_resized")
187
- app.launch()
 
1
  import os
2
  import shutil
3
+ import random
4
  import sys
5
+ import tempfile
 
 
6
  from typing import Sequence, Mapping, Any, Union
7
+
8
  import spaces
9
+ import torch
10
+ import gradio as gr
11
+ from PIL import Image
12
+ from huggingface_hub import hf_hub_download
13
 
14
  def hf_hub_download_local(repo_id, filename, local_dir, **kwargs):
15
  downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
 
35
  hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras")
36
  print("Downloads complete.")
37
 
38
+ # --- Boilerplate code from the original script ---
 
 
 
39
 
40
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
41
+ """Returns the value at the given index of a sequence or mapping.
 
 
 
 
 
42
 
43
+ If the object is a sequence (like list or string), returns the value at the given index.
44
+ If the object is a mapping (like a dictionary), returns the value at the index-th key.
 
 
 
 
45
 
46
+ Some return a dictionary, in these cases, we look for the "results" key
47
+
48
+ Args:
49
+ obj (Union[Sequence, Mapping]): The object to retrieve the value from.
50
+ index (int): The index of the value to retrieve.
51
+
52
+ Returns:
53
+ Any: The value at the given index.
54
+
55
+ Raises:
56
+ IndexError: If the index is out of bounds for the object and the object is not a mapping.
57
+ """
58
  try:
59
  return obj[index]
60
  except KeyError:
61
+ # This is a fallback for custom node outputs that might be dictionaries
62
+ if isinstance(obj, Mapping) and "result" in obj:
63
+ return obj["result"][index]
64
+ raise
65
+
66
+ def find_path(name: str, path: str = None) -> str:
67
+ """
68
+ Recursively looks at parent folders starting from the given path until it finds the given name.
69
+ Returns the path as a Path object if found, or None otherwise.
70
+ """
71
+ if path is None:
72
+ path = os.getcwd()
73
+
74
+ if name in os.listdir(path):
75
+ path_name = os.path.join(path, name)
76
+ print(f"'{name}' found: {path_name}")
77
+ return path_name
78
+
79
+ parent_directory = os.path.dirname(path)
80
+ if parent_directory == path:
81
+ return None
82
+
83
+ return find_path(name, parent_directory)
84
+
85
+
86
+ def add_comfyui_directory_to_sys_path() -> None:
87
+ """
88
+ Add 'ComfyUI' to the sys.path
89
+ """
90
+ # Use a more robust name to find the ComfyUI directory
91
+ comfyui_path = find_path("ComfyUI")
92
+ if comfyui_path is not None and os.path.isdir(comfyui_path):
93
+ sys.path.append(comfyui_path)
94
+ print(f"'{comfyui_path}' added to sys.path")
95
+ else:
96
+ print("Could not find ComfyUI directory. Please run from a parent folder of ComfyUI.")
97
+
98
+ def add_extra_model_paths() -> None:
99
+ """
100
+ Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path.
101
+ """
102
+ try:
103
+ from main import load_extra_path_config
104
+ except ImportError:
105
+ print(
106
+ "Could not import load_extra_path_config from main.py. This might be okay if you don't use it."
107
+ )
108
+ return
109
+
110
+ extra_model_paths = find_path("extra_model_paths.yaml")
111
+ if extra_model_paths is not None:
112
+ load_extra_path_config(extra_model_paths)
113
+ else:
114
+ print("Could not find an optional 'extra_model_paths.yaml' config file.")
115
+
116
+ def import_custom_nodes() -> None:
117
+ """Find all custom nodes in the custom_nodes folder and add those node objects to NODE_CLASS_MAPPINGS
118
+ This function sets up a new asyncio event loop, initializes the PromptServer,
119
+ creates a PromptQueue, and initializes the custom nodes.
120
+ """
121
+ import asyncio
122
+ import execution
123
+ from nodes import init_extra_nodes
124
+ import server
125
+
126
+ loop = asyncio.new_event_loop()
127
+ asyncio.set_event_loop(loop)
128
+ server_instance = server.PromptServer(loop)
129
+ execution.PromptQueue(server_instance)
130
+ loop.run_until_complete(init_extra_nodes(init_custom_nodes=True))
131
+
132
+
133
+ # --- Model Loading and Caching ---
134
+
135
+ # Dictionary to hold all loaded models and node instances
136
+ MODELS_AND_NODES = {}
137
+
138
+ print("Setting up ComfyUI paths...")
139
+ add_comfyui_directory_to_sys_path()
140
+ add_extra_model_paths()
141
 
142
+ print("Importing custom nodes...")
143
+ import_custom_nodes()
144
+
145
+ # Now that paths are set up, we can import from nodes
146
+ from nodes import NODE_CLASS_MAPPINGS
147
+ global folder_paths # Make folder_paths globally accessible
148
+ import folder_paths
149
+
150
+ print("Loading models into memory. This may take a few minutes...")
151
+
152
+ # Load Text-to-Image models (CLIP, UNETs, VAE)
153
+ cliploader = NODE_CLASS_MAPPINGS["CLIPLoader"]()
154
+ MODELS_AND_NODES["clip"] = cliploader.load_clip(
155
+ clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan", device="cpu"
156
+ )
157
+
158
+ unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
159
+ unet_low_noise = unetloader.load_unet(
160
+ unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors",
161
+ weight_dtype="default",
162
+ )
163
+ unet_high_noise = unetloader.load_unet(
164
+ unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors",
165
+ weight_dtype="default",
166
+ )
167
+
168
+ vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
169
+ MODELS_AND_NODES["vae"] = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors")
170
+
171
+ # Load LoRAs
172
+ loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
173
+ MODELS_AND_NODES["model_low_noise"] = loraloadermodelonly.load_lora_model_only(
174
+ lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors",
175
+ strength_model=0.8,
176
+ model=get_value_at_index(unet_low_noise, 0),
177
+ )
178
+ MODELS_AND_NODES["model_high_noise"] = loraloadermodelonly.load_lora_model_only(
179
+ lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors",
180
+ strength_model=0.8,
181
+ model=get_value_at_index(unet_high_noise, 0),
182
+ )
183
+
184
+ # Load Vision model
185
+ clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
186
+ MODELS_AND_NODES["clip_vision"] = clipvisionloader.load_clip(
187
+ clip_name="clip_vision_h.safetensors"
188
+ )
189
+
190
+ # Instantiate all required node classes
191
+ MODELS_AND_NODES["CLIPTextEncode"] = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
192
+ MODELS_AND_NODES["LoadImage"] = NODE_CLASS_MAPPINGS["LoadImage"]()
193
+ MODELS_AND_NODES["CLIPVisionEncode"] = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
194
+ MODELS_AND_NODES["ModelSamplingSD3"] = NODE_CLASS_MAPPINGS["ModelSamplingSD3"]()
195
+ MODELS_AND_NODES["PathchSageAttentionKJ"] = NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]()
196
+ MODELS_AND_NODES["WanFirstLastFrameToVideo"] = NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]()
197
+ MODELS_AND_NODES["KSamplerAdvanced"] = NODE_CLASS_MAPPINGS["KSamplerAdvanced"]()
198
+ MODELS_AND_NODES["VAEDecode"] = NODE_CLASS_MAPPINGS["VAEDecode"]()
199
+ MODELS_AND_NODES["CreateVideo"] = NODE_CLASS_MAPPINGS["CreateVideo"]()
200
+ MODELS_AND_NODES["SaveVideo"] = NODE_CLASS_MAPPINGS["SaveVideo"]()
201
+
202
+ print("All models loaded successfully!")
203
+
204
+ # --- Main Video Generation Logic ---
205
  @spaces.GPU(duration=120)
206
+ def generate_video(start_image_pil: Image.Image, end_image_pil: Image.Image, prompt: str, negative_prompt: str, progress=gr.Progress(track_tqdm=True)):
207
+ """
208
+ The main function to generate a video based on user inputs.
209
+ This function is called every time the user clicks the 'Generate' button.
210
+ """
211
+ # Use pre-loaded models and nodes from the global dictionary
212
+ clip = MODELS_AND_NODES["clip"]
213
+ vae = MODELS_AND_NODES["vae"]
214
+ model_low_noise = MODELS_AND_NODES["model_low_noise"]
215
+ model_high_noise = MODELS_AND_NODES["model_high_noise"]
216
+ clip_vision = MODELS_AND_NODES["clip_vision"]
217
 
218
+ cliptextencode = MODELS_AND_NODES["CLIPTextEncode"]
219
+ loadimage = MODELS_AND_NODES["LoadImage"]
220
+ clipvisionencode = MODELS_AND_NODES["CLIPVisionEncode"]
221
+ modelsamplingsd3 = MODELS_AND_NODES["ModelSamplingSD3"]
222
+ pathchsageattentionkj = MODELS_AND_NODES["PathchSageAttentionKJ"]
223
+ wanfirstlastframetovideo = MODELS_AND_NODES["WanFirstLastFrameToVideo"]
224
+ ksampleradvanced = MODELS_AND_NODES["KSamplerAdvanced"]
225
+ vaedecode = MODELS_AND_NODES["VAEDecode"]
226
+ createvideo = MODELS_AND_NODES["CreateVideo"]
227
+ savevideo = MODELS_AND_NODES["SaveVideo"]
228
+
229
+ # Save uploaded images to temporary files
230
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as start_file, \
231
+ tempfile.NamedTemporaryFile(suffix=".png", delete=False) as end_file:
232
+ start_image_pil.save(start_file.name)
233
+ end_image_pil.save(end_file.name)
234
+ start_image_path = start_file.name
235
+ end_image_path = end_file.name
236
+
237
+ try:
238
+ with torch.inference_mode():
239
+ progress(0.1, desc="Encoding text and images...")
240
+ # --- Workflow execution ---
241
+ positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0))
242
+ negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0))
243
+
244
+ start_image_loaded = loadimage.load_image(image=start_image_path)
245
+ end_image_loaded = loadimage.load_image(image=end_image_path)
246
+
247
+ clip_vision_encoded_start = clipvisionencode.encode(
248
+ crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0)
249
+ )
250
+ clip_vision_encoded_end = clipvisionencode.encode(
251
+ crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0)
252
+ )
253
+
254
+ progress(0.2, desc="Preparing initial latents...")
255
+ initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED(
256
+ width=480, height=480, length=33, batch_size=1,
257
+ positive=get_value_at_index(positive_conditioning, 0),
258
+ negative=get_value_at_index(negative_conditioning, 0),
259
+ vae=get_value_at_index(vae, 0),
260
+ clip_vision_start_image=get_value_at_index(clip_vision_encoded_start, 0),
261
+ clip_vision_end_image=get_value_at_index(clip_vision_encoded_end, 0),
262
+ start_image=get_value_at_index(start_image_loaded, 0),
263
+ end_image=get_value_at_index(end_image_loaded, 0),
264
+ )
265
+
266
+ progress(0.3, desc="Patching models...")
267
+ model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_noise, 0))
268
+ model_low_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0))
269
+
270
+ model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_noise, 0))
271
+ model_high_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0))
272
+
273
+ progress(0.5, desc="Running KSampler (Step 1/2)...")
274
+ latent_step1 = ksampleradvanced.sample(
275
+ add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
276
+ sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4,
277
+ return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0),
278
+ positive=get_value_at_index(initial_latents, 0),
279
+ negative=get_value_at_index(initial_latents, 1),
280
+ latent_image=get_value_at_index(initial_latents, 2),
281
+ )
282
+
283
+ progress(0.7, desc="Running KSampler (Step 2/2)...")
284
+ latent_step2 = ksampleradvanced.sample(
285
+ add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
286
+ sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000,
287
+ return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0),
288
+ positive=get_value_at_index(initial_latents, 0),
289
+ negative=get_value_at_index(initial_latents, 1),
290
+ latent_image=get_value_at_index(latent_step1, 0),
291
+ )
292
+
293
+ progress(0.8, desc="Decoding VAE...")
294
+ decoded_images = vaedecode.decode(samples=get_value_at_index(latent_step2, 0), vae=get_value_at_index(vae, 0))
295
+
296
+ progress(0.9, desc="Creating and saving video...")
297
+ video_data = createvideo.create_video(fps=16, images=get_value_at_index(decoded_images, 0))
298
+
299
+ # Save the video to ComfyUI's output directory
300
+ save_result = savevideo.save_video(
301
+ filename_prefix="GradioVideo", format="mp4", codec="h264",
302
+ video=get_value_at_index(video_data, 0),
303
+ )
304
+
305
+ progress(1.0, desc="Done!")
306
+ return f"output/{save_result['ui']['images'][0]['filename']}"
307
+
308
+ finally:
309
+ # Clean up the temporary image files
310
+ os.unlink(start_image_path)
311
+ os.unlink(end_image_path)
312
+
313
+ # --- Gradio UI ---
314
+
315
+ def create_gradio_app():
316
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
317
+ gr.Markdown("# Image-to-Video Generation App")
318
+ gr.Markdown("Upload a start and end frame, provide a prompt, and let the AI generate a video transitioning between them.")
319
+
320
+ with gr.Row():
321
+ start_image = gr.Image(type="pil", label="Start Frame")
322
+ end_image = gr.Image(type="pil", label="End Frame")
323
 
324
+ prompt = gr.Textbox(label="Prompt", value="the guy turns")
325
+ negative_prompt = gr.Textbox(
326
+ label="Negative Prompt",
327
+ value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,"
328
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
+ generate_button = gr.Button("Generate Video", variant="primary")
 
 
331
 
332
+ output_video = gr.Video(label="Generated Video")
 
 
 
333
 
334
+ generate_button.click(
335
+ fn=generate_video,
336
+ inputs=[start_image, end_image, prompt, negative_prompt],
337
+ outputs=output_video
338
+ )
339
+
340
+ gr.Examples(
341
+ examples=[
342
+ ["examples/start.png", "examples/end.png", "a beautiful woman smiling"],
343
+ ["examples/start.png", "examples/end.png", "a robot walking through a futuristic city"],
344
+ ],
345
+ inputs=[start_image, end_image, prompt],
346
+ outputs=output_video,
347
+ fn=generate_video,
348
+ cache_examples=False, # Set to True if you want to pre-compute examples
349
+ )
350
+
351
+ return app
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
  if __name__ == "__main__":
355
+ app = create_gradio_app()
356
+ app.launch(share=True)