RuinedFooocus / modules /wan_video_pipeline.py
malizec's picture
Upload folder using huggingface_hub
2de3774 verified
import numpy as np
import os
import torch
import einops
import traceback
import cv2
import modules.async_worker as worker
from modules.util import generate_temp_filename
from PIL import Image
import os
from comfy.model_base import WAN21
import shared
from shared import path_manager, settings
from pathlib import Path
import random
from modules.pipleline_utils import (
clean_prompt_cond_caches,
get_previewer,
)
import comfy.utils
import comfy.model_management
from comfy.sd import load_checkpoint_guess_config
from calcuis_gguf.pig import load_gguf_sd, GGMLOps, GGUFModelPatcher
from nodes import (
CLIPTextEncode,
VAEDecodeTiled,
)
from comfy_extras.nodes_hunyuan import EmptyHunyuanLatentVideo
from comfy_extras.nodes_wan import WanImageToVideo
from comfy_extras.nodes_model_advanced import ModelSamplingSD3
class pipeline:
pipeline_type = ["wan_video"]
class StableDiffusionModel:
def __init__(self, unet, vae, clip, clip_vision):
self.unet = unet
self.vae = vae
self.clip = clip
self.clip_vision = clip_vision
def to_meta(self):
if self.unet is not None:
self.unet.model.to("meta")
if self.clip is not None:
self.clip.cond_stage_model.to("meta")
if self.vae is not None:
self.vae.first_stage_model.to("meta")
model_hash = ""
model_base = None
model_hash_patched = ""
model_base_patched = None
conditions = None
ggml_ops = GGMLOps()
# Optional function
def parse_gen_data(self, gen_data):
gen_data["original_image_number"] = 1 + ((int(gen_data["image_number"] / 4.0) + 1) * 4)
gen_data["image_number"] = 1
return gen_data
def load_base_model(self, name, unet_only=True): # Wan_Video never has the clip and vae models?
# Check if model is already loaded
if self.model_hash == name:
return
self.model_base = None
self.model_hash = ""
self.model_base_patched = None
self.model_hash_patched = ""
self.conditions = None
filename = str(shared.models.get_file("checkpoints", name))
print(f"Loading WAN video {'unet' if unet_only else 'model'}: {name}")
if filename.endswith(".gguf") or unet_only:
with torch.torch.inference_mode():
try:
if filename.endswith(".gguf"):
sd = load_gguf_sd(filename)
unet = comfy.sd.load_diffusion_model_state_dict(
sd, model_options={"custom_operations": self.ggml_ops}
)
unet = GGUFModelPatcher.clone(unet)
unet.patch_on_device = True
else:
model_options = {}
model_options["dtype"] = torch.float8_e4m3fn # FIXME should be a setting
unet = comfy.sd.load_diffusion_model(filename, model_options=model_options)
clip_paths = []
clip_names = []
if isinstance(unet.model, WAN21):
clip_name = settings.default_settings.get("clip_umt5", "umt5_xxl_fp8_e4m3fn_scaled.safetensors")
clip_names.append(str(clip_name))
clip_path = path_manager.get_folder_file_path(
"clip",
clip_name,
default = os.path.join(path_manager.model_paths["clip_path"], clip_name)
)
clip_paths.append(str(clip_path))
clip_type = comfy.sd.CLIPType.WAN
# https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged
vae_name = settings.default_settings.get("vae_wan", "wan_2.1_vae.safetensors")
else:
print(f"ERROR: Not a Wan Video model?")
unet = None
return
print(f"Loading CLIP: {clip_names}")
clip = comfy.sd.load_clip(ckpt_paths=clip_paths, clip_type=clip_type, model_options={})
vae_path = path_manager.get_folder_file_path(
"vae",
vae_name,
default = os.path.join(path_manager.model_paths["vae_path"], vae_name)
)
print(f"Loading VAE: {vae_name}")
sd = comfy.utils.load_torch_file(str(vae_path))
vae = comfy.sd.VAE(sd=sd)
clip_vision_name = settings.default_settings.get("clip_vision", "clip_vision_h_fp8_e4m3fn.safetensors")
clip_vision_path = path_manager.get_folder_file_path(
"clip_vision",
clip_vision_name,
default = os.path.join(path_manager.model_paths["clip_vision_path"], clip_vision_name)
)
print(f"Loading CLIP Vision: {clip_vision_name}")
sd = comfy.utils.load_torch_file(str(clip_vision_path))
if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
clip_vision = comfy.clip_vision.load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
else:
clip_vision = comfy.clip_vision.load_clipvision_from_sd(sd=sd)
except Exception as e:
unet = None
traceback.print_exc()
else:
try:
with torch.torch.inference_mode():
unet, clip, vae, clip_vision = load_checkpoint_guess_config(filename)
if clip == None or vae == None:
raise
except:
print(f"Failed. Trying to load as unet.")
self.load_base_model(
filename,
unet_only=True
)
return
if unet == None:
print(f"Failed to load {name}")
self.model_base = None
self.model_hash = ""
else:
self.model_base = self.StableDiffusionModel(
unet=unet, clip=clip, vae=vae, clip_vision=clip_vision
)
if not (
isinstance(self.model_base.unet.model, WAN21)
):
print(
f"Model {type(self.model_base.unet.model)} not supported. Expected Wan Video model."
)
self.model_base = None
if self.model_base is not None:
self.model_hash = name
print(f"Base model loaded: {self.model_hash}")
return
def load_keywords(self, lora):
filename = lora.replace(".safetensors", ".txt")
try:
with open(filename, "r") as file:
data = file.read()
return data
except FileNotFoundError:
return " "
def load_loras(self, loras):
loaded_loras = []
model = self.model_base
for name, weight in loras:
if name == "None" or weight == 0:
continue
filename = str(shared.models.get_file("loras", name))
print(f"Loading LoRAs: {name}")
try:
lora = comfy.utils.load_torch_file(filename, safe_load=True)
unet, clip = comfy.sd.load_lora_for_models(
model.unet, model.clip, lora, weight, weight
)
model = self.StableDiffusionModel(
unet=unet,
clip=clip,
vae=model.vae,
clip_vision=model.clip_vision,
)
loaded_loras += [(name, weight)]
except:
pass
self.model_base_patched = model
self.model_hash_patched = str(loras)
print(f"LoRAs loaded: {loaded_loras}")
return
def refresh_controlnet(self, name=None):
return
def clean_prompt_cond_caches(self):
return
conditions = None
def textencode(self, id, text, clip_skip):
update = False
hash = f"{text} {clip_skip}"
if hash != self.conditions[id]["text"]:
self.conditions[id]["cache"] = CLIPTextEncode().encode(
clip=self.model_base_patched.clip, text=text
)[0]
self.conditions[id]["text"] = hash
update = True
return update
@torch.no_grad()
def vae_decode_fake(self, latents):
# FIXME: This should probably just be import from comfyui
latent_rgb_factors = [
[-0.1299, -0.1692, 0.2932],
[ 0.0671, 0.0406, 0.0442],
[ 0.3568, 0.2548, 0.1747],
[ 0.0372, 0.2344, 0.1420],
[ 0.0313, 0.0189, -0.0328],
[ 0.0296, -0.0956, -0.0665],
[-0.3477, -0.4059, -0.2925],
[ 0.0166, 0.1902, 0.1975],
[-0.0412, 0.0267, -0.1364],
[-0.1293, 0.0740, 0.1636],
[ 0.0680, 0.3019, 0.1128],
[ 0.0032, 0.0581, 0.0639],
[-0.1251, 0.0927, 0.1699],
[ 0.0060, -0.0633, 0.0005],
[ 0.3477, 0.2275, 0.2950],
[ 0.1984, 0.0913, 0.1861]
]
latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360]
weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
images = images.clamp(0.0, 1.0)
return images
@torch.inference_mode()
def process(
self,
gen_data=None,
callback=None,
):
shared.state["preview_total"] = 1
seed = gen_data["seed"] if isinstance(gen_data["seed"], int) else random.randint(1, 2**32)
if callback is not None:
worker.add_result(
gen_data["task_id"],
"preview",
(-1, f"Processing text encoding ...", "html/generate_video.jpeg")
)
if self.conditions is None:
self.conditions = clean_prompt_cond_caches()
positive_prompt = gen_data["positive_prompt"]
negative_prompt = gen_data["negative_prompt"]
clip_skip = 1
self.textencode("+", positive_prompt, clip_skip)
self.textencode("-", negative_prompt, clip_skip)
pbar = comfy.utils.ProgressBar(gen_data["steps"])
def callback_function(step, x0, x, total_steps):
y = self.vae_decode_fake(x0)
y = (y * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
y = einops.rearrange(y, 'b c t h w -> (b h) (t w) c')
# Skip callback() since we'll just confuse the preview grid and push updates outselves
status = "Generating video"
maxw = 1920
maxh = 1080
image = Image.fromarray(y)
ow, oh = image.size
scale = min(maxh / oh, maxw / ow)
image = image.resize((int(ow * scale), int(oh * scale)), Image.LANCZOS)
worker.add_result(
gen_data["task_id"],
"preview",
(
int(100 * (step / total_steps)),
f"{status} - {step}/{total_steps}",
image
)
)
# pbar.update_absolute(step + 1, total_steps, None)
# ModelSamplingSD3
model_sampling = ModelSamplingSD3().patch(
model = self.model_base_patched.unet,
shift = 8.0,
)[0]
# t2v or i2v?
if gen_data["input_image"]:
image = np.array(gen_data["input_image"]).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
clip_vision_output = self.model_base_patched.clip_vision.encode_image(image)
(positive, negative, latent_image) = WanImageToVideo().encode(
positive = self.conditions["+"]["cache"],
negative = self.conditions["-"]["cache"],
vae = self.model_base_patched.vae,
width = gen_data["width"],
height = gen_data["height"],
length = gen_data["original_image_number"],
batch_size = 1,
start_image = image,
clip_vision_output = clip_vision_output,
)
else:
# latent_image
latent_image = EmptyHunyuanLatentVideo().generate(
width = gen_data["width"],
height = gen_data["height"],
length = gen_data["original_image_number"],
batch_size = 1,
)[0]
positive = self.conditions["+"]["cache"]
negative = self.conditions["-"]["cache"]
worker.add_result(
gen_data["task_id"],
"preview",
(-1, f"Generating ...", "html/generate_video.jpeg")
)
noise = comfy.sample.prepare_noise(latent_image["samples"], seed)
sampled = comfy.sample.sample(
model = model_sampling,
noise = noise,
steps = gen_data["steps"],
cfg = gen_data["cfg"],
sampler_name = gen_data["sampler_name"],
scheduler = gen_data["scheduler"],
positive = positive,
negative = negative,
latent_image = latent_image["samples"],
denoise = 1,
callback = callback_function,
)
if callback is not None:
worker.add_result(
gen_data["task_id"],
"preview",
(-1, f"VAE Decoding ...", None)
)
latent_image["samples"] = sampled
decoded_latent = VAEDecodeTiled().decode(
samples=latent_image,
tile_size=128,
overlap=64,
vae=self.model_base_patched.vae,
)[0]
pil_images = []
for image in decoded_latent:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
if callback is not None:
worker.add_result(
gen_data["task_id"],
"preview",
(-1, f"Saving ...", None)
)
file = generate_temp_filename(
folder=path_manager.model_paths["temp_outputs_path"], extension="gif"
)
os.makedirs(os.path.dirname(file), exist_ok=True)
fps=12.0
compress_level=4 # Min = 0, Max = 9
# Save GIF
pil_images[0].save(
file,
compress_level=compress_level,
save_all=True,
duration=int(1000.0/fps),
append_images=pil_images[1:],
optimize=True,
loop=0,
)
# Save mp4
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
mp4_file = file.with_suffix(".mp4")
out = cv2.VideoWriter(mp4_file, fourcc, fps, (gen_data["width"], gen_data["height"]))
for frame in pil_images:
out.write(cv2.cvtColor(np.asarray(frame), cv2.COLOR_BGR2RGB))
out.release()
return [file]