|
import glob |
|
import logging |
|
import os |
|
import re |
|
from functools import partial |
|
from itertools import chain |
|
from os import PathLike |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, List, Union |
|
|
|
import numpy as np |
|
import torch |
|
from controlnet_aux import LineartAnimeDetector |
|
from controlnet_aux.processor import MODELS |
|
from controlnet_aux.processor import Processor as ControlnetPreProcessor |
|
from controlnet_aux.util import HWC3, ade_palette |
|
from controlnet_aux.util import resize_image as aux_resize_image |
|
from diffusers import (AutoencoderKL, ControlNetModel, DiffusionPipeline, |
|
EulerDiscreteScheduler, |
|
StableDiffusionControlNetImg2ImgPipeline, |
|
StableDiffusionPipeline, StableDiffusionXLPipeline) |
|
from PIL import Image |
|
from torchvision.datasets.folder import IMG_EXTENSIONS |
|
from tqdm.rich import tqdm |
|
from transformers import (AutoImageProcessor, CLIPImageProcessor, |
|
CLIPTextConfig, CLIPTextModel, |
|
CLIPTextModelWithProjection, CLIPTokenizer, |
|
UperNetForSemanticSegmentation) |
|
|
|
from animatediff import get_dir |
|
from animatediff.dwpose import DWposeDetector |
|
from animatediff.models.clip import CLIPSkipTextModel |
|
from animatediff.models.unet import UNet3DConditionModel |
|
from animatediff.pipelines import AnimationPipeline, load_text_embeddings |
|
from animatediff.pipelines.lora import load_lcm_lora, load_lora_map |
|
from animatediff.pipelines.pipeline_controlnet_img2img_reference import \ |
|
StableDiffusionControlNetImg2ImgReferencePipeline |
|
from animatediff.schedulers import DiffusionScheduler, get_scheduler |
|
from animatediff.settings import InferenceConfig, ModelConfig |
|
from animatediff.utils.control_net_lllite import (ControlNetLLLite, |
|
load_controlnet_lllite) |
|
from animatediff.utils.convert_from_ckpt import convert_ldm_vae_checkpoint |
|
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora |
|
from animatediff.utils.model import (ensure_motion_modules, |
|
get_checkpoint_weights, |
|
get_checkpoint_weights_sdxl) |
|
from animatediff.utils.util import (get_resized_image, get_resized_image2, |
|
get_resized_images, |
|
get_tensor_interpolation_method, |
|
prepare_dwpose, prepare_extra_controlnet, |
|
prepare_ip_adapter, |
|
prepare_ip_adapter_sdxl, prepare_lcm_lora, |
|
prepare_lllite, prepare_motion_module, |
|
save_frames, save_imgs, save_video) |
|
|
|
controlnet_address_table={ |
|
"controlnet_tile" : ['lllyasviel/control_v11f1e_sd15_tile'], |
|
"controlnet_lineart_anime" : ['lllyasviel/control_v11p_sd15s2_lineart_anime'], |
|
"controlnet_ip2p" : ['lllyasviel/control_v11e_sd15_ip2p'], |
|
"controlnet_openpose" : ['lllyasviel/control_v11p_sd15_openpose'], |
|
"controlnet_softedge" : ['lllyasviel/control_v11p_sd15_softedge'], |
|
"controlnet_shuffle" : ['lllyasviel/control_v11e_sd15_shuffle'], |
|
"controlnet_depth" : ['lllyasviel/control_v11f1p_sd15_depth'], |
|
"controlnet_canny" : ['lllyasviel/control_v11p_sd15_canny'], |
|
"controlnet_inpaint" : ['lllyasviel/control_v11p_sd15_inpaint'], |
|
"controlnet_lineart" : ['lllyasviel/control_v11p_sd15_lineart'], |
|
"controlnet_mlsd" : ['lllyasviel/control_v11p_sd15_mlsd'], |
|
"controlnet_normalbae" : ['lllyasviel/control_v11p_sd15_normalbae'], |
|
"controlnet_scribble" : ['lllyasviel/control_v11p_sd15_scribble'], |
|
"controlnet_seg" : ['lllyasviel/control_v11p_sd15_seg'], |
|
"qr_code_monster_v1" : ['monster-labs/control_v1p_sd15_qrcode_monster'], |
|
"qr_code_monster_v2" : ['monster-labs/control_v1p_sd15_qrcode_monster', 'v2'], |
|
"controlnet_mediapipe_face" : ['CrucibleAI/ControlNetMediaPipeFace', "diffusion_sd15"], |
|
"animatediff_controlnet" : [None, "data/models/controlnet/animatediff_controlnet/controlnet_checkpoint.ckpt"] |
|
} |
|
|
|
|
|
controlnet_address_table_sdxl={ |
|
|
|
|
|
|
|
|
|
|
|
"qr_code_monster_v1" : ['monster-labs/control_v1p_sdxl_qrcode_monster'], |
|
} |
|
|
|
|
|
lllite_address_table_sdxl={ |
|
"controlnet_tile" : ['models/lllite/bdsqlsz_controlllite_xl_tile_anime_β.safetensors'], |
|
"controlnet_lineart_anime" : ['models/lllite/bdsqlsz_controlllite_xl_lineart_anime_denoise.safetensors'], |
|
|
|
"controlnet_openpose" : ['models/lllite/bdsqlsz_controlllite_xl_dw_openpose.safetensors'], |
|
|
|
"controlnet_softedge" : ['models/lllite/bdsqlsz_controlllite_xl_softedge.safetensors'], |
|
"controlnet_shuffle" : ['models/lllite/bdsqlsz_controlllite_xl_t2i-adapter_color_shuffle.safetensors'], |
|
"controlnet_depth" : ['models/lllite/bdsqlsz_controlllite_xl_depth.safetensors'], |
|
"controlnet_canny" : ['models/lllite/bdsqlsz_controlllite_xl_canny.safetensors'], |
|
|
|
|
|
|
|
"controlnet_mlsd" : ['models/lllite/bdsqlsz_controlllite_xl_mlsd_V2.safetensors'], |
|
"controlnet_normalbae" : ['models/lllite/bdsqlsz_controlllite_xl_normal.safetensors'], |
|
"controlnet_scribble" : ['models/lllite/bdsqlsz_controlllite_xl_sketch.safetensors'], |
|
"controlnet_seg" : ['models/lllite/bdsqlsz_controlllite_xl_segment_animeface_V2.safetensors'], |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
import onnxruntime |
|
onnxruntime_installed = True |
|
except: |
|
onnxruntime_installed = False |
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
data_dir = get_dir("data") |
|
default_base_path = data_dir.joinpath("models/huggingface/stable-diffusion-v1-5") |
|
|
|
re_clean_prompt = re.compile(r"[^\w\-, ]") |
|
|
|
controlnet_preprocessor = {} |
|
|
|
def load_safetensors_lora(text_encoder, unet, lora_path, alpha=0.75, is_animatediff=True): |
|
from safetensors.torch import load_file |
|
|
|
from animatediff.utils.lora_diffusers import (LoRANetwork, |
|
create_network_from_weights) |
|
|
|
sd = load_file(lora_path) |
|
|
|
print(f"create LoRA network") |
|
lora_network: LoRANetwork = create_network_from_weights(text_encoder, unet, sd, multiplier=alpha, is_animatediff=is_animatediff) |
|
print(f"load LoRA network weights") |
|
lora_network.load_state_dict(sd, False) |
|
|
|
lora_network.apply_to(alpha) |
|
return lora_network |
|
|
|
def load_safetensors_lora2(text_encoder, unet, lora_path, alpha=0.75, is_animatediff=True): |
|
from safetensors.torch import load_file |
|
|
|
from animatediff.utils.lora_diffusers import (LoRANetwork, |
|
create_network_from_weights) |
|
|
|
sd = load_file(lora_path) |
|
|
|
print(f"create LoRA network") |
|
lora_network: LoRANetwork = create_network_from_weights(text_encoder, unet, sd, multiplier=alpha, is_animatediff=is_animatediff) |
|
print(f"load LoRA network weights") |
|
lora_network.load_state_dict(sd, False) |
|
lora_network.merge_to(alpha) |
|
|
|
|
|
def load_tensors(path:Path,framework="pt",device="cpu"): |
|
tensors = {} |
|
if path.suffix == ".safetensors": |
|
from safetensors import safe_open |
|
with safe_open(path, framework=framework, device=device) as f: |
|
for k in f.keys(): |
|
tensors[k] = f.get_tensor(k) |
|
else: |
|
from torch import load |
|
tensors = load(path, device) |
|
if "state_dict" in tensors: |
|
tensors = tensors["state_dict"] |
|
return tensors |
|
|
|
def load_motion_lora(unet, lora_path:Path, alpha=1.0): |
|
state_dict = load_tensors(lora_path) |
|
|
|
|
|
for key in state_dict: |
|
|
|
if "up." in key: continue |
|
|
|
up_key = key.replace(".down.", ".up.") |
|
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") |
|
model_key = model_key.replace("to_out.", "to_out.0.") |
|
layer_infos = model_key.split(".")[:-1] |
|
|
|
curr_layer = unet |
|
try: |
|
while len(layer_infos) > 0: |
|
temp_name = layer_infos.pop(0) |
|
curr_layer = curr_layer.__getattr__(temp_name) |
|
except: |
|
logger.info(f"{model_key} not found") |
|
continue |
|
|
|
|
|
weight_down = state_dict[key] |
|
weight_up = state_dict[up_key] |
|
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) |
|
|
|
|
|
class SegPreProcessor: |
|
|
|
def __init__(self): |
|
self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") |
|
self.processor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small") |
|
|
|
def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): |
|
|
|
input_array = np.array(input_image, dtype=np.uint8) |
|
input_array = HWC3(input_array) |
|
input_array = aux_resize_image(input_array, detect_resolution) |
|
|
|
pixel_values = self.image_processor(input_array, return_tensors="pt").pixel_values |
|
|
|
with torch.no_grad(): |
|
outputs = self.processor(pixel_values.to(self.processor.device)) |
|
|
|
outputs.loss = outputs.loss.to("cpu") if outputs.loss is not None else outputs.loss |
|
outputs.logits = outputs.logits.to("cpu") if outputs.logits is not None else outputs.logits |
|
outputs.hidden_states = outputs.hidden_states.to("cpu") if outputs.hidden_states is not None else outputs.hidden_states |
|
outputs.attentions = outputs.attentions.to("cpu") if outputs.attentions is not None else outputs.attentions |
|
|
|
seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0] |
|
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) |
|
|
|
for label, color in enumerate(ade_palette()): |
|
color_seg[seg == label, :] = color |
|
|
|
color_seg = color_seg.astype(np.uint8) |
|
color_seg = aux_resize_image(color_seg, image_resolution) |
|
color_seg = Image.fromarray(color_seg) |
|
|
|
return color_seg |
|
|
|
class NullPreProcessor: |
|
def __call__(self, input_image, **kwargs): |
|
return input_image |
|
|
|
class BlurPreProcessor: |
|
def __call__(self, input_image, sigma=5.0, **kwargs): |
|
import cv2 |
|
|
|
input_array = np.array(input_image, dtype=np.uint8) |
|
input_array = HWC3(input_array) |
|
|
|
dst = cv2.GaussianBlur(input_array, (0, 0), sigma) |
|
|
|
return Image.fromarray(dst) |
|
|
|
class TileResamplePreProcessor: |
|
|
|
def resize(self, input_image, resolution): |
|
import cv2 |
|
|
|
H, W, C = input_image.shape |
|
H = float(H) |
|
W = float(W) |
|
k = float(resolution) / min(H, W) |
|
H *= k |
|
W *= k |
|
img = cv2.resize(input_image, (int(W), int(H)), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) |
|
return img |
|
|
|
def __call__(self, input_image, down_sampling_rate = 1.0, **kwargs): |
|
|
|
input_array = np.array(input_image, dtype=np.uint8) |
|
input_array = HWC3(input_array) |
|
|
|
H, W, C = input_array.shape |
|
|
|
target_res = min(H,W) / down_sampling_rate |
|
|
|
dst = self.resize(input_array, target_res) |
|
|
|
return Image.fromarray(dst) |
|
|
|
|
|
|
|
def is_valid_controlnet_type(type_str, is_sdxl): |
|
if not is_sdxl: |
|
return type_str in controlnet_address_table |
|
else: |
|
return (type_str in controlnet_address_table_sdxl) or (type_str in lllite_address_table_sdxl) |
|
|
|
def load_controlnet_from_file(file_path, torch_dtype): |
|
from safetensors.torch import load_file |
|
|
|
prepare_extra_controlnet() |
|
|
|
file_path = Path(file_path) |
|
|
|
if file_path.exists() and file_path.is_file(): |
|
if file_path.suffix.lower() in [".pth", ".pt", ".ckpt"]: |
|
controlnet_state_dict = torch.load(file_path, map_location="cpu", weights_only=True) |
|
elif file_path.suffix.lower() == ".safetensors": |
|
controlnet_state_dict = load_file(file_path, device="cpu") |
|
else: |
|
raise RuntimeError( |
|
f"unknown file format for controlnet weights: {file_path.suffix}" |
|
) |
|
else: |
|
raise FileNotFoundError(f"no controlnet weights found in {file_path}") |
|
|
|
if file_path.parent.name == "animatediff_controlnet": |
|
model = ControlNetModel(cross_attention_dim=768) |
|
else: |
|
model = ControlNetModel() |
|
|
|
missing, _ = model.load_state_dict(controlnet_state_dict["state_dict"], strict=False) |
|
if len(missing) > 0: |
|
logger.info(f"ControlNetModel has missing keys: {missing}") |
|
|
|
return model.to(dtype=torch_dtype) |
|
|
|
def create_controlnet_model(pipe, type_str, is_sdxl): |
|
if not is_sdxl: |
|
if type_str in controlnet_address_table: |
|
addr = controlnet_address_table[type_str] |
|
if addr[0] != None: |
|
if len(addr) == 1: |
|
return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16) |
|
else: |
|
return ControlNetModel.from_pretrained(addr[0], subfolder=addr[1], torch_dtype=torch.float16) |
|
else: |
|
return load_controlnet_from_file(addr[1],torch_dtype=torch.float16) |
|
else: |
|
raise ValueError(f"unknown controlnet type {type_str}") |
|
else: |
|
|
|
if type_str in controlnet_address_table_sdxl: |
|
addr = controlnet_address_table_sdxl[type_str] |
|
if len(addr) == 1: |
|
return ControlNetModel.from_pretrained(addr[0], torch_dtype=torch.float16) |
|
else: |
|
return ControlNetModel.from_pretrained(addr[0], subfolder=addr[1], torch_dtype=torch.float16) |
|
elif type_str in lllite_address_table_sdxl: |
|
addr = lllite_address_table_sdxl[type_str] |
|
model_path = data_dir.joinpath(addr[0]) |
|
return load_controlnet_lllite(model_path, pipe, torch_dtype=torch.float16) |
|
else: |
|
raise ValueError(f"unknown controlnet type {type_str}") |
|
|
|
|
|
|
|
default_preprocessor_table={ |
|
"controlnet_lineart_anime":"lineart_anime", |
|
"controlnet_openpose": "openpose_full" if onnxruntime_installed==False else "dwpose", |
|
"controlnet_softedge":"softedge_hedsafe", |
|
"controlnet_shuffle":"shuffle", |
|
"controlnet_depth":"depth_midas", |
|
"controlnet_canny":"canny", |
|
"controlnet_lineart":"lineart_realistic", |
|
"controlnet_mlsd":"mlsd", |
|
"controlnet_normalbae":"normal_bae", |
|
"controlnet_scribble":"scribble_pidsafe", |
|
"controlnet_seg":"upernet_seg", |
|
"controlnet_mediapipe_face":"mediapipe_face", |
|
"qr_code_monster_v1":"depth_midas", |
|
"qr_code_monster_v2":"depth_midas", |
|
} |
|
|
|
def create_preprocessor_from_name(pre_type): |
|
if pre_type == "dwpose": |
|
prepare_dwpose() |
|
return DWposeDetector() |
|
elif pre_type == "upernet_seg": |
|
return SegPreProcessor() |
|
elif pre_type == "blur": |
|
return BlurPreProcessor() |
|
elif pre_type == "tile_resample": |
|
return TileResamplePreProcessor() |
|
elif pre_type == "none": |
|
return NullPreProcessor() |
|
elif pre_type in MODELS: |
|
return ControlnetPreProcessor(pre_type) |
|
else: |
|
raise ValueError(f"unknown controlnet preprocessor type {pre_type}") |
|
|
|
|
|
def create_default_preprocessor(type_str): |
|
if type_str in default_preprocessor_table: |
|
pre_type = default_preprocessor_table[type_str] |
|
else: |
|
pre_type = "none" |
|
|
|
return create_preprocessor_from_name(pre_type) |
|
|
|
|
|
def get_preprocessor(type_str, device_str, preprocessor_map): |
|
if type_str not in controlnet_preprocessor: |
|
if preprocessor_map: |
|
controlnet_preprocessor[type_str] = create_preprocessor_from_name(preprocessor_map["type"]) |
|
|
|
if type_str not in controlnet_preprocessor: |
|
controlnet_preprocessor[type_str] = create_default_preprocessor(type_str) |
|
|
|
if hasattr(controlnet_preprocessor[type_str], "processor"): |
|
if hasattr(controlnet_preprocessor[type_str].processor, "to"): |
|
if device_str: |
|
controlnet_preprocessor[type_str].processor.to(device_str) |
|
elif hasattr(controlnet_preprocessor[type_str], "to"): |
|
if device_str: |
|
controlnet_preprocessor[type_str].to(device_str) |
|
|
|
|
|
return controlnet_preprocessor[type_str] |
|
|
|
def clear_controlnet_preprocessor(type_str = None): |
|
global controlnet_preprocessor |
|
if type_str == None: |
|
for t in controlnet_preprocessor: |
|
controlnet_preprocessor[t] = None |
|
controlnet_preprocessor={} |
|
torch.cuda.empty_cache() |
|
else: |
|
controlnet_preprocessor[type_str] = None |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def get_preprocessed_img(type_str, img, use_preprocessor, device_str, preprocessor_map): |
|
if use_preprocessor: |
|
param = {} |
|
if preprocessor_map: |
|
param = preprocessor_map["param"] if "param" in preprocessor_map else {} |
|
return get_preprocessor(type_str, device_str, preprocessor_map)(img, **param) |
|
else: |
|
return img |
|
|
|
|
|
def create_pipeline_sdxl( |
|
base_model: Union[str, PathLike] = default_base_path, |
|
model_config: ModelConfig = ..., |
|
infer_config: InferenceConfig = ..., |
|
use_xformers: bool = True, |
|
video_length: int = 16, |
|
motion_module_path = ..., |
|
): |
|
from animatediff.pipelines.sdxl_animation import AnimationPipeline |
|
from animatediff.sdxl_models.unet import UNet3DConditionModel |
|
|
|
logger.info("Loading tokenizer...") |
|
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer") |
|
logger.info("Loading text encoder...") |
|
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(base_model, subfolder="text_encoder", torch_dtype=torch.float16) |
|
logger.info("Loading VAE...") |
|
vae: AutoencoderKL = AutoencoderKL.from_pretrained(base_model, subfolder="vae") |
|
logger.info("Loading tokenizer two...") |
|
tokenizer_two = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer_2") |
|
logger.info("Loading text encoder two...") |
|
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_model, subfolder="text_encoder_2", torch_dtype=torch.float16) |
|
|
|
|
|
logger.info("Loading UNet...") |
|
unet: UNet3DConditionModel = UNet3DConditionModel.from_pretrained_2d( |
|
pretrained_model_path=base_model, |
|
motion_module_path=motion_module_path, |
|
subfolder="unet", |
|
unet_additional_kwargs=infer_config.unet_additional_kwargs, |
|
) |
|
|
|
|
|
sched_kwargs = infer_config.noise_scheduler_kwargs |
|
scheduler = get_scheduler(model_config.scheduler, sched_kwargs) |
|
logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})') |
|
|
|
if model_config.gradual_latent_hires_fix_map: |
|
if "enable" in model_config.gradual_latent_hires_fix_map: |
|
if model_config.gradual_latent_hires_fix_map["enable"]: |
|
if model_config.scheduler not in (DiffusionScheduler.euler_a, DiffusionScheduler.lcm): |
|
logger.warn("gradual_latent_hires_fix enable") |
|
logger.warn(f"{model_config.scheduler=}") |
|
logger.warn("If you are forced to exit with an error, change to euler_a or lcm") |
|
|
|
|
|
|
|
|
|
if model_config.path is not None: |
|
model_path = data_dir.joinpath(model_config.path) |
|
logger.info(f"Loading weights from {model_path}") |
|
if model_path.is_file(): |
|
logger.debug("Loading from single checkpoint file") |
|
unet_state_dict, tenc_state_dict, tenc2_state_dict, vae_state_dict = get_checkpoint_weights_sdxl(model_path) |
|
elif model_path.is_dir(): |
|
logger.debug("Loading from Diffusers model directory") |
|
temp_pipeline = StableDiffusionXLPipeline.from_pretrained(model_path) |
|
unet_state_dict, tenc_state_dict, tenc2_state_dict, vae_state_dict = ( |
|
temp_pipeline.unet.state_dict(), |
|
temp_pipeline.text_encoder.state_dict(), |
|
temp_pipeline.text_encoder_2.state_dict(), |
|
temp_pipeline.vae.state_dict(), |
|
) |
|
del temp_pipeline |
|
else: |
|
raise FileNotFoundError(f"model_path {model_path} is not a file or directory") |
|
|
|
|
|
logger.info("Merging weights into UNet...") |
|
_, unet_unex = unet.load_state_dict(unet_state_dict, strict=False) |
|
if len(unet_unex) > 0: |
|
raise ValueError(f"UNet has unexpected keys: {unet_unex}") |
|
tenc_missing, _ = text_encoder.load_state_dict(tenc_state_dict, strict=False) |
|
if len(tenc_missing) > 0: |
|
raise ValueError(f"TextEncoder has missing keys: {tenc_missing}") |
|
tenc2_missing, _ = text_encoder_two.load_state_dict(tenc2_state_dict, strict=False) |
|
if len(tenc2_missing) > 0: |
|
raise ValueError(f"TextEncoder2 has missing keys: {tenc2_missing}") |
|
vae_missing, _ = vae.load_state_dict(vae_state_dict, strict=False) |
|
if len(vae_missing) > 0: |
|
raise ValueError(f"VAE has missing keys: {vae_missing}") |
|
else: |
|
logger.info("Using base model weights (no checkpoint/LoRA)") |
|
|
|
if model_config.vae_path: |
|
vae_path = data_dir.joinpath(model_config.vae_path) |
|
logger.info(f"Loading vae from {vae_path}") |
|
|
|
if vae_path.is_dir(): |
|
vae = AutoencoderKL.from_pretrained(vae_path) |
|
else: |
|
tensors = load_tensors(vae_path) |
|
tensors = convert_ldm_vae_checkpoint(tensors, vae.config) |
|
vae.load_state_dict(tensors) |
|
|
|
unet.to(torch.float16) |
|
text_encoder.to(torch.float16) |
|
text_encoder_two.to(torch.float16) |
|
|
|
del unet_state_dict |
|
del tenc_state_dict |
|
del tenc2_state_dict |
|
del vae_state_dict |
|
|
|
|
|
if use_xformers: |
|
logger.info("Enabling xformers memory-efficient attention") |
|
unet.enable_xformers_memory_efficient_attention() |
|
|
|
|
|
for l in model_config.motion_lora_map: |
|
lora_path = data_dir.joinpath(l) |
|
logger.info(f"loading motion lora {lora_path=}") |
|
if lora_path.is_file(): |
|
logger.info(f"Loading motion lora {lora_path}") |
|
logger.info(f"alpha = {model_config.motion_lora_map[l]}") |
|
load_motion_lora(unet, lora_path, alpha=model_config.motion_lora_map[l]) |
|
else: |
|
raise ValueError(f"{lora_path=} not found") |
|
|
|
logger.info("Creating AnimationPipeline...") |
|
pipeline = AnimationPipeline( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
text_encoder_2=text_encoder_two, |
|
tokenizer=tokenizer, |
|
tokenizer_2=tokenizer_two, |
|
unet=unet, |
|
scheduler=scheduler, |
|
controlnet_map=None, |
|
) |
|
|
|
del vae |
|
del text_encoder |
|
del text_encoder_two |
|
del tokenizer |
|
del tokenizer_two |
|
del unet |
|
|
|
torch.cuda.empty_cache() |
|
|
|
pipeline.lcm = None |
|
if model_config.lcm_map: |
|
if model_config.lcm_map["enable"]: |
|
prepare_lcm_lora() |
|
load_lcm_lora(pipeline, model_config.lcm_map, is_sdxl=True) |
|
|
|
load_lora_map(pipeline, model_config.lora_map, video_length, is_sdxl=True) |
|
|
|
pipeline.unet = pipeline.unet.half() |
|
pipeline.text_encoder = pipeline.text_encoder.half() |
|
pipeline.text_encoder_2 = pipeline.text_encoder_2.half() |
|
|
|
|
|
pipeline.text_encoder = pipeline.text_encoder.to("cuda") |
|
pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cuda") |
|
|
|
load_text_embeddings(pipeline, is_sdxl=True) |
|
|
|
pipeline.text_encoder = pipeline.text_encoder.to("cpu") |
|
pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cpu") |
|
|
|
return pipeline |
|
|
|
|
|
def create_pipeline( |
|
base_model: Union[str, PathLike] = default_base_path, |
|
model_config: ModelConfig = ..., |
|
infer_config: InferenceConfig = ..., |
|
use_xformers: bool = True, |
|
video_length: int = 16, |
|
is_sdxl:bool = False, |
|
) -> DiffusionPipeline: |
|
"""Create an AnimationPipeline from a pretrained model. |
|
Uses the base_model argument to load or download the pretrained reference pipeline model.""" |
|
|
|
|
|
logger.info("Checking motion module...") |
|
motion_module = data_dir.joinpath(model_config.motion_module) |
|
if not (motion_module.exists() and motion_module.is_file()): |
|
prepare_motion_module() |
|
if not (motion_module.exists() and motion_module.is_file()): |
|
|
|
motion_module = motion_module.with_suffix(".safetensors") |
|
if not (motion_module.exists() and motion_module.is_file()): |
|
|
|
ensure_motion_modules() |
|
if not (motion_module.exists() and motion_module.is_file()): |
|
|
|
raise FileNotFoundError(f"Motion module {motion_module} does not exist or is not a file!") |
|
|
|
if is_sdxl: |
|
return create_pipeline_sdxl( |
|
base_model=base_model, |
|
model_config=model_config, |
|
infer_config=infer_config, |
|
use_xformers=use_xformers, |
|
video_length=video_length, |
|
motion_module_path=motion_module, |
|
) |
|
|
|
logger.info("Loading tokenizer...") |
|
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer") |
|
logger.info("Loading text encoder...") |
|
text_encoder: CLIPSkipTextModel = CLIPSkipTextModel.from_pretrained(base_model, subfolder="text_encoder") |
|
logger.info("Loading VAE...") |
|
vae: AutoencoderKL = AutoencoderKL.from_pretrained(base_model, subfolder="vae") |
|
logger.info("Loading UNet...") |
|
unet: UNet3DConditionModel = UNet3DConditionModel.from_pretrained_2d( |
|
pretrained_model_path=base_model, |
|
motion_module_path=motion_module, |
|
subfolder="unet", |
|
unet_additional_kwargs=infer_config.unet_additional_kwargs, |
|
) |
|
feature_extractor = CLIPImageProcessor.from_pretrained(base_model, subfolder="feature_extractor") |
|
|
|
|
|
if model_config.gradual_latent_hires_fix_map: |
|
if "enable" in model_config.gradual_latent_hires_fix_map: |
|
if model_config.gradual_latent_hires_fix_map["enable"]: |
|
if model_config.scheduler not in (DiffusionScheduler.euler_a, DiffusionScheduler.lcm): |
|
logger.warn("gradual_latent_hires_fix enable") |
|
logger.warn(f"{model_config.scheduler=}") |
|
logger.warn("If you are forced to exit with an error, change to euler_a or lcm") |
|
|
|
sched_kwargs = infer_config.noise_scheduler_kwargs |
|
scheduler = get_scheduler(model_config.scheduler, sched_kwargs) |
|
logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})') |
|
|
|
|
|
if model_config.path is not None: |
|
model_path = data_dir.joinpath(model_config.path) |
|
logger.info(f"Loading weights from {model_path}") |
|
if model_path.is_file(): |
|
logger.debug("Loading from single checkpoint file") |
|
unet_state_dict, tenc_state_dict, vae_state_dict = get_checkpoint_weights(model_path) |
|
elif model_path.is_dir(): |
|
logger.debug("Loading from Diffusers model directory") |
|
temp_pipeline = StableDiffusionPipeline.from_pretrained(model_path) |
|
unet_state_dict, tenc_state_dict, vae_state_dict = ( |
|
temp_pipeline.unet.state_dict(), |
|
temp_pipeline.text_encoder.state_dict(), |
|
temp_pipeline.vae.state_dict(), |
|
) |
|
del temp_pipeline |
|
else: |
|
raise FileNotFoundError(f"model_path {model_path} is not a file or directory") |
|
|
|
|
|
logger.info("Merging weights into UNet...") |
|
_, unet_unex = unet.load_state_dict(unet_state_dict, strict=False) |
|
if len(unet_unex) > 0: |
|
raise ValueError(f"UNet has unexpected keys: {unet_unex}") |
|
tenc_missing, _ = text_encoder.load_state_dict(tenc_state_dict, strict=False) |
|
if len(tenc_missing) > 0: |
|
raise ValueError(f"TextEncoder has missing keys: {tenc_missing}") |
|
vae_missing, _ = vae.load_state_dict(vae_state_dict, strict=False) |
|
if len(vae_missing) > 0: |
|
raise ValueError(f"VAE has missing keys: {vae_missing}") |
|
else: |
|
logger.info("Using base model weights (no checkpoint/LoRA)") |
|
|
|
if model_config.vae_path: |
|
vae_path = data_dir.joinpath(model_config.vae_path) |
|
logger.info(f"Loading vae from {vae_path}") |
|
|
|
if vae_path.is_dir(): |
|
vae = AutoencoderKL.from_pretrained(vae_path) |
|
else: |
|
tensors = load_tensors(vae_path) |
|
tensors = convert_ldm_vae_checkpoint(tensors, vae.config) |
|
vae.load_state_dict(tensors) |
|
|
|
|
|
|
|
if use_xformers: |
|
logger.info("Enabling xformers memory-efficient attention") |
|
unet.enable_xformers_memory_efficient_attention() |
|
|
|
if False: |
|
|
|
for l in model_config.lora_map: |
|
lora_path = data_dir.joinpath(l) |
|
if lora_path.is_file(): |
|
logger.info(f"Loading lora {lora_path}") |
|
logger.info(f"alpha = {model_config.lora_map[l]}") |
|
load_safetensors_lora(text_encoder, unet, lora_path, alpha=model_config.lora_map[l]) |
|
|
|
|
|
for l in model_config.motion_lora_map: |
|
lora_path = data_dir.joinpath(l) |
|
logger.info(f"loading motion lora {lora_path=}") |
|
if lora_path.is_file(): |
|
logger.info(f"Loading motion lora {lora_path}") |
|
logger.info(f"alpha = {model_config.motion_lora_map[l]}") |
|
load_motion_lora(unet, lora_path, alpha=model_config.motion_lora_map[l]) |
|
else: |
|
raise ValueError(f"{lora_path=} not found") |
|
|
|
logger.info("Creating AnimationPipeline...") |
|
pipeline = AnimationPipeline( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
feature_extractor=feature_extractor, |
|
controlnet_map=None, |
|
) |
|
|
|
pipeline.lcm = None |
|
if model_config.lcm_map: |
|
if model_config.lcm_map["enable"]: |
|
prepare_lcm_lora() |
|
load_lcm_lora(pipeline, model_config.lcm_map, is_sdxl=False) |
|
|
|
load_lora_map(pipeline, model_config.lora_map, video_length) |
|
|
|
|
|
pipeline.unet = pipeline.unet.half() |
|
pipeline.text_encoder = pipeline.text_encoder.half() |
|
|
|
pipeline.text_encoder = pipeline.text_encoder.to("cuda") |
|
|
|
load_text_embeddings(pipeline) |
|
|
|
pipeline.text_encoder = pipeline.text_encoder.to("cpu") |
|
|
|
return pipeline |
|
|
|
def load_controlnet_models(pipe: DiffusionPipeline, model_config: ModelConfig = ..., is_sdxl:bool = False): |
|
|
|
|
|
if is_sdxl: |
|
prepare_lllite() |
|
|
|
controlnet_map={} |
|
if model_config.controlnet_map: |
|
c_image_dir = data_dir.joinpath( model_config.controlnet_map["input_image_dir"] ) |
|
|
|
for c in model_config.controlnet_map: |
|
item = model_config.controlnet_map[c] |
|
if type(item) is dict: |
|
if item["enable"] == True: |
|
if is_valid_controlnet_type(c, is_sdxl): |
|
img_dir = c_image_dir.joinpath( c ) |
|
cond_imgs = sorted(glob.glob( os.path.join(img_dir, "[0-9]*.png"), recursive=False)) |
|
if len(cond_imgs) > 0: |
|
logger.info(f"loading {c=} model") |
|
controlnet_map[c] = create_controlnet_model(pipe, c , is_sdxl) |
|
else: |
|
logger.info(f"invalid controlnet type for {'sdxl' if is_sdxl else 'sd15'} : {c}") |
|
|
|
if not controlnet_map: |
|
controlnet_map = None |
|
|
|
pipe.controlnet_map = controlnet_map |
|
|
|
def unload_controlnet_models(pipe: AnimationPipeline): |
|
from animatediff.utils.util import show_gpu |
|
|
|
if pipe.controlnet_map: |
|
for c in pipe.controlnet_map: |
|
controlnet = pipe.controlnet_map[c] |
|
if isinstance(controlnet, ControlNetLLLite): |
|
controlnet.unapply_to() |
|
del controlnet |
|
|
|
|
|
pipe.controlnet_map = None |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
def create_us_pipeline( |
|
model_config: ModelConfig = ..., |
|
infer_config: InferenceConfig = ..., |
|
use_xformers: bool = True, |
|
use_controlnet_ref: bool = False, |
|
use_controlnet_tile: bool = False, |
|
use_controlnet_line_anime: bool = False, |
|
use_controlnet_ip2p: bool = False, |
|
) -> DiffusionPipeline: |
|
|
|
|
|
sched_kwargs = infer_config.noise_scheduler_kwargs |
|
scheduler = get_scheduler(model_config.scheduler, sched_kwargs) |
|
logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})') |
|
|
|
controlnet = [] |
|
if use_controlnet_tile: |
|
controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11f1e_sd15_tile') ) |
|
if use_controlnet_line_anime: |
|
controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15s2_lineart_anime') ) |
|
if use_controlnet_ip2p: |
|
controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11e_sd15_ip2p') ) |
|
|
|
if len(controlnet) == 1: |
|
controlnet = controlnet[0] |
|
elif len(controlnet) == 0: |
|
controlnet = None |
|
|
|
|
|
pipeline:DiffusionPipeline |
|
|
|
if model_config.path is not None: |
|
model_path = data_dir.joinpath(model_config.path) |
|
logger.info(f"Loading weights from {model_path}") |
|
if model_path.is_file(): |
|
|
|
def is_empty_dir(path): |
|
import os |
|
return len(os.listdir(path)) == 0 |
|
|
|
save_path = data_dir.joinpath("models/huggingface/" + model_path.stem + "_" + str(model_path.stat().st_size)) |
|
save_path.mkdir(exist_ok=True) |
|
if save_path.is_dir() and is_empty_dir(save_path): |
|
|
|
logger.debug("Loading from single checkpoint file") |
|
tmp_pipeline = StableDiffusionPipeline.from_single_file( |
|
pretrained_model_link_or_path=str(model_path.absolute()) |
|
) |
|
tmp_pipeline.save_pretrained(save_path, safe_serialization=True) |
|
del tmp_pipeline |
|
|
|
if use_controlnet_ref: |
|
pipeline = StableDiffusionControlNetImg2ImgReferencePipeline.from_pretrained( |
|
save_path, |
|
controlnet=controlnet, |
|
local_files_only=False, |
|
load_safety_checker=False, |
|
safety_checker=None, |
|
) |
|
else: |
|
pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( |
|
save_path, |
|
controlnet=controlnet, |
|
local_files_only=False, |
|
load_safety_checker=False, |
|
safety_checker=None, |
|
) |
|
|
|
elif model_path.is_dir(): |
|
logger.debug("Loading from Diffusers model directory") |
|
if use_controlnet_ref: |
|
pipeline = StableDiffusionControlNetImg2ImgReferencePipeline.from_pretrained( |
|
model_path, |
|
controlnet=controlnet, |
|
local_files_only=True, |
|
load_safety_checker=False, |
|
safety_checker=None, |
|
) |
|
else: |
|
pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( |
|
model_path, |
|
controlnet=controlnet, |
|
local_files_only=True, |
|
load_safety_checker=False, |
|
safety_checker=None, |
|
) |
|
else: |
|
raise FileNotFoundError(f"model_path {model_path} is not a file or directory") |
|
else: |
|
raise ValueError("model_config.path is invalid") |
|
|
|
pipeline.scheduler = scheduler |
|
|
|
|
|
if use_xformers: |
|
logger.info("Enabling xformers memory-efficient attention") |
|
pipeline.enable_xformers_memory_efficient_attention() |
|
|
|
|
|
for l in model_config.lora_map: |
|
lora_path = data_dir.joinpath(l) |
|
if lora_path.is_file(): |
|
alpha = model_config.lora_map[l] |
|
if isinstance(alpha, dict): |
|
alpha = 0.75 |
|
|
|
logger.info(f"Loading lora {lora_path}") |
|
logger.info(f"alpha = {alpha}") |
|
load_safetensors_lora2(pipeline.text_encoder, pipeline.unet, lora_path, alpha=alpha,is_animatediff=False) |
|
|
|
|
|
pipeline.unet = pipeline.unet.half() |
|
pipeline.text_encoder = pipeline.text_encoder.half() |
|
|
|
pipeline.text_encoder = pipeline.text_encoder.to("cuda") |
|
|
|
load_text_embeddings(pipeline) |
|
|
|
pipeline.text_encoder = pipeline.text_encoder.to("cpu") |
|
|
|
return pipeline |
|
|
|
|
|
def seed_everything(seed): |
|
import random |
|
|
|
import numpy as np |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
np.random.seed(seed % (2**32)) |
|
random.seed(seed) |
|
|
|
def controlnet_preprocess( |
|
controlnet_map: Dict[str, Any] = None, |
|
width: int = 512, |
|
height: int = 512, |
|
duration: int = 16, |
|
out_dir: PathLike = ..., |
|
device_str:str=None, |
|
is_sdxl:bool = False, |
|
): |
|
|
|
if not controlnet_map: |
|
return None, None, None, None |
|
|
|
out_dir = Path(out_dir) |
|
|
|
|
|
controlnet_image_map={} |
|
|
|
controlnet_type_map={} |
|
|
|
c_image_dir = data_dir.joinpath( controlnet_map["input_image_dir"] ) |
|
save_detectmap = controlnet_map["save_detectmap"] if "save_detectmap" in controlnet_map else True |
|
|
|
preprocess_on_gpu = controlnet_map["preprocess_on_gpu"] if "preprocess_on_gpu" in controlnet_map else True |
|
device_str = device_str if preprocess_on_gpu else None |
|
|
|
for c in controlnet_map: |
|
if c == "controlnet_ref": |
|
continue |
|
|
|
item = controlnet_map[c] |
|
|
|
processed = False |
|
|
|
if type(item) is dict: |
|
if item["enable"] == True: |
|
|
|
if is_valid_controlnet_type(c, is_sdxl): |
|
preprocessor_map = item["preprocessor"] if "preprocessor" in item else {} |
|
|
|
img_dir = c_image_dir.joinpath( c ) |
|
cond_imgs = sorted(glob.glob( os.path.join(img_dir, "[0-9]*.png"), recursive=False)) |
|
if len(cond_imgs) > 0: |
|
|
|
controlnet_type_map[c] = { |
|
"controlnet_conditioning_scale" : item["controlnet_conditioning_scale"], |
|
"control_guidance_start" : item["control_guidance_start"], |
|
"control_guidance_end" : item["control_guidance_end"], |
|
"control_scale_list" : item["control_scale_list"], |
|
"guess_mode" : item["guess_mode"] if "guess_mode" in item else False, |
|
"control_region_list" : item["control_region_list"] if "control_region_list" in item else [] |
|
} |
|
|
|
use_preprocessor = item["use_preprocessor"] if "use_preprocessor" in item else True |
|
|
|
for img_path in tqdm(cond_imgs, desc=f"Preprocessing images ({c})"): |
|
frame_no = int(Path(img_path).stem) |
|
if frame_no < duration: |
|
if frame_no not in controlnet_image_map: |
|
controlnet_image_map[frame_no] = {} |
|
controlnet_image_map[frame_no][c] = get_preprocessed_img( c, get_resized_image2(img_path, 512) , use_preprocessor, device_str, preprocessor_map) |
|
processed = True |
|
else: |
|
logger.info(f"invalid controlnet type for {'sdxl' if is_sdxl else 'sd15'} : {c}") |
|
|
|
|
|
if save_detectmap and processed: |
|
det_dir = out_dir.joinpath(f"{0:02d}_detectmap/{c}") |
|
det_dir.mkdir(parents=True, exist_ok=True) |
|
for frame_no in tqdm(controlnet_image_map, desc=f"Saving Preprocessed images ({c})"): |
|
save_path = det_dir.joinpath(f"{frame_no:08d}.png") |
|
if c in controlnet_image_map[frame_no]: |
|
controlnet_image_map[frame_no][c].save(save_path) |
|
|
|
clear_controlnet_preprocessor(c) |
|
|
|
clear_controlnet_preprocessor() |
|
|
|
controlnet_ref_map = None |
|
|
|
if "controlnet_ref" in controlnet_map: |
|
r = controlnet_map["controlnet_ref"] |
|
if r["enable"] == True: |
|
org_name = data_dir.joinpath( r["ref_image"]).stem |
|
|
|
ref_image = get_resized_image2( data_dir.joinpath( r["ref_image"] ) , 512) |
|
|
|
if ref_image is not None: |
|
controlnet_ref_map = { |
|
"ref_image" : ref_image, |
|
"style_fidelity" : r["style_fidelity"], |
|
"attention_auto_machine_weight" : r["attention_auto_machine_weight"], |
|
"gn_auto_machine_weight" : r["gn_auto_machine_weight"], |
|
"reference_attn" : r["reference_attn"], |
|
"reference_adain" : r["reference_adain"], |
|
"scale_pattern" : r["scale_pattern"] |
|
} |
|
|
|
if save_detectmap: |
|
det_dir = out_dir.joinpath(f"{0:02d}_detectmap/controlnet_ref") |
|
det_dir.mkdir(parents=True, exist_ok=True) |
|
save_path = det_dir.joinpath(f"{org_name}.png") |
|
ref_image.save(save_path) |
|
|
|
controlnet_no_shrink = ["controlnet_tile","animatediff_controlnet","controlnet_canny","controlnet_normalbae","controlnet_depth","controlnet_lineart","controlnet_lineart_anime","controlnet_scribble","controlnet_seg","controlnet_softedge","controlnet_mlsd"] |
|
if "no_shrink_list" in controlnet_map: |
|
controlnet_no_shrink = controlnet_map["no_shrink_list"] |
|
|
|
return controlnet_image_map, controlnet_type_map, controlnet_ref_map, controlnet_no_shrink |
|
|
|
|
|
def ip_adapter_preprocess( |
|
ip_adapter_config_map: Dict[str, Any] = None, |
|
width: int = 512, |
|
height: int = 512, |
|
duration: int = 16, |
|
out_dir: PathLike = ..., |
|
is_sdxl: bool = False, |
|
): |
|
|
|
ip_adapter_map={} |
|
|
|
processed = False |
|
|
|
if ip_adapter_config_map: |
|
if ip_adapter_config_map["enable"] == True: |
|
resized_to_square = ip_adapter_config_map["resized_to_square"] if "resized_to_square" in ip_adapter_config_map else False |
|
image_dir = data_dir.joinpath( ip_adapter_config_map["input_image_dir"] ) |
|
imgs = sorted(chain.from_iterable([glob.glob(os.path.join(image_dir, f"[0-9]*{ext}")) for ext in IMG_EXTENSIONS])) |
|
if len(imgs) > 0: |
|
prepare_ip_adapter_sdxl() if is_sdxl else prepare_ip_adapter() |
|
ip_adapter_map["images"] = {} |
|
for img_path in tqdm(imgs, desc=f"Preprocessing images (ip_adapter)"): |
|
frame_no = int(Path(img_path).stem) |
|
if frame_no < duration: |
|
if resized_to_square: |
|
ip_adapter_map["images"][frame_no] = get_resized_image(img_path, 256, 256) |
|
else: |
|
ip_adapter_map["images"][frame_no] = get_resized_image2(img_path, 256) |
|
processed = True |
|
|
|
if processed: |
|
ip_adapter_config_map["prompt_fixed_ratio"] = max(min(1.0, ip_adapter_config_map["prompt_fixed_ratio"]),0) |
|
|
|
prompt_fixed_ratio = ip_adapter_config_map["prompt_fixed_ratio"] |
|
prompt_map = ip_adapter_map["images"] |
|
prompt_map = dict(sorted(prompt_map.items())) |
|
key_list = list(prompt_map.keys()) |
|
for k0,k1 in zip(key_list,key_list[1:]+[duration]): |
|
k05 = k0 + round((k1-k0) * prompt_fixed_ratio) |
|
if k05 == k1: |
|
k05 -= 1 |
|
if k05 != k0: |
|
prompt_map[k05] = prompt_map[k0] |
|
ip_adapter_map["images"] = prompt_map |
|
|
|
if (ip_adapter_config_map["save_input_image"] == True) and processed: |
|
det_dir = out_dir.joinpath(f"{0:02d}_ip_adapter/") |
|
det_dir.mkdir(parents=True, exist_ok=True) |
|
for frame_no in tqdm(ip_adapter_map["images"], desc=f"Saving Preprocessed images (ip_adapter)"): |
|
save_path = det_dir.joinpath(f"{frame_no:08d}.png") |
|
ip_adapter_map["images"][frame_no].save(save_path) |
|
|
|
return ip_adapter_map if processed else None |
|
|
|
def prompt_preprocess( |
|
prompt_config_map: Dict[str, Any], |
|
head_prompt: str, |
|
tail_prompt: str, |
|
prompt_fixed_ratio: float, |
|
video_length: int, |
|
): |
|
prompt_map = {} |
|
for k in prompt_config_map.keys(): |
|
if int(k) < video_length: |
|
pr = prompt_config_map[k] |
|
if head_prompt: |
|
pr = head_prompt + "," + pr |
|
if tail_prompt: |
|
pr = pr + "," + tail_prompt |
|
|
|
prompt_map[int(k)]=pr |
|
|
|
prompt_map = dict(sorted(prompt_map.items())) |
|
key_list = list(prompt_map.keys()) |
|
for k0,k1 in zip(key_list,key_list[1:]+[video_length]): |
|
k05 = k0 + round((k1-k0) * prompt_fixed_ratio) |
|
if k05 == k1: |
|
k05 -= 1 |
|
if k05 != k0: |
|
prompt_map[k05] = prompt_map[k0] |
|
|
|
return prompt_map |
|
|
|
|
|
def region_preprocess( |
|
model_config: ModelConfig = ..., |
|
width: int = 512, |
|
height: int = 512, |
|
duration: int = 16, |
|
out_dir: PathLike = ..., |
|
is_init_img_exist: bool = False, |
|
is_sdxl:bool = False, |
|
): |
|
|
|
is_bg_init_img = False |
|
if is_init_img_exist: |
|
if model_config.region_map: |
|
if "background" in model_config.region_map: |
|
is_bg_init_img = model_config.region_map["background"]["is_init_img"] |
|
|
|
|
|
region_condi_list=[] |
|
region2index={} |
|
|
|
condi_index = 0 |
|
|
|
prev_ip_map = None |
|
|
|
if not is_bg_init_img: |
|
ip_map = ip_adapter_preprocess( |
|
model_config.ip_adapter_map, |
|
width, |
|
height, |
|
duration, |
|
out_dir, |
|
is_sdxl |
|
) |
|
|
|
if ip_map: |
|
prev_ip_map = ip_map |
|
|
|
condition_map = { |
|
"prompt_map": prompt_preprocess( |
|
model_config.prompt_map, |
|
model_config.head_prompt, |
|
model_config.tail_prompt, |
|
model_config.prompt_fixed_ratio, |
|
duration |
|
), |
|
"ip_adapter_map": ip_map |
|
} |
|
|
|
region_condi_list.append( condition_map ) |
|
|
|
bg_src = condi_index |
|
condi_index += 1 |
|
else: |
|
bg_src = -1 |
|
|
|
region_list=[ |
|
{ |
|
"mask_images": None, |
|
"src" : bg_src, |
|
"crop_generation_rate" : 0 |
|
} |
|
] |
|
region2index["background"]=bg_src |
|
|
|
if model_config.region_map: |
|
for r in model_config.region_map: |
|
if r == "background": |
|
continue |
|
if model_config.region_map[r]["enable"] != True: |
|
continue |
|
region_dir = out_dir.joinpath(f"region_{int(r):05d}/") |
|
region_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
mask_map = mask_preprocess( |
|
model_config.region_map[r], |
|
width, |
|
height, |
|
duration, |
|
region_dir |
|
) |
|
|
|
if not mask_map: |
|
continue |
|
|
|
if model_config.region_map[r]["is_init_img"] == False: |
|
ip_map = ip_adapter_preprocess( |
|
model_config.region_map[r]["condition"]["ip_adapter_map"], |
|
width, |
|
height, |
|
duration, |
|
region_dir, |
|
is_sdxl |
|
) |
|
|
|
if ip_map: |
|
prev_ip_map = ip_map |
|
|
|
condition_map={ |
|
"prompt_map": prompt_preprocess( |
|
model_config.region_map[r]["condition"]["prompt_map"], |
|
model_config.region_map[r]["condition"]["head_prompt"], |
|
model_config.region_map[r]["condition"]["tail_prompt"], |
|
model_config.region_map[r]["condition"]["prompt_fixed_ratio"], |
|
duration |
|
), |
|
"ip_adapter_map": ip_map |
|
} |
|
|
|
region_condi_list.append( condition_map ) |
|
|
|
src = condi_index |
|
condi_index += 1 |
|
else: |
|
if is_init_img_exist == False: |
|
logger.warn("'is_init_img' : true / BUT init_img is not exist -> ignore region") |
|
continue |
|
src = -1 |
|
|
|
region_list.append( |
|
{ |
|
"mask_images": mask_map, |
|
"src" : src, |
|
"crop_generation_rate" : model_config.region_map[r]["crop_generation_rate"] if "crop_generation_rate" in model_config.region_map[r] else 0 |
|
} |
|
) |
|
region2index[r]=src |
|
|
|
ip_adapter_config_map = None |
|
|
|
if prev_ip_map is not None: |
|
ip_adapter_config_map={} |
|
ip_adapter_config_map["scale"] = model_config.ip_adapter_map["scale"] |
|
ip_adapter_config_map["is_plus"] = model_config.ip_adapter_map["is_plus"] |
|
ip_adapter_config_map["is_plus_face"] = model_config.ip_adapter_map["is_plus_face"] if "is_plus_face" in model_config.ip_adapter_map else False |
|
ip_adapter_config_map["is_light"] = model_config.ip_adapter_map["is_light"] if "is_light" in model_config.ip_adapter_map else False |
|
ip_adapter_config_map["is_full_face"] = model_config.ip_adapter_map["is_full_face"] if "is_full_face" in model_config.ip_adapter_map else False |
|
for c in region_condi_list: |
|
if c["ip_adapter_map"] == None: |
|
logger.info(f"fill map") |
|
c["ip_adapter_map"] = prev_ip_map |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not region_condi_list: |
|
raise ValueError("erro! There is not a single valid region") |
|
|
|
return region_condi_list, region_list, ip_adapter_config_map, region2index |
|
|
|
def img2img_preprocess( |
|
img2img_config_map: Dict[str, Any] = None, |
|
width: int = 512, |
|
height: int = 512, |
|
duration: int = 16, |
|
out_dir: PathLike = ..., |
|
): |
|
|
|
img2img_map={} |
|
|
|
processed = False |
|
|
|
if img2img_config_map: |
|
if img2img_config_map["enable"] == True: |
|
image_dir = data_dir.joinpath( img2img_config_map["init_img_dir"] ) |
|
imgs = sorted(glob.glob( os.path.join(image_dir, "[0-9]*.png"), recursive=False)) |
|
if len(imgs) > 0: |
|
img2img_map["images"] = {} |
|
img2img_map["denoising_strength"] = img2img_config_map["denoising_strength"] |
|
for img_path in tqdm(imgs, desc=f"Preprocessing images (img2img)"): |
|
frame_no = int(Path(img_path).stem) |
|
if frame_no < duration: |
|
img2img_map["images"][frame_no] = get_resized_image(img_path, width, height) |
|
processed = True |
|
|
|
if (img2img_config_map["save_init_image"] == True) and processed: |
|
det_dir = out_dir.joinpath(f"{0:02d}_img2img_init_img/") |
|
det_dir.mkdir(parents=True, exist_ok=True) |
|
for frame_no in tqdm(img2img_map["images"], desc=f"Saving Preprocessed images (img2img)"): |
|
save_path = det_dir.joinpath(f"{frame_no:08d}.png") |
|
img2img_map["images"][frame_no].save(save_path) |
|
|
|
return img2img_map if processed else None |
|
|
|
def mask_preprocess( |
|
region_config_map: Dict[str, Any] = None, |
|
width: int = 512, |
|
height: int = 512, |
|
duration: int = 16, |
|
out_dir: PathLike = ..., |
|
): |
|
|
|
mask_map={} |
|
|
|
processed = False |
|
size = None |
|
mode = None |
|
|
|
if region_config_map: |
|
image_dir = data_dir.joinpath( region_config_map["mask_dir"] ) |
|
imgs = sorted(glob.glob( os.path.join(image_dir, "[0-9]*.png"), recursive=False)) |
|
if len(imgs) > 0: |
|
for img_path in tqdm(imgs, desc=f"Preprocessing images (mask)"): |
|
frame_no = int(Path(img_path).stem) |
|
if frame_no < duration: |
|
mask_map[frame_no] = get_resized_image(img_path, width, height) |
|
if size is None: |
|
size = mask_map[frame_no].size |
|
mode = mask_map[frame_no].mode |
|
|
|
processed = True |
|
|
|
if processed: |
|
if 0 in mask_map: |
|
prev_img = mask_map[0] |
|
else: |
|
prev_img = Image.new(mode, size, color=0) |
|
|
|
for i in range(duration): |
|
if i in mask_map: |
|
prev_img = mask_map[i] |
|
else: |
|
mask_map[i] = prev_img |
|
|
|
if (region_config_map["save_mask"] == True) and processed: |
|
det_dir = out_dir.joinpath(f"mask/") |
|
det_dir.mkdir(parents=True, exist_ok=True) |
|
for frame_no in tqdm(mask_map, desc=f"Saving Preprocessed images (mask)"): |
|
save_path = det_dir.joinpath(f"{frame_no:08d}.png") |
|
mask_map[frame_no].save(save_path) |
|
|
|
return mask_map if processed else None |
|
|
|
def wild_card_conversion(model_config: ModelConfig = ...,): |
|
from animatediff.utils.wild_card import replace_wild_card |
|
|
|
wild_card_dir = get_dir("wildcards") |
|
for k in model_config.prompt_map.keys(): |
|
model_config.prompt_map[k] = replace_wild_card(model_config.prompt_map[k], wild_card_dir) |
|
|
|
if model_config.head_prompt: |
|
model_config.head_prompt = replace_wild_card(model_config.head_prompt, wild_card_dir) |
|
if model_config.tail_prompt: |
|
model_config.tail_prompt = replace_wild_card(model_config.tail_prompt, wild_card_dir) |
|
|
|
model_config.prompt_fixed_ratio = max(min(1.0, model_config.prompt_fixed_ratio),0) |
|
|
|
if model_config.region_map: |
|
for r in model_config.region_map: |
|
if r == "background": |
|
continue |
|
|
|
if "condition" in model_config.region_map[r]: |
|
c = model_config.region_map[r]["condition"] |
|
for k in c["prompt_map"].keys(): |
|
c["prompt_map"][k] = replace_wild_card(c["prompt_map"][k], wild_card_dir) |
|
|
|
if "head_prompt" in c: |
|
c["head_prompt"] = replace_wild_card(c["head_prompt"], wild_card_dir) |
|
if "tail_prompt" in c: |
|
c["tail_prompt"] = replace_wild_card(c["tail_prompt"], wild_card_dir) |
|
if "prompt_fixed_ratio" in c: |
|
c["prompt_fixed_ratio"] = max(min(1.0, c["prompt_fixed_ratio"]),0) |
|
|
|
def save_output( |
|
pipeline_output, |
|
frame_dir:str, |
|
out_file:str, |
|
output_map : Dict[str,Any] = {}, |
|
no_frames : bool = False, |
|
save_frames=save_frames, |
|
save_video=None, |
|
): |
|
|
|
output_format = "gif" |
|
output_fps = 8 |
|
if output_map: |
|
output_format = output_map["format"] if "format" in output_map else output_format |
|
output_fps = output_map["fps"] if "fps" in output_map else output_fps |
|
if output_format == "mp4": |
|
output_format = "h264" |
|
|
|
if output_format == "gif": |
|
out_file = out_file.with_suffix(".gif") |
|
if no_frames is not True: |
|
if save_frames: |
|
save_frames(pipeline_output,frame_dir) |
|
|
|
|
|
if save_video: |
|
save_video(pipeline_output, out_file, output_fps) |
|
else: |
|
pipeline_output[0].save( |
|
fp=out_file, format="GIF", append_images=pipeline_output[1:], save_all=True, duration=(1 / output_fps * 1000), loop=0 |
|
) |
|
|
|
else: |
|
|
|
if save_frames: |
|
save_frames(pipeline_output,frame_dir) |
|
|
|
from animatediff.rife.ffmpeg import (FfmpegEncoder, VideoCodec, |
|
codec_extn) |
|
|
|
out_file = out_file.with_suffix( f".{codec_extn(output_format)}" ) |
|
|
|
logger.info("Creating ffmpeg encoder...") |
|
encoder = FfmpegEncoder( |
|
frames_dir=frame_dir, |
|
out_file=out_file, |
|
codec=output_format, |
|
in_fps=output_fps, |
|
out_fps=output_fps, |
|
lossless=False, |
|
param= output_map["encode_param"] if "encode_param" in output_map else {} |
|
) |
|
logger.info("Encoding interpolated frames with ffmpeg...") |
|
result = encoder.encode() |
|
logger.debug(f"ffmpeg result: {result}") |
|
|
|
|
|
|
|
def run_inference( |
|
pipeline: DiffusionPipeline, |
|
n_prompt: str = ..., |
|
seed: int = -1, |
|
steps: int = 25, |
|
guidance_scale: float = 7.5, |
|
unet_batch_size: int = 1, |
|
width: int = 512, |
|
height: int = 512, |
|
duration: int = 16, |
|
idx: int = 0, |
|
out_dir: PathLike = ..., |
|
context_frames: int = -1, |
|
context_stride: int = 3, |
|
context_overlap: int = 4, |
|
context_schedule: str = "uniform", |
|
clip_skip: int = 1, |
|
controlnet_map: Dict[str, Any] = None, |
|
controlnet_image_map: Dict[str,Any] = None, |
|
controlnet_type_map: Dict[str,Any] = None, |
|
controlnet_ref_map: Dict[str,Any] = None, |
|
controlnet_no_shrink:List[str]=None, |
|
no_frames :bool = False, |
|
img2img_map: Dict[str,Any] = None, |
|
ip_adapter_config_map: Dict[str,Any] = None, |
|
region_list: List[Any] = None, |
|
region_condi_list: List[Any] = None, |
|
output_map: Dict[str,Any] = None, |
|
is_single_prompt_mode: bool = False, |
|
is_sdxl:bool=False, |
|
apply_lcm_lora:bool=False, |
|
gradual_latent_map: Dict[str,Any] = None, |
|
): |
|
out_dir = Path(out_dir) |
|
|
|
|
|
prompt_map = region_condi_list[0]["prompt_map"] |
|
prompt_tags = [re_clean_prompt.sub("", tag).strip().replace(" ", "-") for tag in prompt_map[list(prompt_map.keys())[0]].split(",")] |
|
prompt_str = "_".join((prompt_tags[:6]))[:50] |
|
frame_dir = out_dir.joinpath(f"{idx:02d}-{seed}") |
|
out_file = out_dir.joinpath(f"{idx:02d}_{seed}_{prompt_str}") |
|
|
|
def preview_callback(i: int, video: torch.Tensor, save_fn: Callable[[torch.Tensor], None], out_file: str) -> None: |
|
save_fn(video, out_file=Path(f"{out_file}_preview@{i}")) |
|
|
|
save_fn = partial( |
|
save_output, |
|
frame_dir=frame_dir, |
|
output_map=output_map, |
|
no_frames=no_frames, |
|
save_frames=partial(save_frames, show_progress=False), |
|
save_video=save_video |
|
) |
|
callback = partial(preview_callback, save_fn=save_fn, out_file=out_file) |
|
|
|
seed_everything(seed) |
|
|
|
logger.info(f"{len( region_condi_list )=}") |
|
logger.info(f"{len( region_list )=}") |
|
|
|
pipeline_output = pipeline( |
|
negative_prompt=n_prompt, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance_scale, |
|
unet_batch_size=unet_batch_size, |
|
width=width, |
|
height=height, |
|
video_length=duration, |
|
return_dict=False, |
|
context_frames=context_frames, |
|
context_stride=context_stride + 1, |
|
context_overlap=context_overlap, |
|
context_schedule=context_schedule, |
|
clip_skip=clip_skip, |
|
controlnet_type_map=controlnet_type_map, |
|
controlnet_image_map=controlnet_image_map, |
|
controlnet_ref_map=controlnet_ref_map, |
|
controlnet_no_shrink=controlnet_no_shrink, |
|
controlnet_max_samples_on_vram=controlnet_map["max_samples_on_vram"] if "max_samples_on_vram" in controlnet_map else 999, |
|
controlnet_max_models_on_vram=controlnet_map["max_models_on_vram"] if "max_models_on_vram" in controlnet_map else 99, |
|
controlnet_is_loop = controlnet_map["is_loop"] if "is_loop" in controlnet_map else True, |
|
img2img_map=img2img_map, |
|
ip_adapter_config_map=ip_adapter_config_map, |
|
region_list=region_list, |
|
region_condi_list=region_condi_list, |
|
interpolation_factor=1, |
|
is_single_prompt_mode=is_single_prompt_mode, |
|
apply_lcm_lora=apply_lcm_lora, |
|
gradual_latent_map=gradual_latent_map, |
|
callback=callback, |
|
callback_steps=output_map.get("preview_steps"), |
|
) |
|
logger.info("Generation complete, saving...") |
|
|
|
save_fn(pipeline_output, out_file=out_file) |
|
|
|
logger.info(f"Saved sample to {out_file}") |
|
return pipeline_output |
|
|
|
|
|
def run_upscale( |
|
org_imgs: List[str], |
|
pipeline: DiffusionPipeline, |
|
prompt_map: Dict[int, str] = None, |
|
n_prompt: str = ..., |
|
seed: int = -1, |
|
steps: int = 25, |
|
strength: float = 0.5, |
|
guidance_scale: float = 7.5, |
|
clip_skip: int = 1, |
|
us_width: int = 512, |
|
us_height: int = 512, |
|
idx: int = 0, |
|
out_dir: PathLike = ..., |
|
upscale_config:Dict[str, Any]=None, |
|
use_controlnet_ref: bool = False, |
|
use_controlnet_tile: bool = False, |
|
use_controlnet_line_anime: bool = False, |
|
use_controlnet_ip2p: bool = False, |
|
no_frames:bool = False, |
|
output_map: Dict[str,Any] = None, |
|
): |
|
from animatediff.utils.lpw_stable_diffusion import lpw_encode_prompt |
|
|
|
pipeline.set_progress_bar_config(disable=True) |
|
|
|
images = get_resized_images(org_imgs, us_width, us_height) |
|
|
|
steps = steps if "steps" not in upscale_config else upscale_config["steps"] |
|
scheduler = scheduler if "scheduler" not in upscale_config else upscale_config["scheduler"] |
|
guidance_scale = guidance_scale if "guidance_scale" not in upscale_config else upscale_config["guidance_scale"] |
|
clip_skip = clip_skip if "clip_skip" not in upscale_config else upscale_config["clip_skip"] |
|
strength = strength if "strength" not in upscale_config else upscale_config["strength"] |
|
|
|
controlnet_conditioning_scale = [] |
|
guess_mode = [] |
|
control_guidance_start = [] |
|
control_guidance_end = [] |
|
|
|
|
|
if use_controlnet_tile: |
|
controlnet_conditioning_scale.append(upscale_config["controlnet_tile"]["controlnet_conditioning_scale"]) |
|
guess_mode.append(upscale_config["controlnet_tile"]["guess_mode"]) |
|
control_guidance_start.append(upscale_config["controlnet_tile"]["control_guidance_start"]) |
|
control_guidance_end.append(upscale_config["controlnet_tile"]["control_guidance_end"]) |
|
|
|
|
|
if use_controlnet_line_anime: |
|
controlnet_conditioning_scale.append(upscale_config["controlnet_line_anime"]["controlnet_conditioning_scale"]) |
|
guess_mode.append(upscale_config["controlnet_line_anime"]["guess_mode"]) |
|
control_guidance_start.append(upscale_config["controlnet_line_anime"]["control_guidance_start"]) |
|
control_guidance_end.append(upscale_config["controlnet_line_anime"]["control_guidance_end"]) |
|
|
|
|
|
if use_controlnet_ip2p: |
|
controlnet_conditioning_scale.append(upscale_config["controlnet_ip2p"]["controlnet_conditioning_scale"]) |
|
guess_mode.append(upscale_config["controlnet_ip2p"]["guess_mode"]) |
|
control_guidance_start.append(upscale_config["controlnet_ip2p"]["control_guidance_start"]) |
|
control_guidance_end.append(upscale_config["controlnet_ip2p"]["control_guidance_end"]) |
|
|
|
|
|
ref_image = None |
|
if use_controlnet_ref: |
|
if not upscale_config["controlnet_ref"]["use_frame_as_ref_image"] and not upscale_config["controlnet_ref"]["use_1st_frame_as_ref_image"]: |
|
ref_image = get_resized_images([ data_dir.joinpath( upscale_config["controlnet_ref"]["ref_image"] ) ], us_width, us_height)[0] |
|
|
|
|
|
generator = torch.manual_seed(seed) |
|
|
|
seed_everything(seed) |
|
|
|
prompt_embeds_map = {} |
|
prompt_map = dict(sorted(prompt_map.items())) |
|
negative = None |
|
|
|
do_classifier_free_guidance=guidance_scale > 1.0 |
|
|
|
prompt_list = [prompt_map[key_frame] for key_frame in prompt_map.keys()] |
|
|
|
prompt_embeds,neg_embeds = lpw_encode_prompt( |
|
pipe=pipeline, |
|
prompt=prompt_list, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
negative_prompt=n_prompt, |
|
) |
|
|
|
if do_classifier_free_guidance: |
|
negative = neg_embeds.chunk(neg_embeds.shape[0], 0) |
|
positive = prompt_embeds.chunk(prompt_embeds.shape[0], 0) |
|
else: |
|
negative = [None] |
|
positive = prompt_embeds.chunk(prompt_embeds.shape[0], 0) |
|
|
|
for i, key_frame in enumerate(prompt_map): |
|
prompt_embeds_map[key_frame] = positive[i] |
|
|
|
key_first =list(prompt_map.keys())[0] |
|
key_last =list(prompt_map.keys())[-1] |
|
|
|
def get_current_prompt_embeds( |
|
center_frame: int = 0, |
|
video_length : int = 0 |
|
): |
|
|
|
key_prev = key_last |
|
key_next = key_first |
|
|
|
for p in prompt_map.keys(): |
|
if p > center_frame: |
|
key_next = p |
|
break |
|
key_prev = p |
|
|
|
dist_prev = center_frame - key_prev |
|
if dist_prev < 0: |
|
dist_prev += video_length |
|
dist_next = key_next - center_frame |
|
if dist_next < 0: |
|
dist_next += video_length |
|
|
|
if key_prev == key_next or dist_prev + dist_next == 0: |
|
return prompt_embeds_map[key_prev] |
|
|
|
rate = dist_prev / (dist_prev + dist_next) |
|
|
|
return get_tensor_interpolation_method()(prompt_embeds_map[key_prev],prompt_embeds_map[key_next], rate) |
|
|
|
|
|
line_anime_processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") |
|
|
|
|
|
out_images=[] |
|
|
|
logger.info(f"{use_controlnet_tile=}") |
|
logger.info(f"{use_controlnet_line_anime=}") |
|
logger.info(f"{use_controlnet_ip2p=}") |
|
|
|
logger.info(f"{controlnet_conditioning_scale=}") |
|
logger.info(f"{guess_mode=}") |
|
logger.info(f"{control_guidance_start=}") |
|
logger.info(f"{control_guidance_end=}") |
|
|
|
|
|
for i, org_image in enumerate(tqdm(images, desc=f"Upscaling...")): |
|
|
|
cur_positive = get_current_prompt_embeds(i, len(images)) |
|
|
|
|
|
|
|
condition_image = [] |
|
|
|
if use_controlnet_tile: |
|
condition_image.append( org_image ) |
|
if use_controlnet_line_anime: |
|
condition_image.append( line_anime_processor(org_image) ) |
|
if use_controlnet_ip2p: |
|
condition_image.append( org_image ) |
|
|
|
if not use_controlnet_ref: |
|
out_image = pipeline( |
|
prompt_embeds=cur_positive, |
|
negative_prompt_embeds=negative[0], |
|
image=org_image, |
|
control_image=condition_image, |
|
width=org_image.size[0], |
|
height=org_image.size[1], |
|
strength=strength, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance_scale, |
|
generator=generator, |
|
|
|
controlnet_conditioning_scale= controlnet_conditioning_scale if len(controlnet_conditioning_scale) > 1 else controlnet_conditioning_scale[0], |
|
guess_mode= guess_mode[0], |
|
control_guidance_start= control_guidance_start if len(control_guidance_start) > 1 else control_guidance_start[0], |
|
control_guidance_end= control_guidance_end if len(control_guidance_end) > 1 else control_guidance_end[0], |
|
|
|
).images[0] |
|
else: |
|
|
|
if upscale_config["controlnet_ref"]["use_1st_frame_as_ref_image"]: |
|
if i == 0: |
|
ref_image = org_image |
|
elif upscale_config["controlnet_ref"]["use_frame_as_ref_image"]: |
|
ref_image = org_image |
|
|
|
out_image = pipeline( |
|
prompt_embeds=cur_positive, |
|
negative_prompt_embeds=negative[0], |
|
image=org_image, |
|
control_image=condition_image, |
|
width=org_image.size[0], |
|
height=org_image.size[1], |
|
strength=strength, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance_scale, |
|
generator=generator, |
|
|
|
controlnet_conditioning_scale= controlnet_conditioning_scale if len(controlnet_conditioning_scale) > 1 else controlnet_conditioning_scale[0], |
|
guess_mode= guess_mode[0], |
|
|
|
|
|
|
|
|
|
ref_image=ref_image, |
|
attention_auto_machine_weight = upscale_config["controlnet_ref"]["attention_auto_machine_weight"], |
|
gn_auto_machine_weight = upscale_config["controlnet_ref"]["gn_auto_machine_weight"], |
|
style_fidelity = upscale_config["controlnet_ref"]["style_fidelity"], |
|
reference_attn= upscale_config["controlnet_ref"]["reference_attn"], |
|
reference_adain= upscale_config["controlnet_ref"]["reference_adain"], |
|
|
|
).images[0] |
|
|
|
out_images.append(out_image) |
|
|
|
|
|
prompt_tags = [re_clean_prompt.sub("", tag).strip().replace(" ", "-") for tag in prompt_map[list(prompt_map.keys())[0]].split(",")] |
|
prompt_str = "_".join((prompt_tags[:6]))[:50] |
|
|
|
|
|
out_file = out_dir.joinpath(f"{idx:02d}_{seed}_{prompt_str}") |
|
|
|
frame_dir = out_dir.joinpath(f"{idx:02d}-{seed}-upscaled") |
|
|
|
save_output( out_images, frame_dir, out_file, output_map, no_frames, save_imgs, None ) |
|
|
|
logger.info(f"Saved sample to {out_file}") |
|
|
|
return out_images |
|
|