extensions / sd-webui-animatediff /scripts /animatediff_i2ibatch.py
ehristoforu's picture
Upload folder using huggingface_hub
0163a2c verified
from pathlib import Path
from types import MethodType
import os
import cv2
import numpy as np
import torch
import hashlib
from PIL import Image, ImageOps, UnidentifiedImageError
from modules import processing, shared, scripts, img2img, devices, masking, sd_samplers, images
from modules.processing import (StableDiffusionProcessingImg2Img,
process_images,
create_binary_mask,
create_random_tensors,
images_tensor_to_samples,
setup_color_correction,
opt_f)
from modules.shared import opts
from modules.sd_samplers_common import images_tensor_to_samples, approximation_indexes
from scripts.animatediff_logger import logger_animatediff as logger
class AnimateDiffI2IBatch:
original_img2img_process_batch = None
def hack(self):
# TODO: PR this hack to A1111
if AnimateDiffI2IBatch.original_img2img_process_batch is not None:
logger.info("Hacking i2i-batch is already done.")
return
logger.info("Hacking i2i-batch.")
AnimateDiffI2IBatch.original_img2img_process_batch = img2img.process_batch
original_img2img_process_batch = AnimateDiffI2IBatch.original_img2img_process_batch
def hacked_i2i_init(self, all_prompts, all_seeds, all_subseeds): # only hack this when i2i-batch with batch mask
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_regions = []
paste_to = []
masks_for_overlay = []
image_masks = self.image_mask
for idx, image_mask in enumerate(image_masks):
# image_mask is passed in as RGBA by Gradio to support alpha masks,
# but we still want to support binary masks.
image_mask = create_binary_mask(image_mask)
if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
if self.mask_blur_x > 0:
np_mask = np.array(image_mask)
kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
image_mask = Image.fromarray(np_mask)
if self.mask_blur_y > 0:
np_mask = np.array(image_mask)
kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
image_mask = Image.fromarray(np_mask)
if self.inpaint_full_res:
masks_for_overlay.append(image_mask)
mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
crop_regions.append(crop_region)
x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region)
image_mask = images.resize_image(2, mask, self.width, self.height)
paste_to.append((x1, y1, x2-x1, y2-y1))
else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
np_mask = np.array(image_mask)
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
masks_for_overlay.append(Image.fromarray(np_mask))
image_masks[idx] = image_mask
self.mask_for_overlay = masks_for_overlay[0] # only for saving purpose
if paste_to:
self.paste_to = paste_to[0]
self._animatediff_paste_to_full = paste_to
self.overlay_images = []
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
if add_color_corrections:
self.color_corrections = []
imgs = []
for idx, img in enumerate(self.init_images):
latent_mask = (self.latent_mask[idx] if isinstance(self.latent_mask, list) else self.latent_mask) if self.latent_mask is not None else image_masks[idx]
# Save init image
if opts.save_init_img:
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
image = images.flatten(img, opts.img2img_background_color)
if not crop_regions and self.resize_mode != 3:
image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_masks:
image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(masks_for_overlay[idx].convert('L')))
self.overlay_images.append(image_masked.convert('RGBA'))
# crop_region is not None if we are doing inpaint full res
if crop_regions:
image = image.crop(crop_regions[idx])
image = images.resize_image(2, image, self.width, self.height)
if image_masks:
if self.inpainting_fill != 1:
image = masking.fill(image, latent_mask)
if add_color_corrections:
self.color_corrections.append(setup_color_correction(image))
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
imgs.append(image)
if len(imgs) == 1:
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
if self.color_corrections is not None and len(self.color_corrections) == 1:
self.color_corrections = self.color_corrections * self.batch_size
elif len(imgs) <= self.batch_size:
self.batch_size = len(imgs)
batch_images = np.array(imgs)
else:
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
image = torch.from_numpy(batch_images)
image = image.to(shared.device, dtype=devices.dtype_vae)
if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
devices.torch_gc()
if self.resize_mode == 3:
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
if image_masks is not None:
def process_letmask(init_mask):
# init_mask = latent_mask
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0]
latmask = np.around(latmask)
return np.tile(latmask[None], (4, 1, 1))
if self.latent_mask is not None and not isinstance(self.latent_mask, list):
latmask = process_letmask(self.latent_mask)
else:
if isinstance(self.latent_mask, list):
latmask = [process_letmask(x) for x in self.latent_mask]
else:
latmask = [process_letmask(x) for x in image_masks]
latmask = np.stack(latmask, axis=0)
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
# this needs to be fixed to be done in sample() using actual seeds for batches
if self.inpainting_fill == 2:
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_masks) # let's ignore this image_masks which is related to inpaint model with different arch
def hacked_img2img_process_batch_hijack(
p: StableDiffusionProcessingImg2Img, input_dir: str, output_dir: str, inpaint_mask_dir: str,
args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
if p.scripts:
for script in p.scripts.alwayson_scripts:
if script.title().lower() == "animatediff":
ad_arg = p.script_args[script.args_from]
ad_enabled = ad_arg.get('enable', False) if isinstance(ad_arg, dict) else getattr(ad_arg, 'enable', False)
if ad_enabled:
p._animatediff_i2i_batch = 1 # i2i-batch mode, ordinary
if not hasattr(p, '_animatediff_i2i_batch'):
return original_img2img_process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale, scale_by, use_png_info, png_info_props, png_info_dir)
output_dir = output_dir.strip()
processing.fix_seed(p)
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
is_inpaint_batch = False
if inpaint_mask_dir:
inpaint_masks = shared.listfiles(inpaint_mask_dir)
is_inpaint_batch = bool(inpaint_masks)
if is_inpaint_batch:
assert len(inpaint_masks) == 1 or len(inpaint_masks) == len(images), 'The number of masks must be 1 or equal to the number of images.'
logger.info(f"\n[i2i batch] Inpaint batch is enabled. {len(inpaint_masks)} masks found.")
if len(inpaint_masks) > 1: # batch mask
p.init = MethodType(hacked_i2i_init, p)
logger.info(f"[i2i batch] Will process {len(images)} images, creating {p.n_iter} new videos.")
# extract "default" params to use in case getting png info fails
prompt = p.prompt
negative_prompt = p.negative_prompt
seed = p.seed
cfg_scale = p.cfg_scale
sampler_name = p.sampler_name
steps = p.steps
frame_images = []
frame_masks = []
for i, image in enumerate(images):
try:
img = Image.open(image)
except UnidentifiedImageError as e:
print(e)
continue
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
if to_scale:
p.width = int(img.width * scale_by)
p.height = int(img.height * scale_by)
frame_images.append(img)
image_path = Path(image)
if is_inpaint_batch:
if len(inpaint_masks) == 1:
mask_image_path = inpaint_masks[0]
p.image_mask = Image.open(mask_image_path)
else:
# try to find corresponding mask for an image using index matching
mask_image_path = inpaint_masks[i]
frame_masks.append(Image.open(mask_image_path))
mask_image = Image.open(mask_image_path)
p.image_mask = mask_image
if use_png_info:
try:
info_img = frame_images[0]
if png_info_dir:
info_img_path = os.path.join(png_info_dir, os.path.basename(image))
info_img = Image.open(info_img_path)
from modules import images as imgutil
from modules.generation_parameters_copypaste import parse_generation_parameters
geninfo, _ = imgutil.read_info_from_image(info_img)
parsed_parameters = parse_generation_parameters(geninfo)
parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
except Exception:
parsed_parameters = {}
p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
p.seed = int(parsed_parameters.get("Seed", seed))
p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
p.steps = int(parsed_parameters.get("Steps", steps))
p.init_images = frame_images
if len(frame_masks) > 0:
p.image_mask = frame_masks
proc = scripts.scripts_img2img.run(p, *args) # we should not support this, but just leave it here
if proc is None:
if output_dir:
p.outpath_samples = output_dir
p.override_settings['save_to_dirs'] = False
if p.n_iter > 1 or p.batch_size > 1:
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
else:
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
return process_images(p)
else:
logger.warn("Warning: you are using an unsupported external script. AnimateDiff may not work properly.")
img2img.process_batch = hacked_img2img_process_batch_hijack
def cap_init_image(self, p: StableDiffusionProcessingImg2Img, params):
if params.enable and isinstance(p, StableDiffusionProcessingImg2Img) and hasattr(p, '_animatediff_i2i_batch'):
if len(p.init_images) > params.video_length:
p.init_images = p.init_images[:params.video_length]
if p.image_mask and isinstance(p.image_mask, list) and len(p.image_mask) > params.video_length:
p.image_mask = p.image_mask[:params.video_length]
if len(p.init_images) < params.video_length:
params.video_length = len(p.init_images)
if len(p.init_images) < params.batch_size:
params.batch_size = len(p.init_images)
animatediff_i2ibatch = AnimateDiffI2IBatch()