File size: 5,027 Bytes
ffab2fd
 
 
 
931a3ea
ffab2fd
 
b3556d1
ffab2fd
b3556d1
9130e8d
b3556d1
ffab2fd
 
b3556d1
ffab2fd
 
 
 
 
 
 
1e2c561
ffab2fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73e49ee
ffab2fd
 
1e2c561
73e49ee
 
 
ffab2fd
497aa6f
 
73e49ee
 
ffab2fd
 
 
1e2c561
ffab2fd
 
73e49ee
 
 
3183ff3
1e2c561
ffab2fd
73e49ee
b88cbc2
497aa6f
 
73e49ee
ffab2fd
73e49ee
1e2c561
73e49ee
a323a6f
 
497aa6f
 
ffab2fd
 
 
dcfd465
a46d851
e5632e8
acbffdf
 
 
6f6abe5
 
acbffdf
 
f6c4ee3
 
1b478df
96a0d26
 
1b478df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from PIL import Image
import numpy as np
import torch
import PIL
import os
import cv2
import mediapipe as mp
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline

YOUR_TOKEN = os.environ.get('HF_TOKEN_SD')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = "runwayml/stable-diffusion-inpainting"

if torch.cuda.is_available():
    pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, revision="fp16", torch_dtype=torch.float16,
                                                          use_auth_token=YOUR_TOKEN).to(device)
else:
    pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, use_auth_token=YOUR_TOKEN).to(device)


def image_grid(imgs, cols, rows=1):
    assert len(imgs) == rows * cols

    w, h = imgs[0].size
    grid = PIL.Image.new('RGB', size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid


def mediapipe_segmentation(image_file, mask_file):
    mp_drawing = mp.solutions.drawing_utils
    mp_selfie_segmentation = mp.solutions.selfie_segmentation

    # For static images:
    BG_COLOR = (0, 0, 0)  # gray
    MASK_COLOR = (255, 255, 255)  # white
    with mp_selfie_segmentation.SelfieSegmentation(model_selection=0) as selfie_segmentation:
        image = cv2.imread(image_file)
        image_height, image_width, _ = image.shape
        # Convert the BGR image to RGB before processing.
        results = selfie_segmentation.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

        # blurred_image = cv2.GaussianBlur(image,(55,55),0)
        # condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
        # output_image = np.where(condition, image, blurred_image)

        # Draw selfie segmentation on the background image.
        # To improve segmentation around boundaries, consider applying a joint
        # bilateral filter to "results.segmentation_mask" with "image".
        condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
        # Generate solid color images for showing the output selfie segmentation mask.
        fg_image = np.zeros(image.shape, dtype=np.uint8)
        fg_image[:] = MASK_COLOR
        bg_image = np.zeros(image.shape, dtype=np.uint8)
        bg_image[:] = BG_COLOR
        output_image = np.where(condition, fg_image, bg_image)
        cv2.imwrite(mask_file, output_image)


def image_inpainting(prompt, image_path, mask_image_path, num_samples=4, is_origin=False):
    image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
    mask_image = PIL.Image.open(mask_image_path).convert("RGB").resize((512, 512))
    num_samples = int(num_samples) if num_samples <= 4 else 4
    if not is_origin:
        guidance_scale = 7.5
        generator = torch.Generator(device=device).manual_seed(0)  # change the seed to get different results

        images = pipe(prompt=prompt, image=image, mask_image=mask_image, guidance_scale=guidance_scale,
                      generator=generator, num_images_per_prompt=num_samples).images
    else:
        images = pipe(prompt=prompt, image=image, mask_image=mask_image, num_images_per_prompt=num_samples).images

    # insert initial image in the list so we can compare side by side
    # images.insert(0, image)
    return image_grid(images, num_samples, 1)


title = "Person Matting & Stable Diffusion In-Painting"
description = "Inpainting Stable Diffusion <br/><b>mediapipe + Stable Diffusion<b/><br/>"


def predict1(dict, prompt, num_samples):
    dict['image'].save('image.png')
    # dict['mask'].save('mask.png')
    mediapipe_segmentation('image.png', 'm_mask.png')
    image = image_inpainting(prompt, num_samples=num_samples, image_path='image.png', mask_image_path='m_mask.png',
                             is_origin=False)
    return image


def predict2(dict, prompt, num_samples):
    dict['image'].save('image.png')
    if 'mask' in dict:
        dict['mask'].save('mask.png')
    image = image_inpainting(prompt, num_samples=num_samples, image_path='image.png', mask_image_path='mask.png',
                             is_origin=True)
    return image


image_input = gr.Image(source='upload', tool='sketch', type='pil')
prompt = gr.Textbox(label='prompt')
number = gr.Slider(1, 4, value=2, label='num_samples')

examples = [
    [os.path.join(os.path.dirname(__file__), 'example1.png'), 'a bench in a field', 2],
    # [os.path.join(os.path.dirname(__file__), 'example2.png'), 'a big ship parked on the shore', 2],
    # [os.path.join(os.path.dirname(__file__), 'example3.png'), 'a palace with many steps', 2]
]

greeter_1 = gr.Interface(predict1, inputs=[image_input, prompt, number], outputs=gr.Image(label='auto'))
greeter_2 = gr.Interface(predict2, inputs=[image_input, prompt, number], outputs=gr.Image(label='paint'))
demo = gr.Parallel(greeter_1, greeter_2, examples=examples, cache_examples=False)

if __name__ == "__main__":
    demo.launch(enable_queue=True)