merve's picture
merve HF staff
Upload 219 files
9cc3ad8
raw
history blame contribute delete
No virus
1.36 kB
from io import BytesIO
import torch
import PIL
import requests
from diffusers import RePaintPipeline, RePaintScheduler
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png"
mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
# Load the original image and the mask as PIL images
original_image = download_image(img_url).resize((256, 256))
mask_image = download_image(mask_url).resize((256, 256))
# Load the RePaint scheduler and pipeline based on a pretrained DDPM model
DEVICE = "cuda:1"
CACHE_DIR = "/comp_robot/rentianhe/weights/diffusers/"
scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256", cache_dir=CACHE_DIR)
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler, cache_dir=CACHE_DIR)
pipe = pipe.to(DEVICE)
generator = torch.Generator(device=DEVICE).manual_seed(0)
output = pipe(
image=original_image,
mask_image=mask_image,
num_inference_steps=250,
eta=0.0,
jump_length=10,
jump_n_sample=10,
generator=generator,
)
inpainted_image = output.images[0]
inpainted_image.save("./repaint_demo.jpg")