Inpaint-Anything / stable_diffusion_inpaint.py
myrad01's picture
Duplicate from InpaintAI/Inpaint-Anything
b09c1e6
import os
import sys
import glob
import argparse
import torch
import numpy as np
import PIL.Image as Image
from pathlib import Path
from diffusers import StableDiffusionInpaintPipeline
from utils.mask_processing import crop_for_filling_pre, crop_for_filling_post
from utils.crop_for_replacing import recover_size, resize_and_pad
from utils import load_img_to_array, save_array_to_img
def fill_img_with_sd(
img: np.ndarray,
mask: np.ndarray,
text_prompt: str,
device="cuda"
):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32,
).to(device)
img_crop, mask_crop = crop_for_filling_pre(img, mask)
img_crop_filled = pipe(
prompt=text_prompt,
image=Image.fromarray(img_crop),
mask_image=Image.fromarray(mask_crop)
).images[0]
img_filled = crop_for_filling_post(img, mask, np.array(img_crop_filled))
return img_filled
def replace_img_with_sd(
img: np.ndarray,
mask: np.ndarray,
text_prompt: str,
step: int = 50,
device="cuda"
):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32,
).to(device)
img_padded, mask_padded, padding_factors = resize_and_pad(img, mask)
img_padded = pipe(
prompt=text_prompt,
image=Image.fromarray(img_padded),
mask_image=Image.fromarray(255 - mask_padded),
num_inference_steps=step,
).images[0]
height, width, _ = img.shape
img_resized, mask_resized = recover_size(
np.array(img_padded), mask_padded, (height, width), padding_factors)
mask_resized = np.expand_dims(mask_resized, -1) / 255
img_resized = img_resized * (1-mask_resized) + img * mask_resized
return img_resized
def setup_args(parser):
parser.add_argument(
"--input_img", type=str, required=True,
help="Path to a single input img",
)
parser.add_argument(
"--text_prompt", type=str, required=True,
help="Text prompt",
)
parser.add_argument(
"--input_mask_glob", type=str, required=True,
help="Glob to input masks",
)
parser.add_argument(
"--output_dir", type=str, required=True,
help="Output path to the directory with results.",
)
parser.add_argument(
"--seed", type=int,
help="Specify seed for reproducibility.",
)
parser.add_argument(
"--deterministic", action="store_true",
help="Use deterministic algorithms for reproducibility.",
)
if __name__ == "__main__":
"""Example usage:
python lama_inpaint.py \
--input_img FA_demo/FA1_dog.png \
--input_mask_glob "results/FA1_dog/mask*.png" \
--text_prompt "a teddy bear on a bench" \
--output_dir results
"""
parser = argparse.ArgumentParser()
setup_args(parser)
args = parser.parse_args(sys.argv[1:])
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.deterministic:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)
img_stem = Path(args.input_img).stem
mask_ps = sorted(glob.glob(args.input_mask_glob))
out_dir = Path(args.output_dir) / img_stem
out_dir.mkdir(parents=True, exist_ok=True)
img = load_img_to_array(args.input_img)
for mask_p in mask_ps:
if args.seed is not None:
torch.manual_seed(args.seed)
mask = load_img_to_array(mask_p)
img_filled_p = out_dir / f"filled_with_{Path(mask_p).name}"
img_filled = fill_img_with_sd(
img, mask, args.text_prompt, device=device)
save_array_to_img(img_filled, img_filled_p)