BleachNick's picture
upload required packages
87d40d2
import argparse
import os
import torch
from PIL import Image, ImageFilter
from transformers import CLIPTextModel
from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
parser = argparse.ArgumentParser(description="Inference")
parser.add_argument(
"--model_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--validation_image",
type=str,
default=None,
required=True,
help="The directory of the validation image",
)
parser.add_argument(
"--validation_mask",
type=str,
default=None,
required=True,
help="The directory of the validation mask",
)
parser.add_argument(
"--output_dir",
type=str,
default="./test-infer/",
help="The output directory where predictions are saved",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible inference.")
args = parser.parse_args()
if __name__ == "__main__":
os.makedirs(args.output_dir, exist_ok=True)
generator = None
# create & load model
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32, revision=None
)
pipe.unet = UNet2DConditionModel.from_pretrained(
args.model_path,
subfolder="unet",
revision=None,
)
pipe.text_encoder = CLIPTextModel.from_pretrained(
args.model_path,
subfolder="text_encoder",
revision=None,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
if args.seed is not None:
generator = torch.Generator(device="cuda").manual_seed(args.seed)
image = Image.open(args.validation_image)
mask_image = Image.open(args.validation_mask)
results = pipe(
["a photo of sks"] * 16,
image=image,
mask_image=mask_image,
num_inference_steps=25,
guidance_scale=5,
generator=generator,
).images
erode_kernel = ImageFilter.MaxFilter(3)
mask_image = mask_image.filter(erode_kernel)
blur_kernel = ImageFilter.BoxBlur(1)
mask_image = mask_image.filter(blur_kernel)
for idx, result in enumerate(results):
result = Image.composite(result, image, mask_image)
result.save(f"{args.output_dir}/{idx}.png")
del pipe
torch.cuda.empty_cache()