from io import BytesIO | |
import torch | |
import PIL | |
import requests | |
from diffusers import RePaintPipeline, RePaintScheduler | |
def download_image(url): | |
response = requests.get(url) | |
return PIL.Image.open(BytesIO(response.content)).convert("RGB") | |
img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png" | |
mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png" | |
# Load the original image and the mask as PIL images | |
original_image = download_image(img_url).resize((256, 256)) | |
mask_image = download_image(mask_url).resize((256, 256)) | |
# Load the RePaint scheduler and pipeline based on a pretrained DDPM model | |
DEVICE = "cuda:1" | |
CACHE_DIR = "/comp_robot/rentianhe/weights/diffusers/" | |
scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256", cache_dir=CACHE_DIR) | |
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler, cache_dir=CACHE_DIR) | |
pipe = pipe.to(DEVICE) | |
generator = torch.Generator(device=DEVICE).manual_seed(0) | |
output = pipe( | |
image=original_image, | |
mask_image=mask_image, | |
num_inference_steps=250, | |
eta=0.0, | |
jump_length=10, | |
jump_n_sample=10, | |
generator=generator, | |
) | |
inpainted_image = output.images[0] | |
inpainted_image.save("./repaint_demo.jpg") |