Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import os | |
import torch | |
from PIL import Image, ImageFilter | |
from transformers import CLIPTextModel | |
from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel | |
parser = argparse.ArgumentParser(description="Inference") | |
parser.add_argument( | |
"--model_path", | |
type=str, | |
default=None, | |
required=True, | |
help="Path to pretrained model or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--validation_image", | |
type=str, | |
default=None, | |
required=True, | |
help="The directory of the validation image", | |
) | |
parser.add_argument( | |
"--validation_mask", | |
type=str, | |
default=None, | |
required=True, | |
help="The directory of the validation mask", | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="./test-infer/", | |
help="The output directory where predictions are saved", | |
) | |
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible inference.") | |
args = parser.parse_args() | |
if __name__ == "__main__": | |
os.makedirs(args.output_dir, exist_ok=True) | |
generator = None | |
# create & load model | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32, revision=None | |
) | |
pipe.unet = UNet2DConditionModel.from_pretrained( | |
args.model_path, | |
subfolder="unet", | |
revision=None, | |
) | |
pipe.text_encoder = CLIPTextModel.from_pretrained( | |
args.model_path, | |
subfolder="text_encoder", | |
revision=None, | |
) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe = pipe.to("cuda") | |
if args.seed is not None: | |
generator = torch.Generator(device="cuda").manual_seed(args.seed) | |
image = Image.open(args.validation_image) | |
mask_image = Image.open(args.validation_mask) | |
results = pipe( | |
["a photo of sks"] * 16, | |
image=image, | |
mask_image=mask_image, | |
num_inference_steps=25, | |
guidance_scale=5, | |
generator=generator, | |
).images | |
erode_kernel = ImageFilter.MaxFilter(3) | |
mask_image = mask_image.filter(erode_kernel) | |
blur_kernel = ImageFilter.BoxBlur(1) | |
mask_image = mask_image.filter(blur_kernel) | |
for idx, result in enumerate(results): | |
result = Image.composite(result, image, mask_image) | |
result.save(f"{args.output_dir}/{idx}.png") | |
del pipe | |
torch.cuda.empty_cache() | |