Spaces:
Runtime error
Runtime error
| import glob | |
| import logging | |
| import os | |
| import re | |
| from functools import partial | |
| from itertools import chain | |
| from os import PathLike | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, List, Union | |
| import numpy as np | |
| import torch | |
| from controlnet_aux import LineartAnimeDetector | |
| from controlnet_aux.processor import MODELS | |
| from controlnet_aux.processor import Processor as ControlnetPreProcessor | |
| from controlnet_aux.util import HWC3, ade_palette | |
| from controlnet_aux.util import resize_image as aux_resize_image | |
| from diffusers import (AutoencoderKL, ControlNetModel, DiffusionPipeline, | |
| EulerDiscreteScheduler, | |
| StableDiffusionControlNetImg2ImgPipeline, | |
| StableDiffusionPipeline, StableDiffusionXLPipeline) | |
| from PIL import Image | |
| from torchvision.datasets.folder import IMG_EXTENSIONS | |
| from tqdm.rich import tqdm | |
| from transformers import (AutoImageProcessor, CLIPImageProcessor, | |
| CLIPTextConfig, CLIPTextModel, | |
| CLIPTextModelWithProjection, CLIPTokenizer, | |
| UperNetForSemanticSegmentation) | |
| from animatediff import get_dir | |
| from animatediff.dwpose import DWposeDetector | |
| from animatediff.models.clip import CLIPSkipTextModel | |
| from animatediff.models.unet import UNet3DConditionModel | |
| from animatediff.pipelines import AnimationPipeline, load_text_embeddings | |
| from animatediff.pipelines.lora import load_lcm_lora, load_lora_map | |
| from animatediff.pipelines.pipeline_controlnet_img2img_reference import \ | |
| StableDiffusionControlNetImg2ImgReferencePipeline | |
| from animatediff.schedulers import DiffusionScheduler, get_scheduler | |
| from animatediff.settings import InferenceConfig, ModelConfig | |
| from animatediff.utils.control_net_lllite import (ControlNetLLLite, | |
| load_controlnet_lllite) | |
| from animatediff.utils.convert_from_ckpt import convert_ldm_vae_checkpoint | |
| from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora | |
| from animatediff.utils.model import (ensure_motion_modules, | |
| get_checkpoint_weights, | |
| get_checkpoint_weights_sdxl) | |
| from animatediff.utils.util import (get_resized_image, get_resized_image2, | |
| get_resized_images, | |
| get_tensor_interpolation_method, | |
| prepare_dwpose, prepare_extra_controlnet, | |
| prepare_ip_adapter, | |
| prepare_ip_adapter_sdxl, prepare_lcm_lora, | |
| prepare_lllite, prepare_motion_module, | |
| save_frames, save_imgs, save_video) | |
| controlnet_address_table={ | |
| "controlnet_tile" : ['lllyasviel/control_v11f1e_sd15_tile'], | |
| "controlnet_lineart_anime" : ['lllyasviel/control_v11p_sd15s2_lineart_anime'], | |
| "controlnet_ip2p" : ['lllyasviel/control_v11e_sd15_ip2p'], | |
| "controlnet_openpose" : ['lllyasviel/control_v11p_sd15_openpose'], | |
| "controlnet_softedge" : ['lllyasviel/control_v11p_sd15_softedge'], | |
| "controlnet_shuffle" : ['lllyasviel/control_v11e_sd15_shuffle'], | |
| "controlnet_depth" : ['lllyasviel/control_v11f1p_sd15_depth'], | |
| "controlnet_canny" : ['lllyasviel/control_v11p_sd15_canny'], | |
| "controlnet_inpaint" : ['lllyasviel/control_v11p_sd15_inpaint'], | |
| "controlnet_lineart" : ['lllyasviel/control_v11p_sd15_lineart'], | |
| "controlnet_mlsd" : ['lllyasviel/control_v11p_sd15_mlsd'], | |
| "controlnet_normalbae" : ['lllyasviel/control_v11p_sd15_normalbae'], | |
| "controlnet_scribble" : ['lllyasviel/control_v11p_sd15_scribble'], | |
| "controlnet_seg" : ['lllyasviel/control_v11p_sd15_seg'], | |
| "qr_code_monster_v1" : ['monster-labs/control_v1p_sd15_qrcode_monster'], | |
| "qr_code_monster_v2" : ['monster-labs/control_v1p_sd15_qrcode_monster', 'v2'], | |
| "controlnet_mediapipe_face" : ['CrucibleAI/ControlNetMediaPipeFace', "diffusion_sd15"], | |
| "animatediff_controlnet" : [None, "data/models/controlnet/animatediff_controlnet/controlnet_checkpoint.ckpt"] | |
| } | |
| # Edit this table if you want to change to another controlnet checkpoint | |
| controlnet_address_table_sdxl={ | |
| # "controlnet_openpose" : ['thibaud/controlnet-openpose-sdxl-1.0'], | |
| # "controlnet_softedge" : ['SargeZT/controlnet-sd-xl-1.0-softedge-dexined'], | |
| # "controlnet_depth" : ['diffusers/controlnet-depth-sdxl-1.0-small'], | |
| # "controlnet_canny" : ['diffusers/controlnet-canny-sdxl-1.0-small'], | |
| # "controlnet_seg" : ['SargeZT/sdxl-controlnet-seg'], | |
| "qr_code_monster_v1" : ['monster-labs/control_v1p_sdxl_qrcode_monster'], | |
| } | |
| # Edit this table if you want to change to another lllite checkpoint | |
| lllite_address_table_sdxl={ | |
| "controlnet_tile" : ['models/lllite/bdsqlsz_controlllite_xl_tile_anime_β.safetensors'], | |
| "controlnet_lineart_anime" : ['models/lllite/bdsqlsz_controlllite_xl_lineart_anime_denoise.safetensors'], | |
| # "controlnet_ip2p" : ('lllyasviel/control_v11e_sd15_ip2p'), | |
| "controlnet_openpose" : ['models/lllite/bdsqlsz_controlllite_xl_dw_openpose.safetensors'], | |
| # "controlnet_openpose" : ['models/lllite/controllllite_v01032064e_sdxl_pose_anime.safetensors'], | |
| "controlnet_softedge" : ['models/lllite/bdsqlsz_controlllite_xl_softedge.safetensors'], | |
| "controlnet_shuffle" : ['models/lllite/bdsqlsz_controlllite_xl_t2i-adapter_color_shuffle.safetensors'], | |
| "controlnet_depth" : ['models/lllite/bdsqlsz_controlllite_xl_depth.safetensors'], | |
| "controlnet_canny" : ['models/lllite/bdsqlsz_controlllite_xl_canny.safetensors'], | |
| # "controlnet_canny" : ['models/lllite/controllllite_v01032064e_sdxl_canny.safetensors'], | |
| # "controlnet_inpaint" : ('lllyasviel/control_v11p_sd15_inpaint'), | |
| # "controlnet_lineart" : ('lllyasviel/control_v11p_sd15_lineart'), | |
| "controlnet_mlsd" : ['models/lllite/bdsqlsz_controlllite_xl_mlsd_V2.safetensors'], | |
| "controlnet_normalbae" : ['models/lllite/bdsqlsz_controlllite_xl_normal.safetensors'], | |
| "controlnet_scribble" : ['models/lllite/bdsqlsz_controlllite_xl_sketch.safetensors'], | |
| "controlnet_seg" : ['models/lllite/bdsqlsz_controlllite_xl_segment_animeface_V2.safetensors'], | |
| # "qr_code_monster_v1" : ['monster-labs/control_v1p_sdxl_qrcode_monster'], | |
| # "qr_code_monster_v2" : ('monster-labs/control_v1p_sd15_qrcode_monster', 'v2'), | |
| # "controlnet_mediapipe_face" : ('CrucibleAI/ControlNetMediaPipeFace', "diffusion_sd15"), | |
| } | |
| try: | |
| import onnxruntime | |
| onnxruntime_installed = True | |
| except: | |
| onnxruntime_installed = False | |
| logger = logging.getLogger(__name__) | |
| data_dir = get_dir("data") | |
| default_base_path = data_dir.joinpath("models/huggingface/stable-diffusion-v1-5") | |
| re_clean_prompt = re.compile(r"[^\w\-, ]") | |
| controlnet_preprocessor = {} | |
| def load_safetensors_lora(text_encoder, unet, lora_path, alpha=0.75, is_animatediff=True): | |
| from safetensors.torch import load_file | |
| from animatediff.utils.lora_diffusers import (LoRANetwork, | |
| create_network_from_weights) | |
| sd = load_file(lora_path) | |
| print(f"create LoRA network") | |
| lora_network: LoRANetwork = create_network_from_weights(text_encoder, unet, sd, multiplier=alpha, is_animatediff=is_animatediff) | |
| print(f"load LoRA network weights") | |
| lora_network.load_state_dict(sd, False) | |
| #lora_network.merge_to(alpha) | |
| lora_network.apply_to(alpha) | |
| return lora_network | |
| def load_safetensors_lora2(text_encoder, unet, lora_path, alpha=0.75, is_animatediff=True): | |
| from safetensors.torch import load_file | |
| from animatediff.utils.lora_diffusers import (LoRANetwork, | |
| create_network_from_weights) | |
| sd = load_file(lora_path) | |
| print(f"create LoRA network") | |
| lora_network: LoRANetwork = create_network_from_weights(text_encoder, unet, sd, multiplier=alpha, is_animatediff=is_animatediff) | |
| print(f"load LoRA network weights") | |
| lora_network.load_state_dict(sd, False) | |
| lora_network.merge_to(alpha) | |
| def load_tensors(path:Path,framework="pt",device="cpu"): | |
| tensors = {} | |
| if path.suffix == ".safetensors": | |
| from safetensors import safe_open | |
| with safe_open(path, framework=framework, device=device) as f: | |
| for k in f.keys(): | |
| tensors[k] = f.get_tensor(k) # loads the full tensor given a key | |
| else: | |
| from torch import load | |
| tensors = load(path, device) | |
| if "state_dict" in tensors: | |
| tensors = tensors["state_dict"] | |
| return tensors | |
| def load_motion_lora(unet, lora_path:Path, alpha=1.0): | |
| state_dict = load_tensors(lora_path) | |
| # directly update weight in diffusers model | |
| for key in state_dict: | |
| # only process lora down key | |
| if "up." in key: continue | |
| up_key = key.replace(".down.", ".up.") | |
| model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") | |
| model_key = model_key.replace("to_out.", "to_out.0.") | |
| layer_infos = model_key.split(".")[:-1] | |
| curr_layer = unet | |
| try: | |
| while len(layer_infos) > 0: | |
| temp_name = layer_infos.pop(0) | |
| curr_layer = curr_layer.__getattr__(temp_name) | |
| except: | |
| logger.info(f"{model_key} not found") | |
| continue | |
| weight_down = state_dict[key] | |
| weight_up = state_dict[up_key] | |
| curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) | |
| class SegPreProcessor: | |
| def __init__(self): | |
| self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") | |
| self.processor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small") | |
| def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): | |
| input_array = np.array(input_image, dtype=np.uint8) | |
| input_array = HWC3(input_array) | |
| input_array = aux_resize_image(input_array, detect_resolution) | |
| pixel_values = self.image_processor(input_array, return_tensors="pt").pixel_values | |
| with torch.no_grad(): | |
| outputs = self.processor(pixel_values.to(self.processor.device)) | |
| outputs.loss = outputs.loss.to("cpu") if outputs.loss is not None else outputs.loss | |
| outputs.logits = outputs.logits.to("cpu") if outputs.logits is not None else outputs.logits | |
| outputs.hidden_states = outputs.hidden_states.to("cpu") if outputs.hidden_states is not None else outputs.hidden_states | |
| outputs.attentions = outputs.attentions.to("cpu") if outputs.attentions is not None else outputs.attentions | |
| seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0] | |
| color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3 | |
| for label, color in enumerate(ade_palette()): | |
| color_seg[seg == label, :] = color | |
| color_seg = color_seg.astype(np.uint8) | |
| color_seg = aux_resize_image(color_seg, image_resolution) | |
| color_seg = Image.fromarray(color_seg) | |
| return color_seg | |
| class NullPreProcessor: | |
| def __call__(self, input_image, **kwargs): | |
| return input_image | |
| class BlurPreProcessor: | |
| def __call__(self, input_image, sigma=5.0, **kwargs): | |
| import cv2 | |
| input_array = np.array(input_image, dtype=np.uint8) | |
| input_array = HWC3(input_array) | |
| dst = cv2.GaussianBlur(input_array, (0, 0), sigma) | |
| return Image.fromarray(dst) | |
| class TileResamplePreProcessor: | |
| def resize(self, input_image, resolution): | |
| import cv2 | |
| H, W, C = input_image.shape | |
| H = float(H) | |
| W = float(W) | |
| k = float(resolution) / min(H, W) | |
| H *= k | |
| W *= k | |
| img = cv2.resize(input_image, (int(W), int(H)), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) | |
| return img | |
| def __call__(self, input_image, down_sampling_rate = 1.0, **kwargs): | |
| input_array = np.array(input_image, dtype=np.uint8) | |
| input_array = HWC3(input_array) | |
| H, W, C = input_array.shape | |
| target_res = min(H,W) / down_sampling_rate | |
| dst = self.resize(input_array, target_res) | |
| return Image.fromarray(dst) | |
| def is_valid_controlnet_type(type_str, is_sdxl): | |
| if not is_sdxl: | |
| return type_str in controlnet_address_table | |
| else: | |
| return (type_str in controlnet_address_table_sdxl) or (type_str in lllite_address_table_sdxl) | |
| def load_controlnet_from_file(file_path, torch_dtype): | |
| from safetensors.torch import load_file | |
| prepare_extra_controlnet() | |
| file_path = Path(file_path) | |
| if file_path.exists() and file_path.is_file(): | |
| if file_path.suffix.lower() in [".pth", ".pt", ".ckpt"]: | |
| controlnet_state_dict = torch.load(file_path, map_location="cpu", weights_only=True) | |
| elif file_path.suffix.lower() == ".safetensors": | |
| controlnet_state_dict = load_file(file_path, device="cpu") | |
| else: | |
| raise RuntimeError( | |
| f"unknown file format for controlnet weights: {file_path.suffix}" | |
| ) | |
| else: | |
| raise FileNotFoundError(f"no controlnet weights found in {file_path}") | |
| if file_path.parent.name == "animatediff_controlnet": | |
| model = ControlNetModel(cross_attention_dim=768) | |
| else: | |
| model = ControlNetModel() | |
| missing, _ = model.load_state_dict(controlnet_state_dict["state_dict"], strict=False) | |
| if len(missing) > 0: | |
| logger.info(f"ControlNetModel has missing keys: {missing}") | |
| return model.to(dtype=torch_dtype) | |
| def create_controlnet_model(pipe, type_str, is_sdxl): | |
| if not is_sdxl: | |
| if type_str in controlnet_address_table: | |
| addr = controlnet_address_table[type_str] | |
| if addr[0] != None: | |
| if len(addr) == 1: | |
| return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16) | |
| else: | |
| return ControlNetModel.from_pretrained(addr[0], subfolder=addr[1], torch_dtype=torch.float16) | |
| else: | |
| return load_controlnet_from_file(addr[1],torch_dtype=torch.float16) | |
| else: | |
| raise ValueError(f"unknown controlnet type {type_str}") | |
| else: | |
| if type_str in controlnet_address_table_sdxl: | |
| addr = controlnet_address_table_sdxl[type_str] | |
| if len(addr) == 1: | |
| return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16) | |
| else: | |
| return ControlNetModel.from_pretrained(addr[0], subfolder=addr[1], torch_dtype=torch.float16) | |
| elif type_str in lllite_address_table_sdxl: | |
| addr = lllite_address_table_sdxl[type_str] | |
| model_path = data_dir.joinpath(addr[0]) | |
| return load_controlnet_lllite(model_path, pipe, torch_dtype=torch.float16) | |
| else: | |
| raise ValueError(f"unknown controlnet type {type_str}") | |
| default_preprocessor_table={ | |
| "controlnet_lineart_anime":"lineart_anime", | |
| "controlnet_openpose": "openpose_full" if onnxruntime_installed==False else "dwpose", | |
| "controlnet_softedge":"softedge_hedsafe", | |
| "controlnet_shuffle":"shuffle", | |
| "controlnet_depth":"depth_midas", | |
| "controlnet_canny":"canny", | |
| "controlnet_lineart":"lineart_realistic", | |
| "controlnet_mlsd":"mlsd", | |
| "controlnet_normalbae":"normal_bae", | |
| "controlnet_scribble":"scribble_pidsafe", | |
| "controlnet_seg":"upernet_seg", | |
| "controlnet_mediapipe_face":"mediapipe_face", | |
| "qr_code_monster_v1":"depth_midas", | |
| "qr_code_monster_v2":"depth_midas", | |
| } | |
| def create_preprocessor_from_name(pre_type): | |
| if pre_type == "dwpose": | |
| prepare_dwpose() | |
| return DWposeDetector() | |
| elif pre_type == "upernet_seg": | |
| return SegPreProcessor() | |
| elif pre_type == "blur": | |
| return BlurPreProcessor() | |
| elif pre_type == "tile_resample": | |
| return TileResamplePreProcessor() | |
| elif pre_type == "none": | |
| return NullPreProcessor() | |
| elif pre_type in MODELS: | |
| return ControlnetPreProcessor(pre_type) | |
| else: | |
| raise ValueError(f"unknown controlnet preprocessor type {pre_type}") | |
| def create_default_preprocessor(type_str): | |
| if type_str in default_preprocessor_table: | |
| pre_type = default_preprocessor_table[type_str] | |
| else: | |
| pre_type = "none" | |
| return create_preprocessor_from_name(pre_type) | |
| def get_preprocessor(type_str, device_str, preprocessor_map): | |
| if type_str not in controlnet_preprocessor: | |
| if preprocessor_map: | |
| controlnet_preprocessor[type_str] = create_preprocessor_from_name(preprocessor_map["type"]) | |
| if type_str not in controlnet_preprocessor: | |
| controlnet_preprocessor[type_str] = create_default_preprocessor(type_str) | |
| if hasattr(controlnet_preprocessor[type_str], "processor"): | |
| if hasattr(controlnet_preprocessor[type_str].processor, "to"): | |
| if device_str: | |
| controlnet_preprocessor[type_str].processor.to(device_str) | |
| elif hasattr(controlnet_preprocessor[type_str], "to"): | |
| if device_str: | |
| controlnet_preprocessor[type_str].to(device_str) | |
| return controlnet_preprocessor[type_str] | |
| def clear_controlnet_preprocessor(type_str = None): | |
| global controlnet_preprocessor | |
| if type_str == None: | |
| for t in controlnet_preprocessor: | |
| controlnet_preprocessor[t] = None | |
| controlnet_preprocessor={} | |
| torch.cuda.empty_cache() | |
| else: | |
| controlnet_preprocessor[type_str] = None | |
| torch.cuda.empty_cache() | |
| def get_preprocessed_img(type_str, img, use_preprocessor, device_str, preprocessor_map): | |
| if use_preprocessor: | |
| param = {} | |
| if preprocessor_map: | |
| param = preprocessor_map["param"] if "param" in preprocessor_map else {} | |
| return get_preprocessor(type_str, device_str, preprocessor_map)(img, **param) | |
| else: | |
| return img | |
| def create_pipeline_sdxl( | |
| base_model: Union[str, PathLike] = default_base_path, | |
| model_config: ModelConfig = ..., | |
| infer_config: InferenceConfig = ..., | |
| use_xformers: bool = True, | |
| video_length: int = 16, | |
| motion_module_path = ..., | |
| ): | |
| from animatediff.pipelines.sdxl_animation import AnimationPipeline | |
| from animatediff.sdxl_models.unet import UNet3DConditionModel | |
| logger.info("Loading tokenizer...") | |
| tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer") | |
| logger.info("Loading text encoder...") | |
| text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(base_model, subfolder="text_encoder", torch_dtype=torch.float16) | |
| logger.info("Loading VAE...") | |
| vae: AutoencoderKL = AutoencoderKL.from_pretrained(base_model, subfolder="vae") | |
| logger.info("Loading tokenizer two...") | |
| tokenizer_two = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer_2") | |
| logger.info("Loading text encoder two...") | |
| text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_model, subfolder="text_encoder_2", torch_dtype=torch.float16) | |
| logger.info("Loading UNet...") | |
| unet: UNet3DConditionModel = UNet3DConditionModel.from_pretrained_2d( | |
| pretrained_model_path=base_model, | |
| motion_module_path=motion_module_path, | |
| subfolder="unet", | |
| unet_additional_kwargs=infer_config.unet_additional_kwargs, | |
| ) | |
| # set up scheduler | |
| sched_kwargs = infer_config.noise_scheduler_kwargs | |
| scheduler = get_scheduler(model_config.scheduler, sched_kwargs) | |
| logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})') | |
| if model_config.gradual_latent_hires_fix_map: | |
| if "enable" in model_config.gradual_latent_hires_fix_map: | |
| if model_config.gradual_latent_hires_fix_map["enable"]: | |
| if model_config.scheduler not in (DiffusionScheduler.euler_a, DiffusionScheduler.lcm): | |
| logger.warn("gradual_latent_hires_fix enable") | |
| logger.warn(f"{model_config.scheduler=}") | |
| logger.warn("If you are forced to exit with an error, change to euler_a or lcm") | |
| # Load the checkpoint weights into the pipeline | |
| if model_config.path is not None: | |
| model_path = data_dir.joinpath(model_config.path) | |
| logger.info(f"Loading weights from {model_path}") | |
| if model_path.is_file(): | |
| logger.debug("Loading from single checkpoint file") | |
| unet_state_dict, tenc_state_dict, tenc2_state_dict, vae_state_dict = get_checkpoint_weights_sdxl(model_path) | |
| elif model_path.is_dir(): | |
| logger.debug("Loading from Diffusers model directory") | |
| temp_pipeline = StableDiffusionXLPipeline.from_pretrained(model_path) | |
| unet_state_dict, tenc_state_dict, tenc2_state_dict, vae_state_dict = ( | |
| temp_pipeline.unet.state_dict(), | |
| temp_pipeline.text_encoder.state_dict(), | |
| temp_pipeline.text_encoder_2.state_dict(), | |
| temp_pipeline.vae.state_dict(), | |
| ) | |
| del temp_pipeline | |
| else: | |
| raise FileNotFoundError(f"model_path {model_path} is not a file or directory") | |
| # Load into the unet, TE, and VAE | |
| logger.info("Merging weights into UNet...") | |
| _, unet_unex = unet.load_state_dict(unet_state_dict, strict=False) | |
| if len(unet_unex) > 0: | |
| raise ValueError(f"UNet has unexpected keys: {unet_unex}") | |
| tenc_missing, _ = text_encoder.load_state_dict(tenc_state_dict, strict=False) | |
| if len(tenc_missing) > 0: | |
| raise ValueError(f"TextEncoder has missing keys: {tenc_missing}") | |
| tenc2_missing, _ = text_encoder_two.load_state_dict(tenc2_state_dict, strict=False) | |
| if len(tenc2_missing) > 0: | |
| raise ValueError(f"TextEncoder2 has missing keys: {tenc2_missing}") | |
| vae_missing, _ = vae.load_state_dict(vae_state_dict, strict=False) | |
| if len(vae_missing) > 0: | |
| raise ValueError(f"VAE has missing keys: {vae_missing}") | |
| else: | |
| logger.info("Using base model weights (no checkpoint/LoRA)") | |
| if model_config.vae_path: | |
| vae_path = data_dir.joinpath(model_config.vae_path) | |
| logger.info(f"Loading vae from {vae_path}") | |
| if vae_path.is_dir(): | |
| vae = AutoencoderKL.from_pretrained(vae_path) | |
| else: | |
| tensors = load_tensors(vae_path) | |
| tensors = convert_ldm_vae_checkpoint(tensors, vae.config) | |
| vae.load_state_dict(tensors) | |
| unet.to(torch.float16) | |
| text_encoder.to(torch.float16) | |
| text_encoder_two.to(torch.float16) | |
| del unet_state_dict | |
| del tenc_state_dict | |
| del tenc2_state_dict | |
| del vae_state_dict | |
| # enable xformers if available | |
| if use_xformers: | |
| logger.info("Enabling xformers memory-efficient attention") | |
| unet.enable_xformers_memory_efficient_attention() | |
| # motion lora | |
| for l in model_config.motion_lora_map: | |
| lora_path = data_dir.joinpath(l) | |
| logger.info(f"loading motion lora {lora_path=}") | |
| if lora_path.is_file(): | |
| logger.info(f"Loading motion lora {lora_path}") | |
| logger.info(f"alpha = {model_config.motion_lora_map[l]}") | |
| load_motion_lora(unet, lora_path, alpha=model_config.motion_lora_map[l]) | |
| else: | |
| raise ValueError(f"{lora_path=} not found") | |
| logger.info("Creating AnimationPipeline...") | |
| pipeline = AnimationPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| text_encoder_2=text_encoder_two, | |
| tokenizer=tokenizer, | |
| tokenizer_2=tokenizer_two, | |
| unet=unet, | |
| scheduler=scheduler, | |
| controlnet_map=None, | |
| ) | |
| del vae | |
| del text_encoder | |
| del text_encoder_two | |
| del tokenizer | |
| del tokenizer_two | |
| del unet | |
| torch.cuda.empty_cache() | |
| pipeline.lcm = None | |
| if model_config.lcm_map: | |
| if model_config.lcm_map["enable"]: | |
| prepare_lcm_lora() | |
| load_lcm_lora(pipeline, model_config.lcm_map, is_sdxl=True) | |
| load_lora_map(pipeline, model_config.lora_map, video_length, is_sdxl=True) | |
| pipeline.unet = pipeline.unet.half() | |
| pipeline.text_encoder = pipeline.text_encoder.half() | |
| pipeline.text_encoder_2 = pipeline.text_encoder_2.half() | |
| # Load TI embeddings | |
| pipeline.text_encoder = pipeline.text_encoder.to("cuda") | |
| pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cuda") | |
| load_text_embeddings(pipeline, is_sdxl=True) | |
| pipeline.text_encoder = pipeline.text_encoder.to("cpu") | |
| pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cpu") | |
| return pipeline | |
| def create_pipeline( | |
| base_model: Union[str, PathLike] = default_base_path, | |
| model_config: ModelConfig = ..., | |
| infer_config: InferenceConfig = ..., | |
| use_xformers: bool = True, | |
| video_length: int = 16, | |
| is_sdxl:bool = False, | |
| ) -> DiffusionPipeline: | |
| """Create an AnimationPipeline from a pretrained model. | |
| Uses the base_model argument to load or download the pretrained reference pipeline model.""" | |
| # make sure motion_module is a Path and exists | |
| logger.info("Checking motion module...") | |
| motion_module = data_dir.joinpath(model_config.motion_module) | |
| if not (motion_module.exists() and motion_module.is_file()): | |
| prepare_motion_module() | |
| if not (motion_module.exists() and motion_module.is_file()): | |
| # check for safetensors version | |
| motion_module = motion_module.with_suffix(".safetensors") | |
| if not (motion_module.exists() and motion_module.is_file()): | |
| # download from HuggingFace Hub if not found | |
| ensure_motion_modules() | |
| if not (motion_module.exists() and motion_module.is_file()): | |
| # this should never happen, but just in case... | |
| raise FileNotFoundError(f"Motion module {motion_module} does not exist or is not a file!") | |
| if is_sdxl: | |
| return create_pipeline_sdxl( | |
| base_model=base_model, | |
| model_config=model_config, | |
| infer_config=infer_config, | |
| use_xformers=use_xformers, | |
| video_length=video_length, | |
| motion_module_path=motion_module, | |
| ) | |
| logger.info("Loading tokenizer...") | |
| tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer") | |
| logger.info("Loading text encoder...") | |
| text_encoder: CLIPSkipTextModel = CLIPSkipTextModel.from_pretrained(base_model, subfolder="text_encoder") | |
| logger.info("Loading VAE...") | |
| vae: AutoencoderKL = AutoencoderKL.from_pretrained(base_model, subfolder="vae") | |
| logger.info("Loading UNet...") | |
| unet: UNet3DConditionModel = UNet3DConditionModel.from_pretrained_2d( | |
| pretrained_model_path=base_model, | |
| motion_module_path=motion_module, | |
| subfolder="unet", | |
| unet_additional_kwargs=infer_config.unet_additional_kwargs, | |
| ) | |
| feature_extractor = CLIPImageProcessor.from_pretrained(base_model, subfolder="feature_extractor") | |
| # set up scheduler | |
| if model_config.gradual_latent_hires_fix_map: | |
| if "enable" in model_config.gradual_latent_hires_fix_map: | |
| if model_config.gradual_latent_hires_fix_map["enable"]: | |
| if model_config.scheduler not in (DiffusionScheduler.euler_a, DiffusionScheduler.lcm): | |
| logger.warn("gradual_latent_hires_fix enable") | |
| logger.warn(f"{model_config.scheduler=}") | |
| logger.warn("If you are forced to exit with an error, change to euler_a or lcm") | |
| sched_kwargs = infer_config.noise_scheduler_kwargs | |
| scheduler = get_scheduler(model_config.scheduler, sched_kwargs) | |
| logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})') | |
| # Load the checkpoint weights into the pipeline | |
| if model_config.path is not None: | |
| model_path = data_dir.joinpath(model_config.path) | |
| logger.info(f"Loading weights from {model_path}") | |
| if model_path.is_file(): | |
| logger.debug("Loading from single checkpoint file") | |
| unet_state_dict, tenc_state_dict, vae_state_dict = get_checkpoint_weights(model_path) | |
| elif model_path.is_dir(): | |
| logger.debug("Loading from Diffusers model directory") | |
| temp_pipeline = StableDiffusionPipeline.from_pretrained(model_path) | |
| unet_state_dict, tenc_state_dict, vae_state_dict = ( | |
| temp_pipeline.unet.state_dict(), | |
| temp_pipeline.text_encoder.state_dict(), | |
| temp_pipeline.vae.state_dict(), | |
| ) | |
| del temp_pipeline | |
| else: | |
| raise FileNotFoundError(f"model_path {model_path} is not a file or directory") | |
| # Load into the unet, TE, and VAE | |
| logger.info("Merging weights into UNet...") | |
| _, unet_unex = unet.load_state_dict(unet_state_dict, strict=False) | |
| if len(unet_unex) > 0: | |
| raise ValueError(f"UNet has unexpected keys: {unet_unex}") | |
| tenc_missing, _ = text_encoder.load_state_dict(tenc_state_dict, strict=False) | |
| if len(tenc_missing) > 0: | |
| raise ValueError(f"TextEncoder has missing keys: {tenc_missing}") | |
| vae_missing, _ = vae.load_state_dict(vae_state_dict, strict=False) | |
| if len(vae_missing) > 0: | |
| raise ValueError(f"VAE has missing keys: {vae_missing}") | |
| else: | |
| logger.info("Using base model weights (no checkpoint/LoRA)") | |
| if model_config.vae_path: | |
| vae_path = data_dir.joinpath(model_config.vae_path) | |
| logger.info(f"Loading vae from {vae_path}") | |
| if vae_path.is_dir(): | |
| vae = AutoencoderKL.from_pretrained(vae_path) | |
| else: | |
| tensors = load_tensors(vae_path) | |
| tensors = convert_ldm_vae_checkpoint(tensors, vae.config) | |
| vae.load_state_dict(tensors) | |
| # enable xformers if available | |
| if use_xformers: | |
| logger.info("Enabling xformers memory-efficient attention") | |
| unet.enable_xformers_memory_efficient_attention() | |
| if False: | |
| # lora | |
| for l in model_config.lora_map: | |
| lora_path = data_dir.joinpath(l) | |
| if lora_path.is_file(): | |
| logger.info(f"Loading lora {lora_path}") | |
| logger.info(f"alpha = {model_config.lora_map[l]}") | |
| load_safetensors_lora(text_encoder, unet, lora_path, alpha=model_config.lora_map[l]) | |
| # motion lora | |
| for l in model_config.motion_lora_map: | |
| lora_path = data_dir.joinpath(l) | |
| logger.info(f"loading motion lora {lora_path=}") | |
| if lora_path.is_file(): | |
| logger.info(f"Loading motion lora {lora_path}") | |
| logger.info(f"alpha = {model_config.motion_lora_map[l]}") | |
| load_motion_lora(unet, lora_path, alpha=model_config.motion_lora_map[l]) | |
| else: | |
| raise ValueError(f"{lora_path=} not found") | |
| logger.info("Creating AnimationPipeline...") | |
| pipeline = AnimationPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| scheduler=scheduler, | |
| feature_extractor=feature_extractor, | |
| controlnet_map=None, | |
| ) | |
| pipeline.lcm = None | |
| if model_config.lcm_map: | |
| if model_config.lcm_map["enable"]: | |
| prepare_lcm_lora() | |
| load_lcm_lora(pipeline, model_config.lcm_map, is_sdxl=False) | |
| load_lora_map(pipeline, model_config.lora_map, video_length) | |
| # Load TI embeddings | |
| pipeline.unet = pipeline.unet.half() | |
| pipeline.text_encoder = pipeline.text_encoder.half() | |
| pipeline.text_encoder = pipeline.text_encoder.to("cuda") | |
| load_text_embeddings(pipeline) | |
| pipeline.text_encoder = pipeline.text_encoder.to("cpu") | |
| return pipeline | |
| def load_controlnet_models(pipe: DiffusionPipeline, model_config: ModelConfig = ..., is_sdxl:bool = False): | |
| # controlnet | |
| if is_sdxl: | |
| prepare_lllite() | |
| controlnet_map={} | |
| if model_config.controlnet_map: | |
| c_image_dir = data_dir.joinpath( model_config.controlnet_map["input_image_dir"] ) | |
| for c in model_config.controlnet_map: | |
| item = model_config.controlnet_map[c] | |
| if type(item) is dict: | |
| if item["enable"] == True: | |
| if is_valid_controlnet_type(c, is_sdxl): | |
| img_dir = c_image_dir.joinpath( c ) | |
| cond_imgs = sorted(glob.glob( os.path.join(img_dir, "[0-9]*.png"), recursive=False)) | |
| if len(cond_imgs) > 0: | |
| logger.info(f"loading {c=} model") | |
| controlnet_map[c] = create_controlnet_model(pipe, c , is_sdxl) | |
| else: | |
| logger.info(f"invalid controlnet type for {'sdxl' if is_sdxl else 'sd15'} : {c}") | |
| if not controlnet_map: | |
| controlnet_map = None | |
| pipe.controlnet_map = controlnet_map | |
| def unload_controlnet_models(pipe: AnimationPipeline): | |
| from animatediff.utils.util import show_gpu | |
| if pipe.controlnet_map: | |
| for c in pipe.controlnet_map: | |
| controlnet = pipe.controlnet_map[c] | |
| if isinstance(controlnet, ControlNetLLLite): | |
| controlnet.unapply_to() | |
| del controlnet | |
| #show_gpu("before uload controlnet") | |
| pipe.controlnet_map = None | |
| torch.cuda.empty_cache() | |
| #show_gpu("after unload controlnet") | |
| def create_us_pipeline( | |
| model_config: ModelConfig = ..., | |
| infer_config: InferenceConfig = ..., | |
| use_xformers: bool = True, | |
| use_controlnet_ref: bool = False, | |
| use_controlnet_tile: bool = False, | |
| use_controlnet_line_anime: bool = False, | |
| use_controlnet_ip2p: bool = False, | |
| ) -> DiffusionPipeline: | |
| # set up scheduler | |
| sched_kwargs = infer_config.noise_scheduler_kwargs | |
| scheduler = get_scheduler(model_config.scheduler, sched_kwargs) | |
| logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})') | |
| controlnet = [] | |
| if use_controlnet_tile: | |
| controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11f1e_sd15_tile') ) | |
| if use_controlnet_line_anime: | |
| controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15s2_lineart_anime') ) | |
| if use_controlnet_ip2p: | |
| controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11e_sd15_ip2p') ) | |
| if len(controlnet) == 1: | |
| controlnet = controlnet[0] | |
| elif len(controlnet) == 0: | |
| controlnet = None | |
| # Load the checkpoint weights into the pipeline | |
| pipeline:DiffusionPipeline | |
| if model_config.path is not None: | |
| model_path = data_dir.joinpath(model_config.path) | |
| logger.info(f"Loading weights from {model_path}") | |
| if model_path.is_file(): | |
| def is_empty_dir(path): | |
| import os | |
| return len(os.listdir(path)) == 0 | |
| save_path = data_dir.joinpath("models/huggingface/" + model_path.stem + "_" + str(model_path.stat().st_size)) | |
| save_path.mkdir(exist_ok=True) | |
| if save_path.is_dir() and is_empty_dir(save_path): | |
| # StableDiffusionControlNetImg2ImgPipeline.from_single_file does not exist in version 18.2 | |
| logger.debug("Loading from single checkpoint file") | |
| tmp_pipeline = StableDiffusionPipeline.from_single_file( | |
| pretrained_model_link_or_path=str(model_path.absolute()) | |
| ) | |
| tmp_pipeline.save_pretrained(save_path, safe_serialization=True) | |
| del tmp_pipeline | |
| if use_controlnet_ref: | |
| pipeline = StableDiffusionControlNetImg2ImgReferencePipeline.from_pretrained( | |
| save_path, | |
| controlnet=controlnet, | |
| local_files_only=False, | |
| load_safety_checker=False, | |
| safety_checker=None, | |
| ) | |
| else: | |
| pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
| save_path, | |
| controlnet=controlnet, | |
| local_files_only=False, | |
| load_safety_checker=False, | |
| safety_checker=None, | |
| ) | |
| elif model_path.is_dir(): | |
| logger.debug("Loading from Diffusers model directory") | |
| if use_controlnet_ref: | |
| pipeline = StableDiffusionControlNetImg2ImgReferencePipeline.from_pretrained( | |
| model_path, | |
| controlnet=controlnet, | |
| local_files_only=True, | |
| load_safety_checker=False, | |
| safety_checker=None, | |
| ) | |
| else: | |
| pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
| model_path, | |
| controlnet=controlnet, | |
| local_files_only=True, | |
| load_safety_checker=False, | |
| safety_checker=None, | |
| ) | |
| else: | |
| raise FileNotFoundError(f"model_path {model_path} is not a file or directory") | |
| else: | |
| raise ValueError("model_config.path is invalid") | |
| pipeline.scheduler = scheduler | |
| # enable xformers if available | |
| if use_xformers: | |
| logger.info("Enabling xformers memory-efficient attention") | |
| pipeline.enable_xformers_memory_efficient_attention() | |
| # lora | |
| for l in model_config.lora_map: | |
| lora_path = data_dir.joinpath(l) | |
| if lora_path.is_file(): | |
| alpha = model_config.lora_map[l] | |
| if isinstance(alpha, dict): | |
| alpha = 0.75 | |
| logger.info(f"Loading lora {lora_path}") | |
| logger.info(f"alpha = {alpha}") | |
| load_safetensors_lora2(pipeline.text_encoder, pipeline.unet, lora_path, alpha=alpha,is_animatediff=False) | |
| # Load TI embeddings | |
| pipeline.unet = pipeline.unet.half() | |
| pipeline.text_encoder = pipeline.text_encoder.half() | |
| pipeline.text_encoder = pipeline.text_encoder.to("cuda") | |
| load_text_embeddings(pipeline) | |
| pipeline.text_encoder = pipeline.text_encoder.to("cpu") | |
| return pipeline | |
| def seed_everything(seed): | |
| import random | |
| import numpy as np | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed % (2**32)) | |
| random.seed(seed) | |
| def controlnet_preprocess( | |
| controlnet_map: Dict[str, Any] = None, | |
| width: int = 512, | |
| height: int = 512, | |
| duration: int = 16, | |
| out_dir: PathLike = ..., | |
| device_str:str=None, | |
| is_sdxl:bool = False, | |
| ): | |
| if not controlnet_map: | |
| return None, None, None, None | |
| out_dir = Path(out_dir) # ensure out_dir is a Path | |
| # { 0 : { "type_str" : IMAGE, "type_str2" : IMAGE } } | |
| controlnet_image_map={} | |
| controlnet_type_map={} | |
| c_image_dir = data_dir.joinpath( controlnet_map["input_image_dir"] ) | |
| save_detectmap = controlnet_map["save_detectmap"] if "save_detectmap" in controlnet_map else True | |
| preprocess_on_gpu = controlnet_map["preprocess_on_gpu"] if "preprocess_on_gpu" in controlnet_map else True | |
| device_str = device_str if preprocess_on_gpu else None | |
| for c in controlnet_map: | |
| if c == "controlnet_ref": | |
| continue | |
| item = controlnet_map[c] | |
| processed = False | |
| if type(item) is dict: | |
| if item["enable"] == True: | |
| if is_valid_controlnet_type(c, is_sdxl): | |
| preprocessor_map = item["preprocessor"] if "preprocessor" in item else {} | |
| img_dir = c_image_dir.joinpath( c ) | |
| cond_imgs = sorted(glob.glob( os.path.join(img_dir, "[0-9]*.png"), recursive=False)) | |
| if len(cond_imgs) > 0: | |
| controlnet_type_map[c] = { | |
| "controlnet_conditioning_scale" : item["controlnet_conditioning_scale"], | |
| "control_guidance_start" : item["control_guidance_start"], | |
| "control_guidance_end" : item["control_guidance_end"], | |
| "control_scale_list" : item["control_scale_list"], | |
| "guess_mode" : item["guess_mode"] if "guess_mode" in item else False, | |
| "control_region_list" : item["control_region_list"] if "control_region_list" in item else [] | |
| } | |
| use_preprocessor = item["use_preprocessor"] if "use_preprocessor" in item else True | |
| for img_path in tqdm(cond_imgs, desc=f"Preprocessing images ({c})"): | |
| frame_no = int(Path(img_path).stem) | |
| if frame_no < duration: | |
| if frame_no not in controlnet_image_map: | |
| controlnet_image_map[frame_no] = {} | |
| controlnet_image_map[frame_no][c] = get_preprocessed_img( c, get_resized_image2(img_path, 512) , use_preprocessor, device_str, preprocessor_map) | |
| processed = True | |
| else: | |
| logger.info(f"invalid controlnet type for {'sdxl' if is_sdxl else 'sd15'} : {c}") | |
| if save_detectmap and processed: | |
| det_dir = out_dir.joinpath(f"{0:02d}_detectmap/{c}") | |
| det_dir.mkdir(parents=True, exist_ok=True) | |
| for frame_no in tqdm(controlnet_image_map, desc=f"Saving Preprocessed images ({c})"): | |
| save_path = det_dir.joinpath(f"{frame_no:08d}.png") | |
| if c in controlnet_image_map[frame_no]: | |
| controlnet_image_map[frame_no][c].save(save_path) | |
| clear_controlnet_preprocessor(c) | |
| clear_controlnet_preprocessor() | |
| controlnet_ref_map = None | |
| if "controlnet_ref" in controlnet_map: | |
| r = controlnet_map["controlnet_ref"] | |
| if r["enable"] == True: | |
| org_name = data_dir.joinpath( r["ref_image"]).stem | |
| # ref_image = get_resized_image( data_dir.joinpath( r["ref_image"] ) , width, height) | |
| ref_image = get_resized_image2( data_dir.joinpath( r["ref_image"] ) , 512) | |
| if ref_image is not None: | |
| controlnet_ref_map = { | |
| "ref_image" : ref_image, | |
| "style_fidelity" : r["style_fidelity"], | |
| "attention_auto_machine_weight" : r["attention_auto_machine_weight"], | |
| "gn_auto_machine_weight" : r["gn_auto_machine_weight"], | |
| "reference_attn" : r["reference_attn"], | |
| "reference_adain" : r["reference_adain"], | |
| "scale_pattern" : r["scale_pattern"] | |
| } | |
| if save_detectmap: | |
| det_dir = out_dir.joinpath(f"{0:02d}_detectmap/controlnet_ref") | |
| det_dir.mkdir(parents=True, exist_ok=True) | |
| save_path = det_dir.joinpath(f"{org_name}.png") | |
| ref_image.save(save_path) | |
| controlnet_no_shrink = ["controlnet_tile","animatediff_controlnet","controlnet_canny","controlnet_normalbae","controlnet_depth","controlnet_lineart","controlnet_lineart_anime","controlnet_scribble","controlnet_seg","controlnet_softedge","controlnet_mlsd"] | |
| if "no_shrink_list" in controlnet_map: | |
| controlnet_no_shrink = controlnet_map["no_shrink_list"] | |
| return controlnet_image_map, controlnet_type_map, controlnet_ref_map, controlnet_no_shrink | |
| def ip_adapter_preprocess( | |
| ip_adapter_config_map: Dict[str, Any] = None, | |
| width: int = 512, | |
| height: int = 512, | |
| duration: int = 16, | |
| out_dir: PathLike = ..., | |
| is_sdxl: bool = False, | |
| ): | |
| ip_adapter_map={} | |
| processed = False | |
| if ip_adapter_config_map: | |
| if ip_adapter_config_map["enable"] == True: | |
| resized_to_square = ip_adapter_config_map["resized_to_square"] if "resized_to_square" in ip_adapter_config_map else False | |
| image_dir = data_dir.joinpath( ip_adapter_config_map["input_image_dir"] ) | |
| imgs = sorted(chain.from_iterable([glob.glob(os.path.join(image_dir, f"[0-9]*{ext}")) for ext in IMG_EXTENSIONS])) | |
| if len(imgs) > 0: | |
| prepare_ip_adapter_sdxl() if is_sdxl else prepare_ip_adapter() | |
| ip_adapter_map["images"] = {} | |
| for img_path in tqdm(imgs, desc=f"Preprocessing images (ip_adapter)"): | |
| frame_no = int(Path(img_path).stem) | |
| if frame_no < duration: | |
| if resized_to_square: | |
| ip_adapter_map["images"][frame_no] = get_resized_image(img_path, 256, 256) | |
| else: | |
| ip_adapter_map["images"][frame_no] = get_resized_image2(img_path, 256) | |
| processed = True | |
| if processed: | |
| ip_adapter_config_map["prompt_fixed_ratio"] = max(min(1.0, ip_adapter_config_map["prompt_fixed_ratio"]),0) | |
| prompt_fixed_ratio = ip_adapter_config_map["prompt_fixed_ratio"] | |
| prompt_map = ip_adapter_map["images"] | |
| prompt_map = dict(sorted(prompt_map.items())) | |
| key_list = list(prompt_map.keys()) | |
| for k0,k1 in zip(key_list,key_list[1:]+[duration]): | |
| k05 = k0 + round((k1-k0) * prompt_fixed_ratio) | |
| if k05 == k1: | |
| k05 -= 1 | |
| if k05 != k0: | |
| prompt_map[k05] = prompt_map[k0] | |
| ip_adapter_map["images"] = prompt_map | |
| if (ip_adapter_config_map["save_input_image"] == True) and processed: | |
| det_dir = out_dir.joinpath(f"{0:02d}_ip_adapter/") | |
| det_dir.mkdir(parents=True, exist_ok=True) | |
| for frame_no in tqdm(ip_adapter_map["images"], desc=f"Saving Preprocessed images (ip_adapter)"): | |
| save_path = det_dir.joinpath(f"{frame_no:08d}.png") | |
| ip_adapter_map["images"][frame_no].save(save_path) | |
| return ip_adapter_map if processed else None | |
| def prompt_preprocess( | |
| prompt_config_map: Dict[str, Any], | |
| head_prompt: str, | |
| tail_prompt: str, | |
| prompt_fixed_ratio: float, | |
| video_length: int, | |
| ): | |
| prompt_map = {} | |
| for k in prompt_config_map.keys(): | |
| if int(k) < video_length: | |
| pr = prompt_config_map[k] | |
| if head_prompt: | |
| pr = head_prompt + "," + pr | |
| if tail_prompt: | |
| pr = pr + "," + tail_prompt | |
| prompt_map[int(k)]=pr | |
| prompt_map = dict(sorted(prompt_map.items())) | |
| key_list = list(prompt_map.keys()) | |
| for k0,k1 in zip(key_list,key_list[1:]+[video_length]): | |
| k05 = k0 + round((k1-k0) * prompt_fixed_ratio) | |
| if k05 == k1: | |
| k05 -= 1 | |
| if k05 != k0: | |
| prompt_map[k05] = prompt_map[k0] | |
| return prompt_map | |
| def region_preprocess( | |
| model_config: ModelConfig = ..., | |
| width: int = 512, | |
| height: int = 512, | |
| duration: int = 16, | |
| out_dir: PathLike = ..., | |
| is_init_img_exist: bool = False, | |
| is_sdxl:bool = False, | |
| ): | |
| is_bg_init_img = False | |
| if is_init_img_exist: | |
| if model_config.region_map: | |
| if "background" in model_config.region_map: | |
| is_bg_init_img = model_config.region_map["background"]["is_init_img"] | |
| region_condi_list=[] | |
| region2index={} | |
| condi_index = 0 | |
| prev_ip_map = None | |
| if not is_bg_init_img: | |
| ip_map = ip_adapter_preprocess( | |
| model_config.ip_adapter_map, | |
| width, | |
| height, | |
| duration, | |
| out_dir, | |
| is_sdxl | |
| ) | |
| if ip_map: | |
| prev_ip_map = ip_map | |
| condition_map = { | |
| "prompt_map": prompt_preprocess( | |
| model_config.prompt_map, | |
| model_config.head_prompt, | |
| model_config.tail_prompt, | |
| model_config.prompt_fixed_ratio, | |
| duration | |
| ), | |
| "ip_adapter_map": ip_map | |
| } | |
| region_condi_list.append( condition_map ) | |
| bg_src = condi_index | |
| condi_index += 1 | |
| else: | |
| bg_src = -1 | |
| region_list=[ | |
| { | |
| "mask_images": None, | |
| "src" : bg_src, | |
| "crop_generation_rate" : 0 | |
| } | |
| ] | |
| region2index["background"]=bg_src | |
| if model_config.region_map: | |
| for r in model_config.region_map: | |
| if r == "background": | |
| continue | |
| if model_config.region_map[r]["enable"] != True: | |
| continue | |
| region_dir = out_dir.joinpath(f"region_{int(r):05d}/") | |
| region_dir.mkdir(parents=True, exist_ok=True) | |
| mask_map = mask_preprocess( | |
| model_config.region_map[r], | |
| width, | |
| height, | |
| duration, | |
| region_dir | |
| ) | |
| if not mask_map: | |
| continue | |
| if model_config.region_map[r]["is_init_img"] == False: | |
| ip_map = ip_adapter_preprocess( | |
| model_config.region_map[r]["condition"]["ip_adapter_map"], | |
| width, | |
| height, | |
| duration, | |
| region_dir, | |
| is_sdxl | |
| ) | |
| if ip_map: | |
| prev_ip_map = ip_map | |
| condition_map={ | |
| "prompt_map": prompt_preprocess( | |
| model_config.region_map[r]["condition"]["prompt_map"], | |
| model_config.region_map[r]["condition"]["head_prompt"], | |
| model_config.region_map[r]["condition"]["tail_prompt"], | |
| model_config.region_map[r]["condition"]["prompt_fixed_ratio"], | |
| duration | |
| ), | |
| "ip_adapter_map": ip_map | |
| } | |
| region_condi_list.append( condition_map ) | |
| src = condi_index | |
| condi_index += 1 | |
| else: | |
| if is_init_img_exist == False: | |
| logger.warn("'is_init_img' : true / BUT init_img is not exist -> ignore region") | |
| continue | |
| src = -1 | |
| region_list.append( | |
| { | |
| "mask_images": mask_map, | |
| "src" : src, | |
| "crop_generation_rate" : model_config.region_map[r]["crop_generation_rate"] if "crop_generation_rate" in model_config.region_map[r] else 0 | |
| } | |
| ) | |
| region2index[r]=src | |
| ip_adapter_config_map = None | |
| if prev_ip_map is not None: | |
| ip_adapter_config_map={} | |
| ip_adapter_config_map["scale"] = model_config.ip_adapter_map["scale"] | |
| ip_adapter_config_map["is_plus"] = model_config.ip_adapter_map["is_plus"] | |
| ip_adapter_config_map["is_plus_face"] = model_config.ip_adapter_map["is_plus_face"] if "is_plus_face" in model_config.ip_adapter_map else False | |
| ip_adapter_config_map["is_light"] = model_config.ip_adapter_map["is_light"] if "is_light" in model_config.ip_adapter_map else False | |
| ip_adapter_config_map["is_full_face"] = model_config.ip_adapter_map["is_full_face"] if "is_full_face" in model_config.ip_adapter_map else False | |
| for c in region_condi_list: | |
| if c["ip_adapter_map"] == None: | |
| logger.info(f"fill map") | |
| c["ip_adapter_map"] = prev_ip_map | |
| #for c in region_condi_list: | |
| # logger.info(f"{c['prompt_map']=}") | |
| if not region_condi_list: | |
| raise ValueError("erro! There is not a single valid region") | |
| return region_condi_list, region_list, ip_adapter_config_map, region2index | |
| def img2img_preprocess( | |
| img2img_config_map: Dict[str, Any] = None, | |
| width: int = 512, | |
| height: int = 512, | |
| duration: int = 16, | |
| out_dir: PathLike = ..., | |
| ): | |
| img2img_map={} | |
| processed = False | |
| if img2img_config_map: | |
| if img2img_config_map["enable"] == True: | |
| image_dir = data_dir.joinpath( img2img_config_map["init_img_dir"] ) | |
| imgs = sorted(glob.glob( os.path.join(image_dir, "[0-9]*.png"), recursive=False)) | |
| if len(imgs) > 0: | |
| img2img_map["images"] = {} | |
| img2img_map["denoising_strength"] = img2img_config_map["denoising_strength"] | |
| for img_path in tqdm(imgs, desc=f"Preprocessing images (img2img)"): | |
| frame_no = int(Path(img_path).stem) | |
| if frame_no < duration: | |
| img2img_map["images"][frame_no] = get_resized_image(img_path, width, height) | |
| processed = True | |
| if (img2img_config_map["save_init_image"] == True) and processed: | |
| det_dir = out_dir.joinpath(f"{0:02d}_img2img_init_img/") | |
| det_dir.mkdir(parents=True, exist_ok=True) | |
| for frame_no in tqdm(img2img_map["images"], desc=f"Saving Preprocessed images (img2img)"): | |
| save_path = det_dir.joinpath(f"{frame_no:08d}.png") | |
| img2img_map["images"][frame_no].save(save_path) | |
| return img2img_map if processed else None | |
| def mask_preprocess( | |
| region_config_map: Dict[str, Any] = None, | |
| width: int = 512, | |
| height: int = 512, | |
| duration: int = 16, | |
| out_dir: PathLike = ..., | |
| ): | |
| mask_map={} | |
| processed = False | |
| size = None | |
| mode = None | |
| if region_config_map: | |
| image_dir = data_dir.joinpath( region_config_map["mask_dir"] ) | |
| imgs = sorted(glob.glob( os.path.join(image_dir, "[0-9]*.png"), recursive=False)) | |
| if len(imgs) > 0: | |
| for img_path in tqdm(imgs, desc=f"Preprocessing images (mask)"): | |
| frame_no = int(Path(img_path).stem) | |
| if frame_no < duration: | |
| mask_map[frame_no] = get_resized_image(img_path, width, height) | |
| if size is None: | |
| size = mask_map[frame_no].size | |
| mode = mask_map[frame_no].mode | |
| processed = True | |
| if processed: | |
| if 0 in mask_map: | |
| prev_img = mask_map[0] | |
| else: | |
| prev_img = Image.new(mode, size, color=0) | |
| for i in range(duration): | |
| if i in mask_map: | |
| prev_img = mask_map[i] | |
| else: | |
| mask_map[i] = prev_img | |
| if (region_config_map["save_mask"] == True) and processed: | |
| det_dir = out_dir.joinpath(f"mask/") | |
| det_dir.mkdir(parents=True, exist_ok=True) | |
| for frame_no in tqdm(mask_map, desc=f"Saving Preprocessed images (mask)"): | |
| save_path = det_dir.joinpath(f"{frame_no:08d}.png") | |
| mask_map[frame_no].save(save_path) | |
| return mask_map if processed else None | |
| def wild_card_conversion(model_config: ModelConfig = ...,): | |
| from animatediff.utils.wild_card import replace_wild_card | |
| wild_card_dir = get_dir("wildcards") | |
| for k in model_config.prompt_map.keys(): | |
| model_config.prompt_map[k] = replace_wild_card(model_config.prompt_map[k], wild_card_dir) | |
| if model_config.head_prompt: | |
| model_config.head_prompt = replace_wild_card(model_config.head_prompt, wild_card_dir) | |
| if model_config.tail_prompt: | |
| model_config.tail_prompt = replace_wild_card(model_config.tail_prompt, wild_card_dir) | |
| model_config.prompt_fixed_ratio = max(min(1.0, model_config.prompt_fixed_ratio),0) | |
| if model_config.region_map: | |
| for r in model_config.region_map: | |
| if r == "background": | |
| continue | |
| if "condition" in model_config.region_map[r]: | |
| c = model_config.region_map[r]["condition"] | |
| for k in c["prompt_map"].keys(): | |
| c["prompt_map"][k] = replace_wild_card(c["prompt_map"][k], wild_card_dir) | |
| if "head_prompt" in c: | |
| c["head_prompt"] = replace_wild_card(c["head_prompt"], wild_card_dir) | |
| if "tail_prompt" in c: | |
| c["tail_prompt"] = replace_wild_card(c["tail_prompt"], wild_card_dir) | |
| if "prompt_fixed_ratio" in c: | |
| c["prompt_fixed_ratio"] = max(min(1.0, c["prompt_fixed_ratio"]),0) | |
| def save_output( | |
| pipeline_output, | |
| frame_dir:str, | |
| out_file:str, | |
| output_map : Dict[str,Any] = {}, | |
| no_frames : bool = False, | |
| save_frames=save_frames, | |
| save_video=None, | |
| ): | |
| output_format = "gif" | |
| output_fps = 8 | |
| if output_map: | |
| output_format = output_map["format"] if "format" in output_map else output_format | |
| output_fps = output_map["fps"] if "fps" in output_map else output_fps | |
| if output_format == "mp4": | |
| output_format = "h264" | |
| if output_format == "gif": | |
| out_file = out_file.with_suffix(".gif") | |
| if no_frames is not True: | |
| if save_frames: | |
| save_frames(pipeline_output,frame_dir) | |
| # generate the output filename and save the video | |
| if save_video: | |
| save_video(pipeline_output, out_file, output_fps) | |
| else: | |
| pipeline_output[0].save( | |
| fp=out_file, format="GIF", append_images=pipeline_output[1:], save_all=True, duration=(1 / output_fps * 1000), loop=0 | |
| ) | |
| else: | |
| if save_frames: | |
| save_frames(pipeline_output,frame_dir) | |
| from animatediff.rife.ffmpeg import (FfmpegEncoder, VideoCodec, | |
| codec_extn) | |
| out_file = out_file.with_suffix( f".{codec_extn(output_format)}" ) | |
| logger.info("Creating ffmpeg encoder...") | |
| encoder = FfmpegEncoder( | |
| frames_dir=frame_dir, | |
| out_file=out_file, | |
| codec=output_format, | |
| in_fps=output_fps, | |
| out_fps=output_fps, | |
| lossless=False, | |
| param= output_map["encode_param"] if "encode_param" in output_map else {} | |
| ) | |
| logger.info("Encoding interpolated frames with ffmpeg...") | |
| result = encoder.encode() | |
| logger.debug(f"ffmpeg result: {result}") | |
| def run_inference( | |
| pipeline: DiffusionPipeline, | |
| n_prompt: str = ..., | |
| seed: int = -1, | |
| steps: int = 25, | |
| guidance_scale: float = 7.5, | |
| unet_batch_size: int = 1, | |
| width: int = 512, | |
| height: int = 512, | |
| duration: int = 16, | |
| idx: int = 0, | |
| out_dir: PathLike = ..., | |
| context_frames: int = -1, | |
| context_stride: int = 3, | |
| context_overlap: int = 4, | |
| context_schedule: str = "uniform", | |
| clip_skip: int = 1, | |
| controlnet_map: Dict[str, Any] = None, | |
| controlnet_image_map: Dict[str,Any] = None, | |
| controlnet_type_map: Dict[str,Any] = None, | |
| controlnet_ref_map: Dict[str,Any] = None, | |
| controlnet_no_shrink:List[str]=None, | |
| no_frames :bool = False, | |
| img2img_map: Dict[str,Any] = None, | |
| ip_adapter_config_map: Dict[str,Any] = None, | |
| region_list: List[Any] = None, | |
| region_condi_list: List[Any] = None, | |
| output_map: Dict[str,Any] = None, | |
| is_single_prompt_mode: bool = False, | |
| is_sdxl:bool=False, | |
| apply_lcm_lora:bool=False, | |
| gradual_latent_map: Dict[str,Any] = None, | |
| ): | |
| out_dir = Path(out_dir) # ensure out_dir is a Path | |
| # Trim and clean up the prompt for filename use | |
| prompt_map = region_condi_list[0]["prompt_map"] | |
| prompt_tags = [re_clean_prompt.sub("", tag).strip().replace(" ", "-") for tag in prompt_map[list(prompt_map.keys())[0]].split(",")] | |
| prompt_str = "_".join((prompt_tags[:6]))[:50] | |
| frame_dir = out_dir.joinpath(f"{idx:02d}-{seed}") | |
| out_file = out_dir.joinpath(f"{idx:02d}_{seed}_{prompt_str}") | |
| def preview_callback(i: int, video: torch.Tensor, save_fn: Callable[[torch.Tensor], None], out_file: str) -> None: | |
| save_fn(video, out_file=Path(f"{out_file}_preview@{i}")) | |
| save_fn = partial( | |
| save_output, | |
| frame_dir=frame_dir, | |
| output_map=output_map, | |
| no_frames=no_frames, | |
| save_frames=partial(save_frames, show_progress=False), | |
| save_video=save_video | |
| ) | |
| callback = partial(preview_callback, save_fn=save_fn, out_file=out_file) | |
| seed_everything(seed) | |
| logger.info(f"{len( region_condi_list )=}") | |
| logger.info(f"{len( region_list )=}") | |
| pipeline_output = pipeline( | |
| negative_prompt=n_prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| unet_batch_size=unet_batch_size, | |
| width=width, | |
| height=height, | |
| video_length=duration, | |
| return_dict=False, | |
| context_frames=context_frames, | |
| context_stride=context_stride + 1, | |
| context_overlap=context_overlap, | |
| context_schedule=context_schedule, | |
| clip_skip=clip_skip, | |
| controlnet_type_map=controlnet_type_map, | |
| controlnet_image_map=controlnet_image_map, | |
| controlnet_ref_map=controlnet_ref_map, | |
| controlnet_no_shrink=controlnet_no_shrink, | |
| controlnet_max_samples_on_vram=controlnet_map["max_samples_on_vram"] if "max_samples_on_vram" in controlnet_map else 999, | |
| controlnet_max_models_on_vram=controlnet_map["max_models_on_vram"] if "max_models_on_vram" in controlnet_map else 99, | |
| controlnet_is_loop = controlnet_map["is_loop"] if "is_loop" in controlnet_map else True, | |
| img2img_map=img2img_map, | |
| ip_adapter_config_map=ip_adapter_config_map, | |
| region_list=region_list, | |
| region_condi_list=region_condi_list, | |
| interpolation_factor=1, | |
| is_single_prompt_mode=is_single_prompt_mode, | |
| apply_lcm_lora=apply_lcm_lora, | |
| gradual_latent_map=gradual_latent_map, | |
| callback=callback, | |
| callback_steps=output_map.get("preview_steps"), | |
| ) | |
| logger.info("Generation complete, saving...") | |
| save_fn(pipeline_output, out_file=out_file) | |
| logger.info(f"Saved sample to {out_file}") | |
| return pipeline_output | |
| def run_upscale( | |
| org_imgs: List[str], | |
| pipeline: DiffusionPipeline, | |
| prompt_map: Dict[int, str] = None, | |
| n_prompt: str = ..., | |
| seed: int = -1, | |
| steps: int = 25, | |
| strength: float = 0.5, | |
| guidance_scale: float = 7.5, | |
| clip_skip: int = 1, | |
| us_width: int = 512, | |
| us_height: int = 512, | |
| idx: int = 0, | |
| out_dir: PathLike = ..., | |
| upscale_config:Dict[str, Any]=None, | |
| use_controlnet_ref: bool = False, | |
| use_controlnet_tile: bool = False, | |
| use_controlnet_line_anime: bool = False, | |
| use_controlnet_ip2p: bool = False, | |
| no_frames:bool = False, | |
| output_map: Dict[str,Any] = None, | |
| ): | |
| from animatediff.utils.lpw_stable_diffusion import lpw_encode_prompt | |
| pipeline.set_progress_bar_config(disable=True) | |
| images = get_resized_images(org_imgs, us_width, us_height) | |
| steps = steps if "steps" not in upscale_config else upscale_config["steps"] | |
| scheduler = scheduler if "scheduler" not in upscale_config else upscale_config["scheduler"] | |
| guidance_scale = guidance_scale if "guidance_scale" not in upscale_config else upscale_config["guidance_scale"] | |
| clip_skip = clip_skip if "clip_skip" not in upscale_config else upscale_config["clip_skip"] | |
| strength = strength if "strength" not in upscale_config else upscale_config["strength"] | |
| controlnet_conditioning_scale = [] | |
| guess_mode = [] | |
| control_guidance_start = [] | |
| control_guidance_end = [] | |
| # for controlnet tile | |
| if use_controlnet_tile: | |
| controlnet_conditioning_scale.append(upscale_config["controlnet_tile"]["controlnet_conditioning_scale"]) | |
| guess_mode.append(upscale_config["controlnet_tile"]["guess_mode"]) | |
| control_guidance_start.append(upscale_config["controlnet_tile"]["control_guidance_start"]) | |
| control_guidance_end.append(upscale_config["controlnet_tile"]["control_guidance_end"]) | |
| # for controlnet line_anime | |
| if use_controlnet_line_anime: | |
| controlnet_conditioning_scale.append(upscale_config["controlnet_line_anime"]["controlnet_conditioning_scale"]) | |
| guess_mode.append(upscale_config["controlnet_line_anime"]["guess_mode"]) | |
| control_guidance_start.append(upscale_config["controlnet_line_anime"]["control_guidance_start"]) | |
| control_guidance_end.append(upscale_config["controlnet_line_anime"]["control_guidance_end"]) | |
| # for controlnet ip2p | |
| if use_controlnet_ip2p: | |
| controlnet_conditioning_scale.append(upscale_config["controlnet_ip2p"]["controlnet_conditioning_scale"]) | |
| guess_mode.append(upscale_config["controlnet_ip2p"]["guess_mode"]) | |
| control_guidance_start.append(upscale_config["controlnet_ip2p"]["control_guidance_start"]) | |
| control_guidance_end.append(upscale_config["controlnet_ip2p"]["control_guidance_end"]) | |
| # for controlnet ref | |
| ref_image = None | |
| if use_controlnet_ref: | |
| if not upscale_config["controlnet_ref"]["use_frame_as_ref_image"] and not upscale_config["controlnet_ref"]["use_1st_frame_as_ref_image"]: | |
| ref_image = get_resized_images([ data_dir.joinpath( upscale_config["controlnet_ref"]["ref_image"] ) ], us_width, us_height)[0] | |
| generator = torch.manual_seed(seed) | |
| seed_everything(seed) | |
| prompt_embeds_map = {} | |
| prompt_map = dict(sorted(prompt_map.items())) | |
| negative = None | |
| do_classifier_free_guidance=guidance_scale > 1.0 | |
| prompt_list = [prompt_map[key_frame] for key_frame in prompt_map.keys()] | |
| prompt_embeds,neg_embeds = lpw_encode_prompt( | |
| pipe=pipeline, | |
| prompt=prompt_list, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| negative_prompt=n_prompt, | |
| ) | |
| if do_classifier_free_guidance: | |
| negative = neg_embeds.chunk(neg_embeds.shape[0], 0) | |
| positive = prompt_embeds.chunk(prompt_embeds.shape[0], 0) | |
| else: | |
| negative = [None] | |
| positive = prompt_embeds.chunk(prompt_embeds.shape[0], 0) | |
| for i, key_frame in enumerate(prompt_map): | |
| prompt_embeds_map[key_frame] = positive[i] | |
| key_first =list(prompt_map.keys())[0] | |
| key_last =list(prompt_map.keys())[-1] | |
| def get_current_prompt_embeds( | |
| center_frame: int = 0, | |
| video_length : int = 0 | |
| ): | |
| key_prev = key_last | |
| key_next = key_first | |
| for p in prompt_map.keys(): | |
| if p > center_frame: | |
| key_next = p | |
| break | |
| key_prev = p | |
| dist_prev = center_frame - key_prev | |
| if dist_prev < 0: | |
| dist_prev += video_length | |
| dist_next = key_next - center_frame | |
| if dist_next < 0: | |
| dist_next += video_length | |
| if key_prev == key_next or dist_prev + dist_next == 0: | |
| return prompt_embeds_map[key_prev] | |
| rate = dist_prev / (dist_prev + dist_next) | |
| return get_tensor_interpolation_method()(prompt_embeds_map[key_prev],prompt_embeds_map[key_next], rate) | |
| line_anime_processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") | |
| out_images=[] | |
| logger.info(f"{use_controlnet_tile=}") | |
| logger.info(f"{use_controlnet_line_anime=}") | |
| logger.info(f"{use_controlnet_ip2p=}") | |
| logger.info(f"{controlnet_conditioning_scale=}") | |
| logger.info(f"{guess_mode=}") | |
| logger.info(f"{control_guidance_start=}") | |
| logger.info(f"{control_guidance_end=}") | |
| for i, org_image in enumerate(tqdm(images, desc=f"Upscaling...")): | |
| cur_positive = get_current_prompt_embeds(i, len(images)) | |
| # logger.info(f"w {condition_image.size[0]}") | |
| # logger.info(f"h {condition_image.size[1]}") | |
| condition_image = [] | |
| if use_controlnet_tile: | |
| condition_image.append( org_image ) | |
| if use_controlnet_line_anime: | |
| condition_image.append( line_anime_processor(org_image) ) | |
| if use_controlnet_ip2p: | |
| condition_image.append( org_image ) | |
| if not use_controlnet_ref: | |
| out_image = pipeline( | |
| prompt_embeds=cur_positive, | |
| negative_prompt_embeds=negative[0], | |
| image=org_image, | |
| control_image=condition_image, | |
| width=org_image.size[0], | |
| height=org_image.size[1], | |
| strength=strength, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| controlnet_conditioning_scale= controlnet_conditioning_scale if len(controlnet_conditioning_scale) > 1 else controlnet_conditioning_scale[0], | |
| guess_mode= guess_mode[0], | |
| control_guidance_start= control_guidance_start if len(control_guidance_start) > 1 else control_guidance_start[0], | |
| control_guidance_end= control_guidance_end if len(control_guidance_end) > 1 else control_guidance_end[0], | |
| ).images[0] | |
| else: | |
| if upscale_config["controlnet_ref"]["use_1st_frame_as_ref_image"]: | |
| if i == 0: | |
| ref_image = org_image | |
| elif upscale_config["controlnet_ref"]["use_frame_as_ref_image"]: | |
| ref_image = org_image | |
| out_image = pipeline( | |
| prompt_embeds=cur_positive, | |
| negative_prompt_embeds=negative[0], | |
| image=org_image, | |
| control_image=condition_image, | |
| width=org_image.size[0], | |
| height=org_image.size[1], | |
| strength=strength, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| controlnet_conditioning_scale= controlnet_conditioning_scale if len(controlnet_conditioning_scale) > 1 else controlnet_conditioning_scale[0], | |
| guess_mode= guess_mode[0], | |
| # control_guidance_start= control_guidance_start, | |
| # control_guidance_end= control_guidance_end, | |
| ### for controlnet ref | |
| ref_image=ref_image, | |
| attention_auto_machine_weight = upscale_config["controlnet_ref"]["attention_auto_machine_weight"], | |
| gn_auto_machine_weight = upscale_config["controlnet_ref"]["gn_auto_machine_weight"], | |
| style_fidelity = upscale_config["controlnet_ref"]["style_fidelity"], | |
| reference_attn= upscale_config["controlnet_ref"]["reference_attn"], | |
| reference_adain= upscale_config["controlnet_ref"]["reference_adain"], | |
| ).images[0] | |
| out_images.append(out_image) | |
| # Trim and clean up the prompt for filename use | |
| prompt_tags = [re_clean_prompt.sub("", tag).strip().replace(" ", "-") for tag in prompt_map[list(prompt_map.keys())[0]].split(",")] | |
| prompt_str = "_".join((prompt_tags[:6]))[:50] | |
| # generate the output filename and save the video | |
| out_file = out_dir.joinpath(f"{idx:02d}_{seed}_{prompt_str}") | |
| frame_dir = out_dir.joinpath(f"{idx:02d}-{seed}-upscaled") | |
| save_output( out_images, frame_dir, out_file, output_map, no_frames, save_imgs, None ) | |
| logger.info(f"Saved sample to {out_file}") | |
| return out_images | |