Spaces:
Running
Running
# !pip install diffusers transformers | |
import requests | |
import torch | |
import numpy as np | |
from PIL import Image | |
from io import BytesIO | |
from diffusers import DiffusionPipeline | |
from segment_anything import sam_model_registry, SamPredictor | |
""" | |
Step 1: Download and preprocess example demo images | |
""" | |
def download_image(url): | |
response = requests.get(url) | |
return Image.open(BytesIO(response.content)).convert("RGB") | |
img_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/input_image.png?raw=true" | |
# example_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/pomeranian_example.jpg?raw=True" | |
# example_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/example_image.jpg?raw=true" | |
example_url = "https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/paint_by_example/labrador_example.jpg?raw=true" | |
init_image = download_image(img_url).resize((512, 512)) | |
example_image = download_image(example_url).resize((512, 512)) | |
""" | |
Step 2: Initialize SAM and PaintByExample models | |
""" | |
DEVICE = "cuda:1" | |
# SAM | |
SAM_ENCODER_VERSION = "vit_h" | |
SAM_CHECKPOINT_PATH = "/comp_robot/rentianhe/code/Grounded-Segment-Anything/sam_vit_h_4b8939.pth" | |
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH).to(device=DEVICE) | |
sam_predictor = SamPredictor(sam) | |
sam_predictor.set_image(np.array(init_image)) | |
# PaintByExample Pipeline | |
CACHE_DIR = "/comp_robot/rentianhe/weights/diffusers/" | |
pipe = DiffusionPipeline.from_pretrained( | |
"Fantasy-Studio/Paint-by-Example", | |
torch_dtype=torch.float16, | |
cache_dir=CACHE_DIR, | |
) | |
pipe = pipe.to(DEVICE) | |
""" | |
Step 3: Get masks with SAM by prompt (box or point) and inpaint the mask region by example image. | |
""" | |
input_point = np.array([[350, 256]]) | |
input_label = np.array([1]) # positive label | |
masks, _, _ = sam_predictor.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
multimask_output=False | |
) | |
mask = masks[0] # [1, 512, 512] to [512, 512] np.ndarray | |
mask_pil = Image.fromarray(mask) | |
mask_pil.save("./mask.jpg") | |
image = pipe( | |
image=init_image, | |
mask_image=mask_pil, | |
example_image=example_image, | |
num_inference_steps=500, | |
guidance_scale=9.0 | |
).images[0] | |
image.save("./paint_by_example_demo.jpg") | |